├── README.md
├── assets
├── .DS_Store
├── fig_3.png
├── fig_5.png
├── restart.png
└── vis.png
├── benchmarks
├── .DS_Store
├── README.md
├── dnnlib
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── util.cpython-38.pyc
│ │ └── util.cpython-39.pyc
│ └── util.py
├── fid.py
├── generate_dpm_solver.py
├── generate_restart.py
├── hyperparam.py
├── params_cifar10_edm.txt
├── params_cifar10_pfgmpp.txt
├── params_cifar10_vp.txt
├── params_imagenet_edm.txt
├── run.sh
├── run_fid.sh
├── run_restart.sh
├── torch_utils
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── distributed.cpython-38.pyc
│ │ ├── distributed.cpython-39.pyc
│ │ ├── misc.cpython-38.pyc
│ │ ├── misc.cpython-39.pyc
│ │ ├── persistence.cpython-39.pyc
│ │ ├── training_stats.cpython-38.pyc
│ │ └── training_stats.cpython-39.pyc
│ ├── distributed.py
│ ├── misc.py
│ ├── persistence.py
│ └── training_stats.py
└── training
│ ├── __init__.py
│ ├── augment.py
│ ├── dataset.py
│ ├── loss.py
│ ├── networks.py
│ └── training_loop.py
└── diffuser
├── .DS_Store
├── .idea
├── .gitignore
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── src.iml
└── vcs.xml
├── README.md
├── aesthetic_score.py
├── clip_score.py
├── coco_data_loader.py
├── data_process.py
├── diffusers
├── .DS_Store
├── __init__.py
├── commands
│ ├── __init__.py
│ ├── diffusers_cli.py
│ └── env.py
├── configuration_utils.py
├── dependency_versions_check.py
├── dependency_versions_table.py
├── experimental
│ ├── README.md
│ ├── __init__.py
│ └── rl
│ │ ├── __init__.py
│ │ └── value_guided_sampling.py
├── image_processor.py
├── loaders.py
├── models
│ ├── README.md
│ ├── __init__.py
│ ├── attention.py
│ ├── attention_flax.py
│ ├── attention_processor.py
│ ├── autoencoder_kl.py
│ ├── controlnet.py
│ ├── controlnet_flax.py
│ ├── cross_attention.py
│ ├── dual_transformer_2d.py
│ ├── embeddings.py
│ ├── embeddings_flax.py
│ ├── modeling_flax_pytorch_utils.py
│ ├── modeling_flax_utils.py
│ ├── modeling_pytorch_flax_utils.py
│ ├── modeling_utils.py
│ ├── prior_transformer.py
│ ├── resnet.py
│ ├── resnet_flax.py
│ ├── t5_film_transformer.py
│ ├── transformer_2d.py
│ ├── transformer_temporal.py
│ ├── unet_1d.py
│ ├── unet_1d_blocks.py
│ ├── unet_2d.py
│ ├── unet_2d_blocks.py
│ ├── unet_2d_blocks_flax.py
│ ├── unet_2d_condition.py
│ ├── unet_2d_condition_flax.py
│ ├── unet_3d_blocks.py
│ ├── unet_3d_condition.py
│ ├── vae.py
│ ├── vae_flax.py
│ └── vq_model.py
├── optimization.py
├── pipeline_utils.py
├── pipelines
│ ├── .DS_Store
│ ├── README.md
│ ├── __init__.py
│ ├── alt_diffusion
│ │ ├── __init__.py
│ │ ├── modeling_roberta_series.py
│ │ ├── pipeline_alt_diffusion.py
│ │ └── pipeline_alt_diffusion_img2img.py
│ ├── audio_diffusion
│ │ ├── __init__.py
│ │ ├── mel.py
│ │ └── pipeline_audio_diffusion.py
│ ├── audioldm
│ │ ├── __init__.py
│ │ └── pipeline_audioldm.py
│ ├── dance_diffusion
│ │ ├── __init__.py
│ │ └── pipeline_dance_diffusion.py
│ ├── ddim
│ │ ├── __init__.py
│ │ └── pipeline_ddim.py
│ ├── ddpm
│ │ ├── __init__.py
│ │ └── pipeline_ddpm.py
│ ├── dit
│ │ ├── __init__.py
│ │ └── pipeline_dit.py
│ ├── latent_diffusion
│ │ ├── __init__.py
│ │ ├── pipeline_latent_diffusion.py
│ │ └── pipeline_latent_diffusion_superresolution.py
│ ├── latent_diffusion_uncond
│ │ ├── __init__.py
│ │ └── pipeline_latent_diffusion_uncond.py
│ ├── onnx_utils.py
│ ├── paint_by_example
│ │ ├── __init__.py
│ │ ├── image_encoder.py
│ │ └── pipeline_paint_by_example.py
│ ├── pipeline_flax_utils.py
│ ├── pipeline_utils.py
│ ├── pndm
│ │ ├── __init__.py
│ │ └── pipeline_pndm.py
│ ├── repaint
│ │ ├── __init__.py
│ │ └── pipeline_repaint.py
│ ├── score_sde_ve
│ │ ├── __init__.py
│ │ └── pipeline_score_sde_ve.py
│ ├── semantic_stable_diffusion
│ │ ├── __init__.py
│ │ └── pipeline_semantic_stable_diffusion.py
│ ├── spectrogram_diffusion
│ │ ├── __init__.py
│ │ ├── continous_encoder.py
│ │ ├── midi_utils.py
│ │ ├── notes_encoder.py
│ │ └── pipeline_spectrogram_diffusion.py
│ ├── stable_diffusion
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── convert_from_ckpt.py
│ │ ├── pipeline_cycle_diffusion.py
│ │ ├── pipeline_flax_stable_diffusion.py
│ │ ├── pipeline_flax_stable_diffusion_controlnet.py
│ │ ├── pipeline_flax_stable_diffusion_img2img.py
│ │ ├── pipeline_flax_stable_diffusion_inpaint.py
│ │ ├── pipeline_onnx_stable_diffusion.py
│ │ ├── pipeline_onnx_stable_diffusion_img2img.py
│ │ ├── pipeline_onnx_stable_diffusion_inpaint.py
│ │ ├── pipeline_onnx_stable_diffusion_inpaint_legacy.py
│ │ ├── pipeline_onnx_stable_diffusion_upscale.py
│ │ ├── pipeline_stable_diffusion.py
│ │ ├── pipeline_stable_diffusion_attend_and_excite.py
│ │ ├── pipeline_stable_diffusion_controlnet.py
│ │ ├── pipeline_stable_diffusion_depth2img.py
│ │ ├── pipeline_stable_diffusion_image_variation.py
│ │ ├── pipeline_stable_diffusion_img2img.py
│ │ ├── pipeline_stable_diffusion_inpaint.py
│ │ ├── pipeline_stable_diffusion_inpaint_legacy.py
│ │ ├── pipeline_stable_diffusion_instruct_pix2pix.py
│ │ ├── pipeline_stable_diffusion_k_diffusion.py
│ │ ├── pipeline_stable_diffusion_latent_upscale.py
│ │ ├── pipeline_stable_diffusion_model_editing.py
│ │ ├── pipeline_stable_diffusion_panorama.py
│ │ ├── pipeline_stable_diffusion_pix2pix_zero.py
│ │ ├── pipeline_stable_diffusion_sag.py
│ │ ├── pipeline_stable_diffusion_upscale.py
│ │ ├── pipeline_stable_unclip.py
│ │ ├── pipeline_stable_unclip_img2img.py
│ │ ├── safety_checker.py
│ │ ├── safety_checker_flax.py
│ │ └── stable_unclip_image_normalizer.py
│ ├── stable_diffusion_safe
│ │ ├── __init__.py
│ │ ├── pipeline_stable_diffusion_safe.py
│ │ └── safety_checker.py
│ ├── stochastic_karras_ve
│ │ ├── __init__.py
│ │ └── pipeline_stochastic_karras_ve.py
│ └── text_to_video_synthesis
│ │ ├── __init__.py
│ │ ├── pipeline_text_to_video_synth.py
│ │ └── pipeline_text_to_video_zero.py
├── schedulers
│ ├── README.md
│ ├── __init__.py
│ ├── scheduling_ddim.py
│ ├── scheduling_ddim_flax.py
│ ├── scheduling_ddim_inverse.py
│ ├── scheduling_ddpm.py
│ ├── scheduling_ddpm_flax.py
│ ├── scheduling_deis_multistep.py
│ ├── scheduling_dpmsolver_multistep.py
│ ├── scheduling_dpmsolver_multistep_flax.py
│ ├── scheduling_dpmsolver_singlestep.py
│ ├── scheduling_euler_ancestral_discrete.py
│ ├── scheduling_euler_discrete.py
│ ├── scheduling_heun_discrete.py
│ ├── scheduling_ipndm.py
│ ├── scheduling_k_dpm_2_ancestral_discrete.py
│ ├── scheduling_k_dpm_2_discrete.py
│ ├── scheduling_karras_ve.py
│ ├── scheduling_karras_ve_flax.py
│ ├── scheduling_lms_discrete.py
│ ├── scheduling_lms_discrete_flax.py
│ ├── scheduling_pndm.py
│ ├── scheduling_pndm_flax.py
│ ├── scheduling_repaint.py
│ ├── scheduling_sde.py
│ ├── scheduling_sde_ve.py
│ ├── scheduling_sde_ve_flax.py
│ ├── scheduling_sde_vp.py
│ ├── scheduling_unclip.py
│ ├── scheduling_unipc_multistep.py
│ ├── scheduling_utils.py
│ ├── scheduling_utils_flax.py
│ └── scheduling_vq_diffusion.py
├── test.py
├── training_utils.py
└── utils
│ ├── __init__.py
│ ├── accelerate_utils.py
│ ├── constants.py
│ ├── deprecation_utils.py
│ ├── doc_utils.py
│ ├── dummy_flax_and_transformers_objects.py
│ ├── dummy_flax_objects.py
│ ├── dummy_note_seq_objects.py
│ ├── dummy_onnx_objects.py
│ ├── dummy_pt_objects.py
│ ├── dummy_torch_and_librosa_objects.py
│ ├── dummy_torch_and_scipy_objects.py
│ ├── dummy_torch_and_transformers_and_k_diffusion_objects.py
│ ├── dummy_torch_and_transformers_and_onnx_objects.py
│ ├── dummy_torch_and_transformers_objects.py
│ ├── dummy_transformers_and_torch_and_note_seq_objects.py
│ ├── dynamic_modules_utils.py
│ ├── hub_utils.py
│ ├── import_utils.py
│ ├── logging.py
│ ├── model_card_template.md
│ ├── outputs.py
│ ├── pil_utils.py
│ ├── testing_utils.py
│ └── torch_utils.py
├── dnnlib
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ └── util.cpython-39.pyc
└── util.py
├── eval_clip_score.py
├── ffhq.py
├── fid.py
├── generate.py
├── requirements.txt
├── torch_utils
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ ├── distributed.cpython-39.pyc
│ ├── misc.cpython-39.pyc
│ ├── persistence.cpython-39.pyc
│ └── training_stats.cpython-39.pyc
├── distributed.py
├── misc.py
├── persistence.py
└── training_stats.py
├── training
├── __init__.py
├── augment.py
├── dataset.py
├── loss.py
├── networks.py
└── training_loop.py
└── visualization.py
/assets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/assets/.DS_Store
--------------------------------------------------------------------------------
/assets/fig_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/assets/fig_3.png
--------------------------------------------------------------------------------
/assets/fig_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/assets/fig_5.png
--------------------------------------------------------------------------------
/assets/restart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/assets/restart.png
--------------------------------------------------------------------------------
/assets/vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/assets/vis.png
--------------------------------------------------------------------------------
/benchmarks/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/.DS_Store
--------------------------------------------------------------------------------
/benchmarks/README.md:
--------------------------------------------------------------------------------
1 | ### Standard Benchmarks (CIFAR-10, ImageNet-64)
2 |
3 | **The working directory for standard benchmarks is under `./benchmarks`**
4 |
5 | #### 1. Preparing datasets and checkpoints
6 |
7 | **CIFAR-10:** Download the [CIFAR-10 python version](https://www.cs.toronto.edu/~kriz/cifar.html) and convert to ZIP archive:
8 |
9 | ```.bash
10 | python dataset_tool.py --source=downloads/cifar10/cifar-10-python.tar.gz \
11 | --dest=datasets/cifar10-32x32.zip
12 | python fid.py ref --data=datasets/cifar10-32x32.zip --dest=fid-refs/cifar10-32x32.npz
13 | ```
14 |
15 | **ImageNet:** Download the [ImageNet Object Localization Challenge](https://www.kaggle.com/competitions/imagenet-object-localization-challenge/data) and convert to ZIP archive at 64x64 resolution:
16 |
17 | ```.bash
18 | python dataset_tool.py --source=downloads/imagenet/ILSVRC/Data/CLS-LOC/train \
19 | --dest=datasets/imagenet-64x64.zip --resolution=64x64 --transform=center-crop
20 | python fid.py ref --data=datasets/imagenet-64x64.zip --dest=fid-refs/imagenet-64x64.npz
21 | ```
22 |
23 | Alternatively, you could consider downloading the FID statistics at [CIFAR-10-FID](https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz) and [ImageNet-$64\times 64$-FID](https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/imagenet-64x64.npz).
24 |
25 |
26 |
27 | Please download the CIFAR-10 or ImageNet checkpoints from [EDM](https://github.com/NVlabs/edm) repo or [PFGM++](https://github.com/Newbeeer/pfgmpp) repo. For example,
28 |
29 | | Dataset | Method | Path |
30 | | ---------------------- | ------------------------------- | ------------------------------------------------------------ |
31 | | CIFAR-10 | VP unconditional | [path](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl) |
32 | | CIFAR-10 | PFGM++ ($D=2048$) unconditional | [path](https://drive.google.com/drive/folders/1sZ7vh7o8kuXfFjK8ROWXxtEZi8Srewgo) |
33 | | ImageNet $64\times 64$ | EDM conditional | [path](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl) |
34 |
35 |
36 |
37 | #### 2. Generate
38 |
39 | Generating a large number of images can be time-consuming; the workload can be distributed across multiple GPUs by launching the above command using `torchrun`. Before generation, please make sure the checkpoint is downloaded in the `./benchmarks/imgs` folder
40 |
41 | ```shell
42 | torchrun --standalone --nproc_per_node=8 generate_restart.py --outdir=./imgs \
43 | --restart_info='{restart_config}' --S_min=0.01 --S_max=1 --S_noise 1.003 \
44 | --steps={steps} --seeds=00000-49999 --name={name} (--pfgmpp=1) (--aug_dim={D})
45 |
46 |
47 | restart_config: configuration for Restart (details below)
48 | steps: number of steps in the main backward process, default=18
49 | name: name of experiments (for FID evaluation)
50 | pfgmpp: flag for using PFGM++
51 | D: augmented dimension in PFGM++
52 | ```
53 |
54 | The `restart_info` is in the format of $\lbrace i: [N_{\textrm{Restart},i}, K_i, t_{\textrm{min}, i}, t_{\textrm{max}, i}] \rbrace_{i=0}^{l-1}$ , such as `{"0": [3, 2, 0.06, 0.30]}`. Please refer to Table 3 (CIFAR-10) and Table5 (ImageNet-64) for detailed configuration. For example, on uncond. EDM cond. ImageNet-64, with NFE=203, FID=1.41, the command line is:
55 |
56 | ```shell
57 | torchrun --standalone --nproc_per_node=8 generate_restart.py --outdir=./imgs \
58 | --restart_info='{"0": [4, 1, 19.35, 40.79], "1": [4, 1, 1.09, 1.92], "2": [4, 5, 0.59, 1.09], "3": [4, 5, 0.30, 0.59], "4": [6, 6, 0.06, 0.30]}' --S_min=0.01 --S_max=1 --S_noise 1.003 \
59 | --steps=36 --seeds=00000-49999 --name=imagenet_edm
60 | ```
61 |
62 | We also provide the extentive Restart configurations in `params_cifar10_vp.txt`, `params_imagenet_edm.txt`, corresponding to Table 3 (CIFAR-10) and Table5 (ImageNet-64) respectively. Each line in these `txt` is in the form of $N_{\textrm{main}} \quad \lbrace i: [N_{\textrm{Restart},i}, K_i, t_{\textrm{min}, i}, t_{\textrm{max}, i}]\rbrace_{i=0}^{l-1}$. To sweep the Restart configurations in the `txt` files, please run
63 |
64 | ```shell
65 | python3 hyperparams.py --dataset {dataset} --method {method}
66 |
67 | dataset: cifar10 | imagenet
68 | method: edm | vp | pfgmpp
69 | ```
70 |
71 | The above sweeping will reproduce the results in the following figure (Fig 3 in the paper):
72 |
73 | 
74 |
75 | #### 3. Evaluation
76 |
77 | For FID evaluation, please run:
78 |
79 | ```shell
80 | python fid.py ./imgs/imgs_{name} stats_path
81 |
82 | name: name of experiments (specified in geneation command line)
83 | stats_path: path to FID statistics, such as ./cifar10-32x32.npz or ./imagenet-64x64.npz
84 | ```
85 |
86 |
87 |
88 |
--------------------------------------------------------------------------------
/benchmarks/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | from .util import EasyDict, make_cache_dir_path
9 |
--------------------------------------------------------------------------------
/benchmarks/dnnlib/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/dnnlib/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/benchmarks/dnnlib/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/dnnlib/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/benchmarks/dnnlib/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/dnnlib/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/benchmarks/dnnlib/__pycache__/util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/dnnlib/__pycache__/util.cpython-39.pyc
--------------------------------------------------------------------------------
/benchmarks/hyperparam.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from subprocess import PIPE, run
4 | import argparse
5 | import gdown
6 |
7 |
8 | # setup arguments
9 | parser = argparse.ArgumentParser(description='Restart Sampling')
10 | parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet'])
11 | parser.add_argument('--method', type=str, default='vp', choices=['edm', 'vp', 'pfgmpp'])
12 | args = parser.parse_args()
13 |
14 |
15 | devices='0,1,2,3,4,5,6,7'
16 | os.environ['CUDA_VISIBLE_DEVICES'] = devices
17 | batch_size = 192
18 |
19 | def get_fid(num, name):
20 |
21 | if args.dataset == 'cifar10':
22 | fid_stats_path = './cifar10-32x32.npz'
23 | else:
24 | fid_stats_path = './imagenet-64x64.npz'
25 |
26 | for tries in range(2):
27 | fid_command = f'CUDA_VISIBLE_DEVICES=0 python fid.py ./imgs/imgs_{name} {fid_stats_path} --num={num} > tmp_fid.txt'
28 | print('----------------------------')
29 | print(fid_command)
30 | os.system(fid_command)
31 | with open("tmp_fid.txt", "r") as f:
32 | output = f.read()
33 | print(output)
34 | try:
35 | fid_score = float(output.split()[-1])
36 | return fid_score
37 | except:
38 | print("FID computation failed, trying again")
39 | print('----------------------------')
40 | return 1e9
41 |
42 |
43 | runs_dict = dict()
44 |
45 | def generate(num, name, store = True, **kwargs):
46 |
47 | # If using PFGM++, setting up the augmentation dimension and the method flag
48 | pfgmpp = 1 if args.method == 'pfgmpp' else 0
49 | aug_dim = 2048 if args.method == 'pfgmpp' else -1
50 | s_noise = 1.003 if args.dataset == 'imagenet' else 1.0
51 |
52 | command = f"CUDA_VISIBLE_DEVICES={devices} torchrun --standalone --nproc_per_node={len(devices.split(','))} generate_restart.py --outdir=./imgs " \
53 | f"--restart_info='{kwargs['restart']}' --S_min=0.01 --S_max=1 --S_noise={s_noise} " \
54 | f"--steps={kwargs['steps']} --seeds=00000-{00000+num-1} --name={name} --batch={batch_size} --pfgmpp={pfgmpp} --aug_dim={aug_dim} #generate"
55 | print(command)
56 | os.system(command)
57 | if store:
58 | fid_score = get_fid(num, name)
59 | NFE = 0
60 | print("restart:", kwargs["restart"])
61 | dic = json.loads(kwargs["restart"])
62 | print("dic:", dic)
63 | for restartid in dic.keys():
64 | info = dic[restartid]
65 | NFE += 2 * info[1] * (info[0] - 1)
66 | NFE += (2 * kwargs['steps'] - 1)
67 | print(f'NFE:{NFE} FID_OF_{num}:{fid_score}')
68 | runs_dict[name] = {"fid": fid_score, "NFE": NFE, "Args": kwargs}
69 |
70 |
71 | import random
72 | import json
73 |
74 | runs_dict = dict()
75 | with open(f"params_{args.dataset}_{args.method}.txt", "r") as f:
76 | lines = f.readlines()
77 | for line in lines:
78 | sample_runs = 50000
79 | infos = line.split(' ')
80 |
81 | steps = int(infos[0])
82 | restart_info = ' '.join(infos[1:])
83 | print("restart_info:", restart_info)
84 | cur_name = random.randint(0, 20000000)
85 | generate(sample_runs, cur_name, True, restart=restart_info, steps=steps)
86 |
87 | with open(f"restart_runs_dict_{args.dataset}_{args.method}.json", "w") as f:
88 | json.dump(runs_dict, f)
89 |
90 |
--------------------------------------------------------------------------------
/benchmarks/params_cifar10_edm.txt:
--------------------------------------------------------------------------------
1 | 18 {"0": [3, 2, 0.14, 0.30]}
--------------------------------------------------------------------------------
/benchmarks/params_cifar10_pfgmpp.txt:
--------------------------------------------------------------------------------
1 | 18 {"0": [3, 2, 0.14, 0.30]}
--------------------------------------------------------------------------------
/benchmarks/params_cifar10_vp.txt:
--------------------------------------------------------------------------------
1 | 18 {"0": [3, 2, 0.06, 0.30]}
2 | 18 {"0": [3, 5, 0.06, 0.30]}
3 | 18 {"0": [3, 10, 0.06, 0.30]}
4 | 18 {"0": [3, 20, 0.06, 0.30]}
5 | 20 {"0": [9, 30, 0.06, 0.20]}
--------------------------------------------------------------------------------
/benchmarks/params_imagenet_edm.txt:
--------------------------------------------------------------------------------
1 | 14 {"0": [3, 1, 19.35, 40.79], "1": [3, 1, 1.09, 1.92], "2": [3, 1, 0.06, 0.30]}
2 | 18 {"0": [5, 1, 19.35, 40.79], "1": [5, 1, 1.09, 1.92], "2": [5, 1, 0.59, 1.09], "3": [5, 1, 0.06, 0.30]}
3 | 18 {"0": [3, 1, 19.35, 40.79], "1": [4, 1, 1.09, 1.92], "2": [4, 4, 0.59, 1.09], "3": [4, 1, 0.30, 0.59], "4": [4, 4, 0.06, 0.30]}
4 | 18 {"0": [3, 1, 19.35, 40.79], "1": [4, 1, 1.09, 1.92], "2": [4, 5, 0.59, 1.09], "3": [4, 5, 0.30, 0.59], "4": [4, 10, 0.06, 0.30]}
5 | 36 {"0": [4, 1, 19.35, 40.79], "1": [4, 1, 1.09, 1.92], "2": [4, 5, 0.59, 1.09], "3": [4, 5, 0.30, 0.59], "4": [6, 6, 0.06, 0.30]}
6 | 36 {"0": [3, 1, 19.35, 40.79], "1": [6, 1, 1.09, 1.92], "2": [6, 5, 0.59, 1.09], "3": [6, 5, 0.30, 0.59], "4": [6, 20, 0.06, 0.30]}
7 | 36 {"0": [6, 1, 19.35, 40.79], "1": [6, 1, 1.09, 1.92], "2": [7, 6, 0.59, 1.09], "3": [7, 6, 0.30, 0.59], "4": [7, 25, 0.06, 0.30]}
8 | 36 {"0": [10, 3, 19.35, 40.79], "1": [10, 3, 1.09, 1.92], "2": [7, 6, 0.59, 1.09], "3": [7, 6, 0.30, 0.59], "4": [7, 25, 0.06, 0.30]}
--------------------------------------------------------------------------------
/benchmarks/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | for i in $(seq 5 13);
3 | do
4 | torchrun --rdzv_endpoint=0.0.0.0:1201 --nproc_per_node=8 generate_dpm_solver.py --outdir=./imgs --restart_info='' --S_min=0.01 --S_max=1 --S_noise 1.003 --steps=$i --seeds=50000-99999 --name step_dpm3_$i
5 | done
--------------------------------------------------------------------------------
/benchmarks/run_fid.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | for i in $(seq 8 11);
3 | do
4 | python fid.py ./imgs/imgs_step_3_1_0.3_dpm3_3_$i cifar10-32x32.npz
5 | done
--------------------------------------------------------------------------------
/benchmarks/run_restart.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | for i in $(seq 5 7);
3 | do
4 | torchrun --rdzv_endpoint=0.0.0.0:1203 --nproc_per_node=8 generate_dpm_solver.py --outdir=./imgs --restart_info='{"0": [3, 1, 0.06, 1]}' --S_min=0.01 --S_max=1 --S_noise 1.003 --steps=$i --seeds=50000-99999 --name step_3_1_1_dpm3_3_$i
5 | done
--------------------------------------------------------------------------------
/benchmarks/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/benchmarks/torch_utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/benchmarks/torch_utils/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/benchmarks/torch_utils/__pycache__/distributed.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/distributed.cpython-38.pyc
--------------------------------------------------------------------------------
/benchmarks/torch_utils/__pycache__/distributed.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/distributed.cpython-39.pyc
--------------------------------------------------------------------------------
/benchmarks/torch_utils/__pycache__/misc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/misc.cpython-38.pyc
--------------------------------------------------------------------------------
/benchmarks/torch_utils/__pycache__/misc.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/misc.cpython-39.pyc
--------------------------------------------------------------------------------
/benchmarks/torch_utils/__pycache__/persistence.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/persistence.cpython-39.pyc
--------------------------------------------------------------------------------
/benchmarks/torch_utils/__pycache__/training_stats.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/training_stats.cpython-38.pyc
--------------------------------------------------------------------------------
/benchmarks/torch_utils/__pycache__/training_stats.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/training_stats.cpython-39.pyc
--------------------------------------------------------------------------------
/benchmarks/torch_utils/distributed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import os
9 | import torch
10 | from . import training_stats
11 |
12 | #----------------------------------------------------------------------------
13 |
14 | def init():
15 | if 'MASTER_ADDR' not in os.environ:
16 | os.environ['MASTER_ADDR'] = 'localhost'
17 | if 'MASTER_PORT' not in os.environ:
18 | os.environ['MASTER_PORT'] = '29500'
19 | if 'RANK' not in os.environ:
20 | os.environ['RANK'] = '0'
21 | if 'LOCAL_RANK' not in os.environ:
22 | os.environ['LOCAL_RANK'] = '0'
23 | if 'WORLD_SIZE' not in os.environ:
24 | os.environ['WORLD_SIZE'] = '1'
25 |
26 | backend = 'gloo' if os.name == 'nt' else 'nccl'
27 | torch.distributed.init_process_group(backend=backend, init_method='env://')
28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
29 |
30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None
31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device)
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def get_rank():
36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | def get_world_size():
41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
42 |
43 | #----------------------------------------------------------------------------
44 |
45 | def should_stop():
46 | return False
47 |
48 | #----------------------------------------------------------------------------
49 |
50 | def update_progress(cur, total):
51 | _ = cur, total
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | def print0(*args, **kwargs):
56 | if get_rank() == 0:
57 | print(*args, **kwargs)
58 |
59 | #----------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/benchmarks/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/diffuser/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/.DS_Store
--------------------------------------------------------------------------------
/diffuser/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/diffuser/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/diffuser/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/diffuser/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/diffuser/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/diffuser/.idea/src.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/diffuser/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/diffuser/README.md:
--------------------------------------------------------------------------------
1 | ### Stable Diffusion
2 |
3 | TODO: merge into the diffuser repo.
4 |
5 | **The working directory for standard benchmarks is under `./diffuser`**
6 |
7 | 
8 |
9 | #### 1. Data processing:
10 |
11 | - Step 1: Follow the instruction at the head of `data_process.py`
12 | - Step 2: Run `python3 data_process.py` to randomly sampled 5K image-text pair from COCO validation set.
13 |
14 | - Step 3: Calculate FID statistics `python fid_edm.py ref --data=path-to-coco-subset --dest=./coco.npz`
15 |
16 | #### 2. Generate
17 |
18 | ```
19 | torchrun --rdzv_endpoint=0.0.0.0:1201 --nproc_per_node=8 generate.py
20 | --steps: number of sampling steps (default=50)
21 | --scheduler: baseline method (DDIM | DDPM | Heun)
22 | --save_path: path to save images
23 | --w: classifier-guidance weight ({2,3,5,8})
24 | --name: name of experiments
25 | --restart
26 | ```
27 |
28 | If you would like to visualize the images given text prompt, run:
29 |
30 | ```python
31 | python3 visualization.py --prompt {prompt} --w {w} --steps {steps} --scheduler {scheduler} (--restart)
32 |
33 | prompt: text prompt. defautlt='a photo of an astronaut riding a horse on mars'
34 | steps: number of sampling steps (default=50)
35 | scheduler: baseline method (DDIM | DDPM)
36 | w: classifier-guidance weight ({2,3,5,8})
37 | ```
38 |
39 |
40 |
41 | #### 3. Evaluation
42 |
43 | - FID score
44 |
45 | ```sh
46 | python3 fid.py {path} ./coco.npz
47 |
48 | path: path to the directory of generated image
49 | ```
50 |
51 | - Aesthetic score & CLIP score
52 |
53 | ```shell
54 | python3 eval_clip_score.py --csv_path {path}/subset.csv --dir_path {path}
55 |
56 | path: path to the directory of generated image
57 | ```
58 |
59 |
--------------------------------------------------------------------------------
/diffuser/aesthetic_score.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import pytorch_lightning as pl
4 | import torch.nn as nn
5 | import clip
6 | import torch.nn.functional as F
7 | from PIL import Image, ImageFile
8 |
9 | ##### This script will predict the aesthetic score for this image file:
10 |
11 | img_path = "test.jpg"
12 |
13 |
14 | # if you changed the MLP architecture during training, change it also here:
15 | class MLP(pl.LightningModule):
16 | def __init__(self, input_size, xcol='emb', ycol='avg_rating'):
17 | super().__init__()
18 | self.input_size = input_size
19 | self.xcol = xcol
20 | self.ycol = ycol
21 | self.layers = nn.Sequential(
22 | nn.Linear(self.input_size, 1024),
23 | # nn.ReLU(),
24 | nn.Dropout(0.2),
25 | nn.Linear(1024, 128),
26 | # nn.ReLU(),
27 | nn.Dropout(0.2),
28 | nn.Linear(128, 64),
29 | # nn.ReLU(),
30 | nn.Dropout(0.1),
31 |
32 | nn.Linear(64, 16),
33 | # nn.ReLU(),
34 |
35 | nn.Linear(16, 1)
36 | )
37 |
38 | def forward(self, x):
39 | return self.layers(x)
40 |
41 | def training_step(self, batch, batch_idx):
42 | x = batch[self.xcol]
43 | y = batch[self.ycol].reshape(-1, 1)
44 | x_hat = self.layers(x)
45 | loss = F.mse_loss(x_hat, y)
46 | return loss
47 |
48 | def validation_step(self, batch, batch_idx):
49 | x = batch[self.xcol]
50 | y = batch[self.ycol].reshape(-1, 1)
51 | x_hat = self.layers(x)
52 | loss = F.mse_loss(x_hat, y)
53 | return loss
54 |
55 | def configure_optimizers(self):
56 | optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
57 | return optimizer
58 |
59 |
60 | def normalized(a, axis=-1, order=2):
61 | import numpy as np # pylint: disable=import-outside-toplevel
62 |
63 | l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
64 | l2[l2 == 0] = 1
65 | return a / np.expand_dims(l2, axis)
66 |
67 |
68 | # model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14
69 | #
70 | # s = torch.load("sac+logos+ava1-l14-linearMSE.pth") # load the model you trained previously or the model available in this repo
71 | #
72 | # model.load_state_dict(s)
73 | #
74 | # model.to("cuda")
75 | # model.eval()
76 | #
77 | #
78 | # device = "cuda" if torch.cuda.is_available() else "cpu"
79 | # model2, preprocess = clip.load("ViT-L/14", device=device) #RN50x64
80 | #
81 | #
82 | # pil_image = Image.open(img_path)
83 | #
84 | # image = preprocess(pil_image).unsqueeze(0).to(device)
85 | #
86 | #
87 | #
88 | # with torch.no_grad():
89 | # image_features = model2.encode_image(image)
90 | #
91 | # im_emb_arr = normalized(image_features.cpu().detach().numpy() )
92 | #
93 | # prediction = model(torch.from_numpy(im_emb_arr).to(device).type(torch.cuda.FloatTensor))
94 | #
95 | # print( "Aesthetic score predicted by the model:")
96 | # print( prediction )
--------------------------------------------------------------------------------
/diffuser/clip_score.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import open_clip
4 |
5 | model, _, preprocess = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k')
6 | tokenizer = open_clip.get_tokenizer('ViT-g-14')
7 |
8 | image = preprocess(Image.open("astronaut_rides_horse.png")).unsqueeze(0)
9 | text = tokenizer(["a horse", "a dog", "a cat"])
10 |
11 | with torch.no_grad(), torch.cuda.amp.autocast():
12 | image_features = model.encode_image(image)
13 | text_features = model.encode_text(text)
14 | image_features /= image_features.norm(dim=-1, keepdim=True)
15 | text_features /= text_features.norm(dim=-1, keepdim=True)
16 |
17 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
18 |
19 | print("Label probs:", text_probs) # prints: [[1., 0., 0.]]
--------------------------------------------------------------------------------
/diffuser/coco_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils
3 | import pandas as pd
4 | import os
5 | import open_clip
6 | from PIL import Image
7 | import clip
8 |
9 | class text_image_pair(torch.utils.data.Dataset):
10 | def __init__(self, dir_path, csv_path):
11 | """
12 |
13 | Args:
14 | dir_path: the path to the stored images
15 | file_path:
16 | """
17 | self.dir_path = dir_path
18 | df = pd.read_csv(csv_path)
19 | self.text_description = df['caption']
20 | _, _, self.preprocess = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k')
21 | _, self.preprocess2 = clip.load("ViT-L/14", device='cuda') # RN50x64
22 | # tokenizer = open_clip.get_tokenizer('ViT-g-14')
23 |
24 | def __len__(self):
25 | return len(self.text_description)
26 |
27 | def __getitem__(self, idx):
28 |
29 | img_path = os.path.join(self.dir_path, f'{idx}.png')
30 | raw_image = Image.open(img_path)
31 | image = self.preprocess(raw_image).squeeze().float()
32 | image2 = self.preprocess2(raw_image).squeeze().float()
33 | text = self.text_description[idx]
34 | return image, image2, text
35 |
--------------------------------------------------------------------------------
/diffuser/data_process.py:
--------------------------------------------------------------------------------
1 | # # bash commands to download the data
2 | # !wget http://images.cocodataset.org/zips/val2014.zip
3 | # !wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
4 | # !unzip annotations_trainval2014.zip -d coco/
5 | # !unzip val2014.zip -d coco/
6 |
7 |
8 |
9 | # load data
10 | import json
11 | import numpy as np
12 |
13 |
14 | dir_path = './coco/'
15 | data_file = dir_path + 'coco/annotations/captions_val2014.json'
16 | data = json.load(open(data_file))
17 |
18 | np.random.seed(123)
19 |
20 | # merge images and annotations
21 | import pandas as pd
22 | images = data['images']
23 | annotations = data['annotations']
24 | df = pd.DataFrame(images)
25 | df_annotations = pd.DataFrame(annotations)
26 | df = df.merge(pd.DataFrame(annotations), how='left', left_on='id', right_on='image_id')
27 |
28 |
29 | # keep only the relevant columns
30 | df = df[['file_name', 'caption']]
31 | print(df)
32 | print("length:", len(df['file_name']))
33 | # shuffle the dataset
34 | df = df.sample(frac=1)
35 |
36 |
37 | # remove duplicate images
38 | # df = df.drop_duplicates(subset='file_name')
39 |
40 | # create a random subset of n_samples
41 | n_samples = 5000
42 | df_sample = df.sample(n_samples)
43 | print(df_sample)
44 |
45 | # save the sample to a parquet file
46 | df_sample.to_csv(dir_path + 'coco/subset.csv')
47 |
48 | # copy the images to reference folder
49 | from pathlib import Path
50 | import shutil
51 | subset_path = Path(dir_path + 'coco/subset')
52 | subset_path.mkdir(exist_ok=True)
53 | for i, row in df_sample.iterrows():
54 | path = dir_path + 'coco/val2014/' + row['file_name']
55 | shutil.copy(path, dir_path + 'coco/subset/')
56 |
--------------------------------------------------------------------------------
/diffuser/diffusers/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/diffusers/.DS_Store
--------------------------------------------------------------------------------
/diffuser/diffusers/commands/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from abc import ABC, abstractmethod
16 | from argparse import ArgumentParser
17 |
18 |
19 | class BaseDiffusersCLICommand(ABC):
20 | @staticmethod
21 | @abstractmethod
22 | def register_subcommand(parser: ArgumentParser):
23 | raise NotImplementedError()
24 |
25 | @abstractmethod
26 | def run(self):
27 | raise NotImplementedError()
28 |
--------------------------------------------------------------------------------
/diffuser/diffusers/commands/diffusers_cli.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright 2023 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from argparse import ArgumentParser
17 |
18 | from .env import EnvironmentCommand
19 |
20 |
21 | def main():
22 | parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []")
23 | commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
24 |
25 | # Register commands
26 | EnvironmentCommand.register_subcommand(commands_parser)
27 |
28 | # Let's go
29 | args = parser.parse_args()
30 |
31 | if not hasattr(args, "func"):
32 | parser.print_help()
33 | exit(1)
34 |
35 | # Run
36 | service = args.func(args)
37 | service.run()
38 |
39 |
40 | if __name__ == "__main__":
41 | main()
42 |
--------------------------------------------------------------------------------
/diffuser/diffusers/commands/env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import platform
16 | from argparse import ArgumentParser
17 |
18 | import huggingface_hub
19 |
20 | from .. import __version__ as version
21 | from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
22 | from . import BaseDiffusersCLICommand
23 |
24 |
25 | def info_command_factory(_):
26 | return EnvironmentCommand()
27 |
28 |
29 | class EnvironmentCommand(BaseDiffusersCLICommand):
30 | @staticmethod
31 | def register_subcommand(parser: ArgumentParser):
32 | download_parser = parser.add_parser("env")
33 | download_parser.set_defaults(func=info_command_factory)
34 |
35 | def run(self):
36 | hub_version = huggingface_hub.__version__
37 |
38 | pt_version = "not installed"
39 | pt_cuda_available = "NA"
40 | if is_torch_available():
41 | import torch
42 |
43 | pt_version = torch.__version__
44 | pt_cuda_available = torch.cuda.is_available()
45 |
46 | transformers_version = "not installed"
47 | if is_transformers_available():
48 | import transformers
49 |
50 | transformers_version = transformers.__version__
51 |
52 | accelerate_version = "not installed"
53 | if is_accelerate_available():
54 | import accelerate
55 |
56 | accelerate_version = accelerate.__version__
57 |
58 | xformers_version = "not installed"
59 | if is_xformers_available():
60 | import xformers
61 |
62 | xformers_version = xformers.__version__
63 |
64 | info = {
65 | "`diffusers` version": version,
66 | "Platform": platform.platform(),
67 | "Python version": platform.python_version(),
68 | "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
69 | "Huggingface_hub version": hub_version,
70 | "Transformers version": transformers_version,
71 | "Accelerate version": accelerate_version,
72 | "xFormers version": xformers_version,
73 | "Using GPU in script?": "",
74 | "Using distributed or parallel set-up in script?": "",
75 | }
76 |
77 | print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
78 | print(self.format_dict(info))
79 |
80 | return info
81 |
82 | @staticmethod
83 | def format_dict(d):
84 | return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
85 |
--------------------------------------------------------------------------------
/diffuser/diffusers/dependency_versions_check.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import sys
15 |
16 | from .dependency_versions_table import deps
17 | from .utils.versions import require_version, require_version_core
18 |
19 |
20 | # define which module versions we always want to check at run time
21 | # (usually the ones defined in `install_requires` in setup.py)
22 | #
23 | # order specific notes:
24 | # - tqdm must be checked before tokenizers
25 |
26 | pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
27 | if sys.version_info < (3, 7):
28 | pkgs_to_check_at_runtime.append("dataclasses")
29 | if sys.version_info < (3, 8):
30 | pkgs_to_check_at_runtime.append("importlib_metadata")
31 |
32 | for pkg in pkgs_to_check_at_runtime:
33 | if pkg in deps:
34 | if pkg == "tokenizers":
35 | # must be loaded here, or else tqdm check may fail
36 | from .utils import is_tokenizers_available
37 |
38 | if not is_tokenizers_available():
39 | continue # not required, check version only if installed
40 |
41 | require_version_core(deps[pkg])
42 | else:
43 | raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
44 |
45 |
46 | def dep_version_check(pkg, hint=None):
47 | require_version(deps[pkg], hint)
48 |
--------------------------------------------------------------------------------
/diffuser/diffusers/dependency_versions_table.py:
--------------------------------------------------------------------------------
1 | # THIS FILE HAS BEEN AUTOGENERATED. To update:
2 | # 1. modify the `_deps` dict in setup.py
3 | # 2. run `make deps_table_update``
4 | deps = {
5 | "Pillow": "Pillow",
6 | "accelerate": "accelerate>=0.11.0",
7 | "compel": "compel==0.1.8",
8 | "black": "black~=23.1",
9 | "datasets": "datasets",
10 | "filelock": "filelock",
11 | "flax": "flax>=0.4.1",
12 | "hf-doc-builder": "hf-doc-builder>=0.3.0",
13 | "huggingface-hub": "huggingface-hub>=0.13.2",
14 | "requests-mock": "requests-mock==1.10.0",
15 | "importlib_metadata": "importlib_metadata",
16 | "isort": "isort>=5.5.4",
17 | "jax": "jax>=0.2.8,!=0.3.2",
18 | "jaxlib": "jaxlib>=0.1.65",
19 | "Jinja2": "Jinja2",
20 | "k-diffusion": "k-diffusion>=0.0.12",
21 | "librosa": "librosa",
22 | "note-seq": "note-seq",
23 | "numpy": "numpy",
24 | "parameterized": "parameterized",
25 | "protobuf": "protobuf>=3.20.3,<4",
26 | "pytest": "pytest",
27 | "pytest-timeout": "pytest-timeout",
28 | "pytest-xdist": "pytest-xdist",
29 | "ruff": "ruff>=0.0.241",
30 | "safetensors": "safetensors",
31 | "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
32 | "scipy": "scipy",
33 | "regex": "regex!=2019.12.17",
34 | "requests": "requests",
35 | "tensorboard": "tensorboard",
36 | "torch": "torch>=1.4",
37 | "torchvision": "torchvision",
38 | "transformers": "transformers>=4.25.1",
39 | }
40 |
--------------------------------------------------------------------------------
/diffuser/diffusers/experimental/README.md:
--------------------------------------------------------------------------------
1 | # 🧨 Diffusers Experimental
2 |
3 | We are adding experimental code to support novel applications and usages of the Diffusers library.
4 | Currently, the following experiments are supported:
5 | * Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
--------------------------------------------------------------------------------
/diffuser/diffusers/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | from .rl import ValueGuidedRLPipeline
2 |
--------------------------------------------------------------------------------
/diffuser/diffusers/experimental/rl/__init__.py:
--------------------------------------------------------------------------------
1 | from .value_guided_sampling import ValueGuidedRLPipeline
2 |
--------------------------------------------------------------------------------
/diffuser/diffusers/models/README.md:
--------------------------------------------------------------------------------
1 | # Models
2 |
3 | For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models).
--------------------------------------------------------------------------------
/diffuser/diffusers/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from ..utils import is_flax_available, is_torch_available
16 |
17 |
18 | if is_torch_available():
19 | from .autoencoder_kl import AutoencoderKL
20 | from .controlnet import ControlNetModel
21 | from .dual_transformer_2d import DualTransformer2DModel
22 | from .modeling_utils import ModelMixin
23 | from .prior_transformer import PriorTransformer
24 | from .t5_film_transformer import T5FilmDecoder
25 | from .transformer_2d import Transformer2DModel
26 | from .unet_1d import UNet1DModel
27 | from .unet_2d import UNet2DModel
28 | from .unet_2d_condition import UNet2DConditionModel
29 | from .unet_3d_condition import UNet3DConditionModel
30 | from .vq_model import VQModel
31 |
32 | if is_flax_available():
33 | from .controlnet_flax import FlaxControlNetModel
34 | from .unet_2d_condition_flax import FlaxUNet2DConditionModel
35 | from .vae_flax import FlaxAutoencoderKL
36 |
--------------------------------------------------------------------------------
/diffuser/diffusers/models/cross_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from ..utils import deprecate
15 | from .attention_processor import ( # noqa: F401
16 | Attention,
17 | AttentionProcessor,
18 | AttnAddedKVProcessor,
19 | AttnProcessor2_0,
20 | LoRAAttnProcessor,
21 | LoRALinearLayer,
22 | LoRAXFormersAttnProcessor,
23 | SlicedAttnAddedKVProcessor,
24 | SlicedAttnProcessor,
25 | XFormersAttnProcessor,
26 | )
27 | from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401
28 |
29 |
30 | deprecate(
31 | "cross_attention",
32 | "0.18.0",
33 | "Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.",
34 | standard_warn=False,
35 | )
36 |
37 |
38 | AttnProcessor = AttentionProcessor
39 |
40 |
41 | class CrossAttention(Attention):
42 | def __init__(self, *args, **kwargs):
43 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
44 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
45 | super().__init__(*args, **kwargs)
46 |
47 |
48 | class CrossAttnProcessor(AttnProcessorRename):
49 | def __init__(self, *args, **kwargs):
50 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
51 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
52 | super().__init__(*args, **kwargs)
53 |
54 |
55 | class LoRACrossAttnProcessor(LoRAAttnProcessor):
56 | def __init__(self, *args, **kwargs):
57 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
58 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
59 | super().__init__(*args, **kwargs)
60 |
61 |
62 | class CrossAttnAddedKVProcessor(AttnAddedKVProcessor):
63 | def __init__(self, *args, **kwargs):
64 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
65 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
66 | super().__init__(*args, **kwargs)
67 |
68 |
69 | class XFormersCrossAttnProcessor(XFormersAttnProcessor):
70 | def __init__(self, *args, **kwargs):
71 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
72 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
73 | super().__init__(*args, **kwargs)
74 |
75 |
76 | class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor):
77 | def __init__(self, *args, **kwargs):
78 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
79 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
80 | super().__init__(*args, **kwargs)
81 |
82 |
83 | class SlicedCrossAttnProcessor(SlicedAttnProcessor):
84 | def __init__(self, *args, **kwargs):
85 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
86 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
87 | super().__init__(*args, **kwargs)
88 |
89 |
90 | class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor):
91 | def __init__(self, *args, **kwargs):
92 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
93 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
94 | super().__init__(*args, **kwargs)
95 |
--------------------------------------------------------------------------------
/diffuser/diffusers/models/embeddings_flax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import math
15 |
16 | import flax.linen as nn
17 | import jax.numpy as jnp
18 |
19 |
20 | def get_sinusoidal_embeddings(
21 | timesteps: jnp.ndarray,
22 | embedding_dim: int,
23 | freq_shift: float = 1,
24 | min_timescale: float = 1,
25 | max_timescale: float = 1.0e4,
26 | flip_sin_to_cos: bool = False,
27 | scale: float = 1.0,
28 | ) -> jnp.ndarray:
29 | """Returns the positional encoding (same as Tensor2Tensor).
30 |
31 | Args:
32 | timesteps: a 1-D Tensor of N indices, one per batch element.
33 | These may be fractional.
34 | embedding_dim: The number of output channels.
35 | min_timescale: The smallest time unit (should probably be 0.0).
36 | max_timescale: The largest time unit.
37 | Returns:
38 | a Tensor of timing signals [N, num_channels]
39 | """
40 | assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
41 | assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
42 | num_timescales = float(embedding_dim // 2)
43 | log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
44 | inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
45 | emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
46 |
47 | # scale embeddings
48 | scaled_time = scale * emb
49 |
50 | if flip_sin_to_cos:
51 | signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
52 | else:
53 | signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
54 | signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
55 | return signal
56 |
57 |
58 | class FlaxTimestepEmbedding(nn.Module):
59 | r"""
60 | Time step Embedding Module. Learns embeddings for input time steps.
61 |
62 | Args:
63 | time_embed_dim (`int`, *optional*, defaults to `32`):
64 | Time step embedding dimension
65 | dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
66 | Parameters `dtype`
67 | """
68 | time_embed_dim: int = 32
69 | dtype: jnp.dtype = jnp.float32
70 |
71 | @nn.compact
72 | def __call__(self, temb):
73 | temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
74 | temb = nn.silu(temb)
75 | temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
76 | return temb
77 |
78 |
79 | class FlaxTimesteps(nn.Module):
80 | r"""
81 | Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
82 |
83 | Args:
84 | dim (`int`, *optional*, defaults to `32`):
85 | Time step embedding dimension
86 | """
87 | dim: int = 32
88 | flip_sin_to_cos: bool = False
89 | freq_shift: float = 1
90 |
91 | @nn.compact
92 | def __call__(self, timesteps):
93 | return get_sinusoidal_embeddings(
94 | timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
95 | )
96 |
--------------------------------------------------------------------------------
/diffuser/diffusers/models/modeling_flax_pytorch_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ PyTorch - Flax general utilities."""
16 | import re
17 |
18 | import jax.numpy as jnp
19 | from flax.traverse_util import flatten_dict, unflatten_dict
20 | from jax.random import PRNGKey
21 |
22 | from ..utils import logging
23 |
24 |
25 | logger = logging.get_logger(__name__)
26 |
27 |
28 | def rename_key(key):
29 | regex = r"\w+[.]\d+"
30 | pats = re.findall(regex, key)
31 | for pat in pats:
32 | key = key.replace(pat, "_".join(pat.split(".")))
33 | return key
34 |
35 |
36 | #####################
37 | # PyTorch => Flax #
38 | #####################
39 |
40 |
41 | # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
42 | # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
43 | def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
44 | """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
45 |
46 | # conv norm or layer norm
47 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
48 | if (
49 | any("norm" in str_ for str_ in pt_tuple_key)
50 | and (pt_tuple_key[-1] == "bias")
51 | and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
52 | and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
53 | ):
54 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
55 | return renamed_pt_tuple_key, pt_tensor
56 | elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
57 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
58 | return renamed_pt_tuple_key, pt_tensor
59 |
60 | # embedding
61 | if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
62 | pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
63 | return renamed_pt_tuple_key, pt_tensor
64 |
65 | # conv layer
66 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
67 | if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
68 | pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
69 | return renamed_pt_tuple_key, pt_tensor
70 |
71 | # linear layer
72 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
73 | if pt_tuple_key[-1] == "weight":
74 | pt_tensor = pt_tensor.T
75 | return renamed_pt_tuple_key, pt_tensor
76 |
77 | # old PyTorch layer norm weight
78 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
79 | if pt_tuple_key[-1] == "gamma":
80 | return renamed_pt_tuple_key, pt_tensor
81 |
82 | # old PyTorch layer norm bias
83 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
84 | if pt_tuple_key[-1] == "beta":
85 | return renamed_pt_tuple_key, pt_tensor
86 |
87 | return pt_tuple_key, pt_tensor
88 |
89 |
90 | def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
91 | # Step 1: Convert pytorch tensor to numpy
92 | pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
93 |
94 | # Step 2: Since the model is stateless, get random Flax params
95 | random_flax_params = flax_model.init_weights(PRNGKey(init_key))
96 |
97 | random_flax_state_dict = flatten_dict(random_flax_params)
98 | flax_state_dict = {}
99 |
100 | # Need to change some parameters name to match Flax names
101 | for pt_key, pt_tensor in pt_state_dict.items():
102 | renamed_pt_key = rename_key(pt_key)
103 | pt_tuple_key = tuple(renamed_pt_key.split("."))
104 |
105 | # Correctly rename weight parameters
106 | flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
107 |
108 | if flax_key in random_flax_state_dict:
109 | if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
110 | raise ValueError(
111 | f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
112 | f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
113 | )
114 |
115 | # also add unexpected weight so that warning is thrown
116 | flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
117 |
118 | return unflatten_dict(flax_state_dict)
119 |
--------------------------------------------------------------------------------
/diffuser/diffusers/models/resnet_flax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import flax.linen as nn
15 | import jax
16 | import jax.numpy as jnp
17 |
18 |
19 | class FlaxUpsample2D(nn.Module):
20 | out_channels: int
21 | dtype: jnp.dtype = jnp.float32
22 |
23 | def setup(self):
24 | self.conv = nn.Conv(
25 | self.out_channels,
26 | kernel_size=(3, 3),
27 | strides=(1, 1),
28 | padding=((1, 1), (1, 1)),
29 | dtype=self.dtype,
30 | )
31 |
32 | def __call__(self, hidden_states):
33 | batch, height, width, channels = hidden_states.shape
34 | hidden_states = jax.image.resize(
35 | hidden_states,
36 | shape=(batch, height * 2, width * 2, channels),
37 | method="nearest",
38 | )
39 | hidden_states = self.conv(hidden_states)
40 | return hidden_states
41 |
42 |
43 | class FlaxDownsample2D(nn.Module):
44 | out_channels: int
45 | dtype: jnp.dtype = jnp.float32
46 |
47 | def setup(self):
48 | self.conv = nn.Conv(
49 | self.out_channels,
50 | kernel_size=(3, 3),
51 | strides=(2, 2),
52 | padding=((1, 1), (1, 1)), # padding="VALID",
53 | dtype=self.dtype,
54 | )
55 |
56 | def __call__(self, hidden_states):
57 | # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
58 | # hidden_states = jnp.pad(hidden_states, pad_width=pad)
59 | hidden_states = self.conv(hidden_states)
60 | return hidden_states
61 |
62 |
63 | class FlaxResnetBlock2D(nn.Module):
64 | in_channels: int
65 | out_channels: int = None
66 | dropout_prob: float = 0.0
67 | use_nin_shortcut: bool = None
68 | dtype: jnp.dtype = jnp.float32
69 |
70 | def setup(self):
71 | out_channels = self.in_channels if self.out_channels is None else self.out_channels
72 |
73 | self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
74 | self.conv1 = nn.Conv(
75 | out_channels,
76 | kernel_size=(3, 3),
77 | strides=(1, 1),
78 | padding=((1, 1), (1, 1)),
79 | dtype=self.dtype,
80 | )
81 |
82 | self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)
83 |
84 | self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
85 | self.dropout = nn.Dropout(self.dropout_prob)
86 | self.conv2 = nn.Conv(
87 | out_channels,
88 | kernel_size=(3, 3),
89 | strides=(1, 1),
90 | padding=((1, 1), (1, 1)),
91 | dtype=self.dtype,
92 | )
93 |
94 | use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
95 |
96 | self.conv_shortcut = None
97 | if use_nin_shortcut:
98 | self.conv_shortcut = nn.Conv(
99 | out_channels,
100 | kernel_size=(1, 1),
101 | strides=(1, 1),
102 | padding="VALID",
103 | dtype=self.dtype,
104 | )
105 |
106 | def __call__(self, hidden_states, temb, deterministic=True):
107 | residual = hidden_states
108 | hidden_states = self.norm1(hidden_states)
109 | hidden_states = nn.swish(hidden_states)
110 | hidden_states = self.conv1(hidden_states)
111 |
112 | temb = self.time_emb_proj(nn.swish(temb))
113 | temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
114 | hidden_states = hidden_states + temb
115 |
116 | hidden_states = self.norm2(hidden_states)
117 | hidden_states = nn.swish(hidden_states)
118 | hidden_states = self.dropout(hidden_states, deterministic)
119 | hidden_states = self.conv2(hidden_states)
120 |
121 | if self.conv_shortcut is not None:
122 | residual = self.conv_shortcut(residual)
123 |
124 | return hidden_states + residual
125 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipeline_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 |
14 | # limitations under the License.
15 |
16 | # NOTE: This file is deprecated and will be removed in a future version.
17 | # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
18 |
19 | from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401
20 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/diffusers/pipelines/.DS_Store
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from ..utils import (
2 | OptionalDependencyNotAvailable,
3 | is_flax_available,
4 | is_k_diffusion_available,
5 | is_librosa_available,
6 | is_note_seq_available,
7 | is_onnx_available,
8 | is_torch_available,
9 | is_transformers_available,
10 | )
11 |
12 |
13 | try:
14 | if not is_torch_available():
15 | raise OptionalDependencyNotAvailable()
16 | except OptionalDependencyNotAvailable:
17 | from ..utils.dummy_pt_objects import * # noqa F403
18 | else:
19 | from .dance_diffusion import DanceDiffusionPipeline
20 | from .ddim import DDIMPipeline
21 | from .ddpm import DDPMPipeline
22 | from .dit import DiTPipeline
23 | from .latent_diffusion import LDMSuperResolutionPipeline
24 | from .latent_diffusion_uncond import LDMPipeline
25 | from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
26 | from .pndm import PNDMPipeline
27 | from .repaint import RePaintPipeline
28 | from .score_sde_ve import ScoreSdeVePipeline
29 | from .stochastic_karras_ve import KarrasVePipeline
30 |
31 | try:
32 | if not (is_torch_available() and is_librosa_available()):
33 | raise OptionalDependencyNotAvailable()
34 | except OptionalDependencyNotAvailable:
35 | from ..utils.dummy_torch_and_librosa_objects import * # noqa F403
36 | else:
37 | from .audio_diffusion import AudioDiffusionPipeline, Mel
38 |
39 | try:
40 | if not (is_torch_available() and is_transformers_available()):
41 | raise OptionalDependencyNotAvailable()
42 | except OptionalDependencyNotAvailable:
43 | from ..utils.dummy_torch_and_transformers_objects import * # noqa F403
44 | else:
45 | from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
46 | from .audioldm import AudioLDMPipeline
47 | from .latent_diffusion import LDMTextToImagePipeline
48 | from .paint_by_example import PaintByExamplePipeline
49 | from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
50 | from .stable_diffusion import (
51 | CycleDiffusionPipeline,
52 | StableDiffusionAttendAndExcitePipeline,
53 | StableDiffusionControlNetPipeline,
54 | StableDiffusionDepth2ImgPipeline,
55 | StableDiffusionImageVariationPipeline,
56 | StableDiffusionImg2ImgPipeline,
57 | StableDiffusionInpaintPipeline,
58 | StableDiffusionInpaintPipelineLegacy,
59 | StableDiffusionInstructPix2PixPipeline,
60 | StableDiffusionLatentUpscalePipeline,
61 | StableDiffusionModelEditingPipeline,
62 | StableDiffusionPanoramaPipeline,
63 | StableDiffusionPipeline,
64 | StableDiffusionPix2PixZeroPipeline,
65 | StableDiffusionSAGPipeline,
66 | StableDiffusionUpscalePipeline,
67 | StableUnCLIPImg2ImgPipeline,
68 | StableUnCLIPPipeline,
69 | )
70 | from .stable_diffusion_safe import StableDiffusionPipelineSafe
71 | from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline
72 | from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
73 | from .versatile_diffusion import (
74 | VersatileDiffusionDualGuidedPipeline,
75 | VersatileDiffusionImageVariationPipeline,
76 | VersatileDiffusionPipeline,
77 | VersatileDiffusionTextToImagePipeline,
78 | )
79 | from .vq_diffusion import VQDiffusionPipeline
80 |
81 | try:
82 | if not is_onnx_available():
83 | raise OptionalDependencyNotAvailable()
84 | except OptionalDependencyNotAvailable:
85 | from ..utils.dummy_onnx_objects import * # noqa F403
86 | else:
87 | from .onnx_utils import OnnxRuntimeModel
88 |
89 | try:
90 | if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
91 | raise OptionalDependencyNotAvailable()
92 | except OptionalDependencyNotAvailable:
93 | from ..utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
94 | else:
95 | from .stable_diffusion import (
96 | OnnxStableDiffusionImg2ImgPipeline,
97 | OnnxStableDiffusionInpaintPipeline,
98 | OnnxStableDiffusionInpaintPipelineLegacy,
99 | OnnxStableDiffusionPipeline,
100 | OnnxStableDiffusionUpscalePipeline,
101 | StableDiffusionOnnxPipeline,
102 | )
103 |
104 | try:
105 | if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
106 | raise OptionalDependencyNotAvailable()
107 | except OptionalDependencyNotAvailable:
108 | from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
109 | else:
110 | from .stable_diffusion import StableDiffusionKDiffusionPipeline
111 |
112 | try:
113 | if not is_flax_available():
114 | raise OptionalDependencyNotAvailable()
115 | except OptionalDependencyNotAvailable:
116 | from ..utils.dummy_flax_objects import * # noqa F403
117 | else:
118 | from .pipeline_flax_utils import FlaxDiffusionPipeline
119 |
120 |
121 | try:
122 | if not (is_flax_available() and is_transformers_available()):
123 | raise OptionalDependencyNotAvailable()
124 | except OptionalDependencyNotAvailable:
125 | from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
126 | else:
127 | from .stable_diffusion import (
128 | FlaxStableDiffusionControlNetPipeline,
129 | FlaxStableDiffusionImg2ImgPipeline,
130 | FlaxStableDiffusionInpaintPipeline,
131 | FlaxStableDiffusionPipeline,
132 | )
133 | try:
134 | if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
135 | raise OptionalDependencyNotAvailable()
136 | except OptionalDependencyNotAvailable:
137 | from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
138 | else:
139 | from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline
140 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/alt_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional, Union
3 |
4 | import numpy as np
5 | import PIL
6 | from PIL import Image
7 |
8 | from ...utils import BaseOutput, is_torch_available, is_transformers_available
9 |
10 |
11 | @dataclass
12 | # Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with Stable->Alt
13 | class AltDiffusionPipelineOutput(BaseOutput):
14 | """
15 | Output class for Alt Diffusion pipelines.
16 |
17 | Args:
18 | images (`List[PIL.Image.Image]` or `np.ndarray`)
19 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
20 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
21 | nsfw_content_detected (`List[bool]`)
22 | List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
23 | (nsfw) content, or `None` if safety checking could not be performed.
24 | """
25 |
26 | images: Union[List[PIL.Image.Image], np.ndarray]
27 | nsfw_content_detected: Optional[List[bool]]
28 |
29 |
30 | if is_transformers_available() and is_torch_available():
31 | from .modeling_roberta_series import RobertaSeriesModelWithTransformation
32 | from .pipeline_alt_diffusion import AltDiffusionPipeline
33 | from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline
34 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, Tuple
3 |
4 | import torch
5 | from torch import nn
6 | from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel
7 | from transformers.utils import ModelOutput
8 |
9 |
10 | @dataclass
11 | class TransformationModelOutput(ModelOutput):
12 | """
13 | Base class for text model's outputs that also contains a pooling of the last hidden states.
14 |
15 | Args:
16 | text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
17 | The text embeddings obtained by applying the projection layer to the pooler_output.
18 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
19 | Sequence of hidden-states at the output of the last layer of the model.
20 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
21 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
22 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
23 |
24 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
25 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
26 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
27 | sequence_length)`.
28 |
29 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
30 | heads.
31 | """
32 |
33 | projection_state: Optional[torch.FloatTensor] = None
34 | last_hidden_state: torch.FloatTensor = None
35 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
36 | attentions: Optional[Tuple[torch.FloatTensor]] = None
37 |
38 |
39 | class RobertaSeriesConfig(XLMRobertaConfig):
40 | def __init__(
41 | self,
42 | pad_token_id=1,
43 | bos_token_id=0,
44 | eos_token_id=2,
45 | project_dim=512,
46 | pooler_fn="cls",
47 | learn_encoder=False,
48 | use_attention_mask=True,
49 | **kwargs,
50 | ):
51 | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
52 | self.project_dim = project_dim
53 | self.pooler_fn = pooler_fn
54 | self.learn_encoder = learn_encoder
55 | self.use_attention_mask = use_attention_mask
56 |
57 |
58 | class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
59 | _keys_to_ignore_on_load_unexpected = [r"pooler"]
60 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
61 | base_model_prefix = "roberta"
62 | config_class = RobertaSeriesConfig
63 |
64 | def __init__(self, config):
65 | super().__init__(config)
66 | self.roberta = XLMRobertaModel(config)
67 | self.transformation = nn.Linear(config.hidden_size, config.project_dim)
68 | self.post_init()
69 |
70 | def forward(
71 | self,
72 | input_ids: Optional[torch.Tensor] = None,
73 | attention_mask: Optional[torch.Tensor] = None,
74 | token_type_ids: Optional[torch.Tensor] = None,
75 | position_ids: Optional[torch.Tensor] = None,
76 | head_mask: Optional[torch.Tensor] = None,
77 | inputs_embeds: Optional[torch.Tensor] = None,
78 | encoder_hidden_states: Optional[torch.Tensor] = None,
79 | encoder_attention_mask: Optional[torch.Tensor] = None,
80 | output_attentions: Optional[bool] = None,
81 | return_dict: Optional[bool] = None,
82 | output_hidden_states: Optional[bool] = None,
83 | ):
84 | r""" """
85 |
86 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
87 |
88 | outputs = self.base_model(
89 | input_ids=input_ids,
90 | attention_mask=attention_mask,
91 | token_type_ids=token_type_ids,
92 | position_ids=position_ids,
93 | head_mask=head_mask,
94 | inputs_embeds=inputs_embeds,
95 | encoder_hidden_states=encoder_hidden_states,
96 | encoder_attention_mask=encoder_attention_mask,
97 | output_attentions=output_attentions,
98 | output_hidden_states=output_hidden_states,
99 | return_dict=return_dict,
100 | )
101 |
102 | projection_state = self.transformation(outputs.last_hidden_state)
103 |
104 | return TransformationModelOutput(
105 | projection_state=projection_state,
106 | last_hidden_state=outputs.last_hidden_state,
107 | hidden_states=outputs.hidden_states,
108 | attentions=outputs.attentions,
109 | )
110 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/audio_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from .mel import Mel
2 | from .pipeline_audio_diffusion import AudioDiffusionPipeline
3 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/audio_diffusion/mel.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import numpy as np # noqa: E402
17 |
18 | from ...configuration_utils import ConfigMixin, register_to_config
19 | from ...schedulers.scheduling_utils import SchedulerMixin
20 |
21 |
22 | try:
23 | import librosa # noqa: E402
24 |
25 | _librosa_can_be_imported = True
26 | _import_error = ""
27 | except Exception as e:
28 | _librosa_can_be_imported = False
29 | _import_error = (
30 | f"Cannot import librosa because {e}. Make sure to correctly install librosa to be able to install it."
31 | )
32 |
33 |
34 | from PIL import Image # noqa: E402
35 |
36 |
37 | class Mel(ConfigMixin, SchedulerMixin):
38 | """
39 | Parameters:
40 | x_res (`int`): x resolution of spectrogram (time)
41 | y_res (`int`): y resolution of spectrogram (frequency bins)
42 | sample_rate (`int`): sample rate of audio
43 | n_fft (`int`): number of Fast Fourier Transforms
44 | hop_length (`int`): hop length (a higher number is recommended for lower than 256 y_res)
45 | top_db (`int`): loudest in decibels
46 | n_iter (`int`): number of iterations for Griffin Linn mel inversion
47 | """
48 |
49 | config_name = "mel_config.json"
50 |
51 | @register_to_config
52 | def __init__(
53 | self,
54 | x_res: int = 256,
55 | y_res: int = 256,
56 | sample_rate: int = 22050,
57 | n_fft: int = 2048,
58 | hop_length: int = 512,
59 | top_db: int = 80,
60 | n_iter: int = 32,
61 | ):
62 | self.hop_length = hop_length
63 | self.sr = sample_rate
64 | self.n_fft = n_fft
65 | self.top_db = top_db
66 | self.n_iter = n_iter
67 | self.set_resolution(x_res, y_res)
68 | self.audio = None
69 |
70 | if not _librosa_can_be_imported:
71 | raise ValueError(_import_error)
72 |
73 | def set_resolution(self, x_res: int, y_res: int):
74 | """Set resolution.
75 |
76 | Args:
77 | x_res (`int`): x resolution of spectrogram (time)
78 | y_res (`int`): y resolution of spectrogram (frequency bins)
79 | """
80 | self.x_res = x_res
81 | self.y_res = y_res
82 | self.n_mels = self.y_res
83 | self.slice_size = self.x_res * self.hop_length - 1
84 |
85 | def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None):
86 | """Load audio.
87 |
88 | Args:
89 | audio_file (`str`): must be a file on disk due to Librosa limitation or
90 | raw_audio (`np.ndarray`): audio as numpy array
91 | """
92 | if audio_file is not None:
93 | self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr)
94 | else:
95 | self.audio = raw_audio
96 |
97 | # Pad with silence if necessary.
98 | if len(self.audio) < self.x_res * self.hop_length:
99 | self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))])
100 |
101 | def get_number_of_slices(self) -> int:
102 | """Get number of slices in audio.
103 |
104 | Returns:
105 | `int`: number of spectograms audio can be sliced into
106 | """
107 | return len(self.audio) // self.slice_size
108 |
109 | def get_audio_slice(self, slice: int = 0) -> np.ndarray:
110 | """Get slice of audio.
111 |
112 | Args:
113 | slice (`int`): slice number of audio (out of get_number_of_slices())
114 |
115 | Returns:
116 | `np.ndarray`: audio as numpy array
117 | """
118 | return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)]
119 |
120 | def get_sample_rate(self) -> int:
121 | """Get sample rate:
122 |
123 | Returns:
124 | `int`: sample rate of audio
125 | """
126 | return self.sr
127 |
128 | def audio_slice_to_image(self, slice: int) -> Image.Image:
129 | """Convert slice of audio to spectrogram.
130 |
131 | Args:
132 | slice (`int`): slice number of audio to convert (out of get_number_of_slices())
133 |
134 | Returns:
135 | `PIL Image`: grayscale image of x_res x y_res
136 | """
137 | S = librosa.feature.melspectrogram(
138 | y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels
139 | )
140 | log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db)
141 | bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8)
142 | image = Image.fromarray(bytedata)
143 | return image
144 |
145 | def image_to_audio(self, image: Image.Image) -> np.ndarray:
146 | """Converts spectrogram to audio.
147 |
148 | Args:
149 | image (`PIL Image`): x_res x y_res grayscale image
150 |
151 | Returns:
152 | audio (`np.ndarray`): raw audio
153 | """
154 | bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width))
155 | log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db
156 | S = librosa.db_to_power(log_S)
157 | audio = librosa.feature.inverse.mel_to_audio(
158 | S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter
159 | )
160 | return audio
161 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/audioldm/__init__.py:
--------------------------------------------------------------------------------
1 | from ...utils import (
2 | OptionalDependencyNotAvailable,
3 | is_torch_available,
4 | is_transformers_available,
5 | is_transformers_version,
6 | )
7 |
8 |
9 | try:
10 | if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
11 | raise OptionalDependencyNotAvailable()
12 | except OptionalDependencyNotAvailable:
13 | from ...utils.dummy_torch_and_transformers_objects import (
14 | AudioLDMPipeline,
15 | )
16 | else:
17 | from .pipeline_audioldm import AudioLDMPipeline
18 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/dance_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_dance_diffusion import DanceDiffusionPipeline
2 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 |
20 | from ...utils import logging, randn_tensor
21 | from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
22 |
23 |
24 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25 |
26 |
27 | class DanceDiffusionPipeline(DiffusionPipeline):
28 | r"""
29 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
30 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
31 |
32 | Parameters:
33 | unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded image.
34 | scheduler ([`SchedulerMixin`]):
35 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
36 | [`IPNDMScheduler`].
37 | """
38 |
39 | def __init__(self, unet, scheduler):
40 | super().__init__()
41 | self.register_modules(unet=unet, scheduler=scheduler)
42 |
43 | @torch.no_grad()
44 | def __call__(
45 | self,
46 | batch_size: int = 1,
47 | num_inference_steps: int = 100,
48 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
49 | audio_length_in_s: Optional[float] = None,
50 | return_dict: bool = True,
51 | ) -> Union[AudioPipelineOutput, Tuple]:
52 | r"""
53 | Args:
54 | batch_size (`int`, *optional*, defaults to 1):
55 | The number of audio samples to generate.
56 | num_inference_steps (`int`, *optional*, defaults to 50):
57 | The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at
58 | the expense of slower inference.
59 | generator (`torch.Generator`, *optional*):
60 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
61 | to make generation deterministic.
62 | audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`):
63 | The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
64 | `sample_size`, will be `audio_length_in_s` * `self.unet.config.sample_rate`.
65 | return_dict (`bool`, *optional*, defaults to `True`):
66 | Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
67 |
68 | Returns:
69 | [`~pipelines.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if `return_dict` is
70 | True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
71 | """
72 |
73 | if audio_length_in_s is None:
74 | audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate
75 |
76 | sample_size = audio_length_in_s * self.unet.config.sample_rate
77 |
78 | down_scale_factor = 2 ** len(self.unet.up_blocks)
79 | if sample_size < 3 * down_scale_factor:
80 | raise ValueError(
81 | f"{audio_length_in_s} is too small. Make sure it's bigger or equal to"
82 | f" {3 * down_scale_factor / self.unet.config.sample_rate}."
83 | )
84 |
85 | original_sample_size = int(sample_size)
86 | if sample_size % down_scale_factor != 0:
87 | sample_size = (
88 | (audio_length_in_s * self.unet.config.sample_rate) // down_scale_factor + 1
89 | ) * down_scale_factor
90 | logger.info(
91 | f"{audio_length_in_s} is increased to {sample_size / self.unet.config.sample_rate} so that it can be handled"
92 | f" by the model. It will be cut to {original_sample_size / self.unet.config.sample_rate} after the denoising"
93 | " process."
94 | )
95 | sample_size = int(sample_size)
96 |
97 | dtype = next(iter(self.unet.parameters())).dtype
98 | shape = (batch_size, self.unet.config.in_channels, sample_size)
99 | if isinstance(generator, list) and len(generator) != batch_size:
100 | raise ValueError(
101 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
102 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
103 | )
104 |
105 | audio = randn_tensor(shape, generator=generator, device=self.device, dtype=dtype)
106 |
107 | # set step values
108 | self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
109 | self.scheduler.timesteps = self.scheduler.timesteps.to(dtype)
110 |
111 | for t in self.progress_bar(self.scheduler.timesteps):
112 | # 1. predict noise model_output
113 | model_output = self.unet(audio, t).sample
114 |
115 | # 2. compute previous image: x_t -> t_t-1
116 | audio = self.scheduler.step(model_output, t, audio).prev_sample
117 |
118 | audio = audio.clamp(-1, 1).float().cpu().numpy()
119 |
120 | audio = audio[:, :, :original_sample_size]
121 |
122 | if not return_dict:
123 | return (audio,)
124 |
125 | return AudioPipelineOutput(audios=audio)
126 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/ddim/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_ddim import DDIMPipeline
2 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/ddim/pipeline_ddim.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import List, Optional, Tuple, Union
16 |
17 | import torch
18 |
19 | from ...schedulers import DDIMScheduler
20 | from ...utils import randn_tensor
21 | from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
22 |
23 |
24 | class DDIMPipeline(DiffusionPipeline):
25 | r"""
26 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
27 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
28 |
29 | Parameters:
30 | unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
31 | scheduler ([`SchedulerMixin`]):
32 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
33 | [`DDPMScheduler`], or [`DDIMScheduler`].
34 | """
35 |
36 | def __init__(self, unet, scheduler):
37 | super().__init__()
38 |
39 | # make sure scheduler can always be converted to DDIM
40 | scheduler = DDIMScheduler.from_config(scheduler.config)
41 |
42 | self.register_modules(unet=unet, scheduler=scheduler)
43 |
44 | @torch.no_grad()
45 | def __call__(
46 | self,
47 | batch_size: int = 1,
48 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
49 | eta: float = 0.0,
50 | num_inference_steps: int = 50,
51 | use_clipped_model_output: Optional[bool] = None,
52 | output_type: Optional[str] = "pil",
53 | return_dict: bool = True,
54 | ) -> Union[ImagePipelineOutput, Tuple]:
55 | r"""
56 | Args:
57 | batch_size (`int`, *optional*, defaults to 1):
58 | The number of images to generate.
59 | generator (`torch.Generator`, *optional*):
60 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
61 | to make generation deterministic.
62 | eta (`float`, *optional*, defaults to 0.0):
63 | The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
64 | num_inference_steps (`int`, *optional*, defaults to 50):
65 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
66 | expense of slower inference.
67 | use_clipped_model_output (`bool`, *optional*, defaults to `None`):
68 | if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed
69 | downstream to the scheduler. So use `None` for schedulers which don't support this argument.
70 | output_type (`str`, *optional*, defaults to `"pil"`):
71 | The output format of the generate image. Choose between
72 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
73 | return_dict (`bool`, *optional*, defaults to `True`):
74 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
75 |
76 | Returns:
77 | [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
78 | True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
79 | """
80 |
81 | # Sample gaussian noise to begin loop
82 | if isinstance(self.unet.config.sample_size, int):
83 | image_shape = (
84 | batch_size,
85 | self.unet.config.in_channels,
86 | self.unet.config.sample_size,
87 | self.unet.config.sample_size,
88 | )
89 | else:
90 | image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
91 |
92 | if isinstance(generator, list) and len(generator) != batch_size:
93 | raise ValueError(
94 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
95 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
96 | )
97 |
98 | image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
99 |
100 | # set step values
101 | self.scheduler.set_timesteps(num_inference_steps)
102 |
103 | for t in self.progress_bar(self.scheduler.timesteps):
104 | # 1. predict noise model_output
105 | model_output = self.unet(image, t).sample
106 |
107 | # 2. predict previous mean of image x_t-1 and add variance depending on eta
108 | # eta corresponds to η in paper and should be between [0, 1]
109 | # do x_t -> x_t-1
110 | image = self.scheduler.step(
111 | model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
112 | ).prev_sample
113 |
114 | image = (image / 2 + 0.5).clamp(0, 1)
115 | image = image.cpu().permute(0, 2, 3, 1).numpy()
116 | if output_type == "pil":
117 | image = self.numpy_to_pil(image)
118 |
119 | if not return_dict:
120 | return (image,)
121 |
122 | return ImagePipelineOutput(images=image)
123 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/ddpm/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_ddpm import DDPMPipeline
2 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/ddpm/pipeline_ddpm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 |
20 | from ...utils import randn_tensor
21 | from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
22 |
23 |
24 | class DDPMPipeline(DiffusionPipeline):
25 | r"""
26 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
27 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
28 |
29 | Parameters:
30 | unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
31 | scheduler ([`SchedulerMixin`]):
32 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
33 | [`DDPMScheduler`], or [`DDIMScheduler`].
34 | """
35 |
36 | def __init__(self, unet, scheduler):
37 | super().__init__()
38 | self.register_modules(unet=unet, scheduler=scheduler)
39 |
40 | @torch.no_grad()
41 | def __call__(
42 | self,
43 | batch_size: int = 1,
44 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
45 | num_inference_steps: int = 1000,
46 | output_type: Optional[str] = "pil",
47 | return_dict: bool = True,
48 | ) -> Union[ImagePipelineOutput, Tuple]:
49 | r"""
50 | Args:
51 | batch_size (`int`, *optional*, defaults to 1):
52 | The number of images to generate.
53 | generator (`torch.Generator`, *optional*):
54 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
55 | to make generation deterministic.
56 | num_inference_steps (`int`, *optional*, defaults to 1000):
57 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
58 | expense of slower inference.
59 | output_type (`str`, *optional*, defaults to `"pil"`):
60 | The output format of the generate image. Choose between
61 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
62 | return_dict (`bool`, *optional*, defaults to `True`):
63 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
64 |
65 | Returns:
66 | [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
67 | True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
68 | """
69 | # Sample gaussian noise to begin loop
70 | if isinstance(self.unet.config.sample_size, int):
71 | image_shape = (
72 | batch_size,
73 | self.unet.config.in_channels,
74 | self.unet.config.sample_size,
75 | self.unet.config.sample_size,
76 | )
77 | else:
78 | image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
79 |
80 | if self.device.type == "mps":
81 | # randn does not work reproducibly on mps
82 | image = randn_tensor(image_shape, generator=generator)
83 | image = image.to(self.device)
84 | else:
85 | image = randn_tensor(image_shape, generator=generator, device=self.device)
86 |
87 | # set step values
88 | self.scheduler.set_timesteps(num_inference_steps)
89 |
90 | for t in self.progress_bar(self.scheduler.timesteps):
91 | # 1. predict noise model_output
92 | model_output = self.unet(image, t).sample
93 |
94 | # 2. compute previous image: x_t -> x_t-1
95 | image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
96 |
97 | image = (image / 2 + 0.5).clamp(0, 1)
98 | image = image.cpu().permute(0, 2, 3, 1).numpy()
99 | if output_type == "pil":
100 | image = self.numpy_to_pil(image)
101 |
102 | if not return_dict:
103 | return (image,)
104 |
105 | return ImagePipelineOutput(images=image)
106 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/dit/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_dit import DiTPipeline
2 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/latent_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from ...utils import is_transformers_available
2 | from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline
3 |
4 |
5 | if is_transformers_available():
6 | from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline
7 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/latent_diffusion_uncond/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_latent_diffusion_uncond import LDMPipeline
2 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import inspect
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 |
20 | from ...models import UNet2DModel, VQModel
21 | from ...schedulers import DDIMScheduler
22 | from ...utils import randn_tensor
23 | from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24 |
25 |
26 | class LDMPipeline(DiffusionPipeline):
27 | r"""
28 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
29 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
30 |
31 | Parameters:
32 | vqvae ([`VQModel`]):
33 | Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
34 | unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
35 | scheduler ([`SchedulerMixin`]):
36 | [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents.
37 | """
38 |
39 | def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
40 | super().__init__()
41 | self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
42 |
43 | @torch.no_grad()
44 | def __call__(
45 | self,
46 | batch_size: int = 1,
47 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
48 | eta: float = 0.0,
49 | num_inference_steps: int = 50,
50 | output_type: Optional[str] = "pil",
51 | return_dict: bool = True,
52 | **kwargs,
53 | ) -> Union[Tuple, ImagePipelineOutput]:
54 | r"""
55 | Args:
56 | batch_size (`int`, *optional*, defaults to 1):
57 | Number of images to generate.
58 | generator (`torch.Generator`, *optional*):
59 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
60 | to make generation deterministic.
61 | num_inference_steps (`int`, *optional*, defaults to 50):
62 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
63 | expense of slower inference.
64 | output_type (`str`, *optional*, defaults to `"pil"`):
65 | The output format of the generate image. Choose between
66 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
67 | return_dict (`bool`, *optional*, defaults to `True`):
68 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
69 |
70 | Returns:
71 | [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
72 | True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
73 | """
74 |
75 | latents = randn_tensor(
76 | (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
77 | generator=generator,
78 | )
79 | latents = latents.to(self.device)
80 |
81 | # scale the initial noise by the standard deviation required by the scheduler
82 | latents = latents * self.scheduler.init_noise_sigma
83 |
84 | self.scheduler.set_timesteps(num_inference_steps)
85 |
86 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
87 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
88 |
89 | extra_kwargs = {}
90 | if accepts_eta:
91 | extra_kwargs["eta"] = eta
92 |
93 | for t in self.progress_bar(self.scheduler.timesteps):
94 | latent_model_input = self.scheduler.scale_model_input(latents, t)
95 | # predict the noise residual
96 | noise_prediction = self.unet(latent_model_input, t).sample
97 | # compute the previous noisy sample x_t -> x_t-1
98 | latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
99 |
100 | # decode the image latents with the VAE
101 | image = self.vqvae.decode(latents).sample
102 |
103 | image = (image / 2 + 0.5).clamp(0, 1)
104 | image = image.cpu().permute(0, 2, 3, 1).numpy()
105 | if output_type == "pil":
106 | image = self.numpy_to_pil(image)
107 |
108 | if not return_dict:
109 | return (image,)
110 |
111 | return ImagePipelineOutput(images=image)
112 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/paint_by_example/__init__.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional, Union
3 |
4 | import numpy as np
5 | import PIL
6 | from PIL import Image
7 |
8 | from ...utils import is_torch_available, is_transformers_available
9 |
10 |
11 | if is_transformers_available() and is_torch_available():
12 | from .image_encoder import PaintByExampleImageEncoder
13 | from .pipeline_paint_by_example import PaintByExamplePipeline
14 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/paint_by_example/image_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | from torch import nn
16 | from transformers import CLIPPreTrainedModel, CLIPVisionModel
17 |
18 | from ...models.attention import BasicTransformerBlock
19 | from ...utils import logging
20 |
21 |
22 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23 |
24 |
25 | class PaintByExampleImageEncoder(CLIPPreTrainedModel):
26 | def __init__(self, config, proj_size=768):
27 | super().__init__(config)
28 | self.proj_size = proj_size
29 |
30 | self.model = CLIPVisionModel(config)
31 | self.mapper = PaintByExampleMapper(config)
32 | self.final_layer_norm = nn.LayerNorm(config.hidden_size)
33 | self.proj_out = nn.Linear(config.hidden_size, self.proj_size)
34 |
35 | # uncondition for scaling
36 | self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))
37 |
38 | def forward(self, pixel_values, return_uncond_vector=False):
39 | clip_output = self.model(pixel_values=pixel_values)
40 | latent_states = clip_output.pooler_output
41 | latent_states = self.mapper(latent_states[:, None])
42 | latent_states = self.final_layer_norm(latent_states)
43 | latent_states = self.proj_out(latent_states)
44 | if return_uncond_vector:
45 | return latent_states, self.uncond_vector
46 |
47 | return latent_states
48 |
49 |
50 | class PaintByExampleMapper(nn.Module):
51 | def __init__(self, config):
52 | super().__init__()
53 | num_layers = (config.num_hidden_layers + 1) // 5
54 | hid_size = config.hidden_size
55 | num_heads = 1
56 | self.blocks = nn.ModuleList(
57 | [
58 | BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True)
59 | for _ in range(num_layers)
60 | ]
61 | )
62 |
63 | def forward(self, hidden_states):
64 | for block in self.blocks:
65 | hidden_states = block(hidden_states)
66 |
67 | return hidden_states
68 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/pndm/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_pndm import PNDMPipeline
2 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/pndm/pipeline_pndm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 |
20 | from ...models import UNet2DModel
21 | from ...schedulers import PNDMScheduler
22 | from ...utils import randn_tensor
23 | from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24 |
25 |
26 | class PNDMPipeline(DiffusionPipeline):
27 | r"""
28 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
29 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
30 |
31 | Parameters:
32 | unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents.
33 | scheduler ([`SchedulerMixin`]):
34 | The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image.
35 | """
36 |
37 | unet: UNet2DModel
38 | scheduler: PNDMScheduler
39 |
40 | def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
41 | super().__init__()
42 |
43 | scheduler = PNDMScheduler.from_config(scheduler.config)
44 |
45 | self.register_modules(unet=unet, scheduler=scheduler)
46 |
47 | @torch.no_grad()
48 | def __call__(
49 | self,
50 | batch_size: int = 1,
51 | num_inference_steps: int = 50,
52 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
53 | output_type: Optional[str] = "pil",
54 | return_dict: bool = True,
55 | **kwargs,
56 | ) -> Union[ImagePipelineOutput, Tuple]:
57 | r"""
58 | Args:
59 | batch_size (`int`, `optional`, defaults to 1): The number of images to generate.
60 | num_inference_steps (`int`, `optional`, defaults to 50):
61 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
62 | expense of slower inference.
63 | generator (`torch.Generator`, `optional`): A [torch
64 | generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
65 | deterministic.
66 | output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose
67 | between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
68 | return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a
69 | [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
70 |
71 | Returns:
72 | [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
73 | True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
74 | """
75 | # For more information on the sampling method you can take a look at Algorithm 2 of
76 | # the official paper: https://arxiv.org/pdf/2202.09778.pdf
77 |
78 | # Sample gaussian noise to begin loop
79 | image = randn_tensor(
80 | (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
81 | generator=generator,
82 | device=self.device,
83 | )
84 |
85 | self.scheduler.set_timesteps(num_inference_steps)
86 | for t in self.progress_bar(self.scheduler.timesteps):
87 | model_output = self.unet(image, t).sample
88 |
89 | image = self.scheduler.step(model_output, t, image).prev_sample
90 |
91 | image = (image / 2 + 0.5).clamp(0, 1)
92 | image = image.cpu().permute(0, 2, 3, 1).numpy()
93 | if output_type == "pil":
94 | image = self.numpy_to_pil(image)
95 |
96 | if not return_dict:
97 | return (image,)
98 |
99 | return ImagePipelineOutput(images=image)
100 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/repaint/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_repaint import RePaintPipeline
2 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/score_sde_ve/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_score_sde_ve import ScoreSdeVePipeline
2 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/semantic_stable_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from enum import Enum
3 | from typing import List, Optional, Union
4 |
5 | import numpy as np
6 | import PIL
7 | from PIL import Image
8 |
9 | from ...utils import BaseOutput, is_torch_available, is_transformers_available
10 |
11 |
12 | @dataclass
13 | class SemanticStableDiffusionPipelineOutput(BaseOutput):
14 | """
15 | Output class for Stable Diffusion pipelines.
16 |
17 | Args:
18 | images (`List[PIL.Image.Image]` or `np.ndarray`)
19 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
20 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
21 | nsfw_content_detected (`List[bool]`)
22 | List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
23 | (nsfw) content, or `None` if safety checking could not be performed.
24 | """
25 |
26 | images: Union[List[PIL.Image.Image], np.ndarray]
27 | nsfw_content_detected: Optional[List[bool]]
28 |
29 |
30 | if is_transformers_available() and is_torch_available():
31 | from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline
32 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/spectrogram_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | from ...utils import is_note_seq_available, is_transformers_available, is_torch_available
3 | from ...utils import OptionalDependencyNotAvailable
4 |
5 |
6 | try:
7 | if not (is_transformers_available() and is_torch_available()):
8 | raise OptionalDependencyNotAvailable()
9 | except OptionalDependencyNotAvailable:
10 | from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
11 | else:
12 | from .notes_encoder import SpectrogramNotesEncoder
13 | from .continous_encoder import SpectrogramContEncoder
14 | from .pipeline_spectrogram_diffusion import (
15 | SpectrogramContEncoder,
16 | SpectrogramDiffusionPipeline,
17 | T5FilmDecoder,
18 | )
19 |
20 | try:
21 | if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
22 | raise OptionalDependencyNotAvailable()
23 | except OptionalDependencyNotAvailable:
24 | from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
25 | else:
26 | from .midi_utils import MidiProcessor
27 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/spectrogram_diffusion/continous_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The Music Spectrogram Diffusion Authors.
2 | # Copyright 2023 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 | import torch.nn as nn
18 | from transformers.modeling_utils import ModuleUtilsMixin
19 | from transformers.models.t5.modeling_t5 import (
20 | T5Block,
21 | T5Config,
22 | T5LayerNorm,
23 | )
24 |
25 | from ...configuration_utils import ConfigMixin, register_to_config
26 | from ...models import ModelMixin
27 |
28 |
29 | class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
30 | @register_to_config
31 | def __init__(
32 | self,
33 | input_dims: int,
34 | targets_context_length: int,
35 | d_model: int,
36 | dropout_rate: float,
37 | num_layers: int,
38 | num_heads: int,
39 | d_kv: int,
40 | d_ff: int,
41 | feed_forward_proj: str,
42 | is_decoder: bool = False,
43 | ):
44 | super().__init__()
45 |
46 | self.input_proj = nn.Linear(input_dims, d_model, bias=False)
47 |
48 | self.position_encoding = nn.Embedding(targets_context_length, d_model)
49 | self.position_encoding.weight.requires_grad = False
50 |
51 | self.dropout_pre = nn.Dropout(p=dropout_rate)
52 |
53 | t5config = T5Config(
54 | d_model=d_model,
55 | num_heads=num_heads,
56 | d_kv=d_kv,
57 | d_ff=d_ff,
58 | feed_forward_proj=feed_forward_proj,
59 | dropout_rate=dropout_rate,
60 | is_decoder=is_decoder,
61 | is_encoder_decoder=False,
62 | )
63 | self.encoders = nn.ModuleList()
64 | for lyr_num in range(num_layers):
65 | lyr = T5Block(t5config)
66 | self.encoders.append(lyr)
67 |
68 | self.layer_norm = T5LayerNorm(d_model)
69 | self.dropout_post = nn.Dropout(p=dropout_rate)
70 |
71 | def forward(self, encoder_inputs, encoder_inputs_mask):
72 | x = self.input_proj(encoder_inputs)
73 |
74 | # terminal relative positional encodings
75 | max_positions = encoder_inputs.shape[1]
76 | input_positions = torch.arange(max_positions, device=encoder_inputs.device)
77 |
78 | seq_lens = encoder_inputs_mask.sum(-1)
79 | input_positions = torch.roll(input_positions.unsqueeze(0), tuple(seq_lens.tolist()), dims=0)
80 | x += self.position_encoding(input_positions)
81 |
82 | x = self.dropout_pre(x)
83 |
84 | # inverted the attention mask
85 | input_shape = encoder_inputs.size()
86 | extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape)
87 |
88 | for lyr in self.encoders:
89 | x = lyr(x, extended_attention_mask)[0]
90 | x = self.layer_norm(x)
91 |
92 | return self.dropout_post(x), encoder_inputs_mask
93 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/spectrogram_diffusion/notes_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The Music Spectrogram Diffusion Authors.
2 | # Copyright 2023 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 | import torch.nn as nn
18 | from transformers.modeling_utils import ModuleUtilsMixin
19 | from transformers.models.t5.modeling_t5 import T5Block, T5Config, T5LayerNorm
20 |
21 | from ...configuration_utils import ConfigMixin, register_to_config
22 | from ...models import ModelMixin
23 |
24 |
25 | class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
26 | @register_to_config
27 | def __init__(
28 | self,
29 | max_length: int,
30 | vocab_size: int,
31 | d_model: int,
32 | dropout_rate: float,
33 | num_layers: int,
34 | num_heads: int,
35 | d_kv: int,
36 | d_ff: int,
37 | feed_forward_proj: str,
38 | is_decoder: bool = False,
39 | ):
40 | super().__init__()
41 |
42 | self.token_embedder = nn.Embedding(vocab_size, d_model)
43 |
44 | self.position_encoding = nn.Embedding(max_length, d_model)
45 | self.position_encoding.weight.requires_grad = False
46 |
47 | self.dropout_pre = nn.Dropout(p=dropout_rate)
48 |
49 | t5config = T5Config(
50 | vocab_size=vocab_size,
51 | d_model=d_model,
52 | num_heads=num_heads,
53 | d_kv=d_kv,
54 | d_ff=d_ff,
55 | dropout_rate=dropout_rate,
56 | feed_forward_proj=feed_forward_proj,
57 | is_decoder=is_decoder,
58 | is_encoder_decoder=False,
59 | )
60 |
61 | self.encoders = nn.ModuleList()
62 | for lyr_num in range(num_layers):
63 | lyr = T5Block(t5config)
64 | self.encoders.append(lyr)
65 |
66 | self.layer_norm = T5LayerNorm(d_model)
67 | self.dropout_post = nn.Dropout(p=dropout_rate)
68 |
69 | def forward(self, encoder_input_tokens, encoder_inputs_mask):
70 | x = self.token_embedder(encoder_input_tokens)
71 |
72 | seq_length = encoder_input_tokens.shape[1]
73 | inputs_positions = torch.arange(seq_length, device=encoder_input_tokens.device)
74 | x += self.position_encoding(inputs_positions)
75 |
76 | x = self.dropout_pre(x)
77 |
78 | # inverted the attention mask
79 | input_shape = encoder_input_tokens.size()
80 | extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape)
81 |
82 | for lyr in self.encoders:
83 | x = lyr(x, extended_attention_mask)[0]
84 | x = self.layer_norm(x)
85 |
86 | return self.dropout_post(x), encoder_inputs_mask
87 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/stable_diffusion/safety_checker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import torch
17 | import torch.nn as nn
18 | from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
19 |
20 | from ...utils import logging
21 |
22 |
23 | logger = logging.get_logger(__name__)
24 |
25 |
26 | def cosine_distance(image_embeds, text_embeds):
27 | normalized_image_embeds = nn.functional.normalize(image_embeds)
28 | normalized_text_embeds = nn.functional.normalize(text_embeds)
29 | return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
30 |
31 |
32 | class StableDiffusionSafetyChecker(PreTrainedModel):
33 | config_class = CLIPConfig
34 |
35 | _no_split_modules = ["CLIPEncoderLayer"]
36 |
37 | def __init__(self, config: CLIPConfig):
38 | super().__init__(config)
39 |
40 | self.vision_model = CLIPVisionModel(config.vision_config)
41 | self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
42 |
43 | self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
44 | self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
45 |
46 | self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
47 | self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
48 |
49 | @torch.no_grad()
50 | def forward(self, clip_input, images):
51 | pooled_output = self.vision_model(clip_input)[1] # pooled_output
52 | image_embeds = self.visual_projection(pooled_output)
53 |
54 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
55 | special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
56 | cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
57 |
58 | result = []
59 | batch_size = image_embeds.shape[0]
60 | for i in range(batch_size):
61 | result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
62 |
63 | # increase this value to create a stronger `nfsw` filter
64 | # at the cost of increasing the possibility of filtering benign images
65 | adjustment = 0.0
66 |
67 | for concept_idx in range(len(special_cos_dist[0])):
68 | concept_cos = special_cos_dist[i][concept_idx]
69 | concept_threshold = self.special_care_embeds_weights[concept_idx].item()
70 | result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
71 | if result_img["special_scores"][concept_idx] > 0:
72 | result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
73 | adjustment = 0.01
74 |
75 | for concept_idx in range(len(cos_dist[0])):
76 | concept_cos = cos_dist[i][concept_idx]
77 | concept_threshold = self.concept_embeds_weights[concept_idx].item()
78 | result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
79 | if result_img["concept_scores"][concept_idx] > 0:
80 | result_img["bad_concepts"].append(concept_idx)
81 |
82 | result.append(result_img)
83 |
84 | has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
85 |
86 | for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
87 | if has_nsfw_concept:
88 | images[idx] = np.zeros(images[idx].shape) # black image
89 |
90 | if any(has_nsfw_concepts):
91 | logger.warning(
92 | "Potential NSFW content was detected in one or more images. A black image will be returned instead."
93 | " Try again with a different prompt and/or seed."
94 | )
95 |
96 | return images, has_nsfw_concepts
97 |
98 | @torch.no_grad()
99 | def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
100 | pooled_output = self.vision_model(clip_input)[1] # pooled_output
101 | image_embeds = self.visual_projection(pooled_output)
102 |
103 | special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
104 | cos_dist = cosine_distance(image_embeds, self.concept_embeds)
105 |
106 | # increase this value to create a stronger `nsfw` filter
107 | # at the cost of increasing the possibility of filtering benign images
108 | adjustment = 0.0
109 |
110 | special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
111 | # special_scores = special_scores.round(decimals=3)
112 | special_care = torch.any(special_scores > 0, dim=1)
113 | special_adjustment = special_care * 0.01
114 | special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
115 |
116 | concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
117 | # concept_scores = concept_scores.round(decimals=3)
118 | has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
119 |
120 | images[has_nsfw_concepts] = 0.0 # black image
121 |
122 | return images, has_nsfw_concepts
123 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/stable_diffusion/safety_checker_flax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Optional, Tuple
16 |
17 | import jax
18 | import jax.numpy as jnp
19 | from flax import linen as nn
20 | from flax.core.frozen_dict import FrozenDict
21 | from transformers import CLIPConfig, FlaxPreTrainedModel
22 | from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule
23 |
24 |
25 | def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
26 | norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T
27 | norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T
28 | return jnp.matmul(norm_emb_1, norm_emb_2.T)
29 |
30 |
31 | class FlaxStableDiffusionSafetyCheckerModule(nn.Module):
32 | config: CLIPConfig
33 | dtype: jnp.dtype = jnp.float32
34 |
35 | def setup(self):
36 | self.vision_model = FlaxCLIPVisionModule(self.config.vision_config)
37 | self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype)
38 |
39 | self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim))
40 | self.special_care_embeds = self.param(
41 | "special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim)
42 | )
43 |
44 | self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,))
45 | self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,))
46 |
47 | def __call__(self, clip_input):
48 | pooled_output = self.vision_model(clip_input)[1]
49 | image_embeds = self.visual_projection(pooled_output)
50 |
51 | special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds)
52 | cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds)
53 |
54 | # increase this value to create a stronger `nfsw` filter
55 | # at the cost of increasing the possibility of filtering benign image inputs
56 | adjustment = 0.0
57 |
58 | special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment
59 | special_scores = jnp.round(special_scores, 3)
60 | is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True)
61 | # Use a lower threshold if an image has any special care concept
62 | special_adjustment = is_special_care * 0.01
63 |
64 | concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment
65 | concept_scores = jnp.round(concept_scores, 3)
66 | has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1)
67 |
68 | return has_nsfw_concepts
69 |
70 |
71 | class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
72 | config_class = CLIPConfig
73 | main_input_name = "clip_input"
74 | module_class = FlaxStableDiffusionSafetyCheckerModule
75 |
76 | def __init__(
77 | self,
78 | config: CLIPConfig,
79 | input_shape: Optional[Tuple] = None,
80 | seed: int = 0,
81 | dtype: jnp.dtype = jnp.float32,
82 | _do_init: bool = True,
83 | **kwargs,
84 | ):
85 | if input_shape is None:
86 | input_shape = (1, 224, 224, 3)
87 | module = self.module_class(config=config, dtype=dtype, **kwargs)
88 | super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
89 |
90 | def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
91 | # init input tensor
92 | clip_input = jax.random.normal(rng, input_shape)
93 |
94 | params_rng, dropout_rng = jax.random.split(rng)
95 | rngs = {"params": params_rng, "dropout": dropout_rng}
96 |
97 | random_params = self.module.init(rngs, clip_input)["params"]
98 |
99 | return random_params
100 |
101 | def __call__(
102 | self,
103 | clip_input,
104 | params: dict = None,
105 | ):
106 | clip_input = jnp.transpose(clip_input, (0, 2, 3, 1))
107 |
108 | return self.module.apply(
109 | {"params": params or self.params},
110 | jnp.array(clip_input, dtype=jnp.float32),
111 | rngs={},
112 | )
113 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Optional, Union
16 |
17 | import torch
18 | from torch import nn
19 |
20 | from ...configuration_utils import ConfigMixin, register_to_config
21 | from ...models.modeling_utils import ModelMixin
22 |
23 |
24 | class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin):
25 | """
26 | This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP.
27 |
28 | It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image
29 | embeddings.
30 | """
31 |
32 | @register_to_config
33 | def __init__(
34 | self,
35 | embedding_dim: int = 768,
36 | ):
37 | super().__init__()
38 |
39 | self.mean = nn.Parameter(torch.zeros(1, embedding_dim))
40 | self.std = nn.Parameter(torch.ones(1, embedding_dim))
41 |
42 | def to(
43 | self,
44 | torch_device: Optional[Union[str, torch.device]] = None,
45 | torch_dtype: Optional[torch.dtype] = None,
46 | ):
47 | self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype))
48 | self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype))
49 | return self
50 |
51 | def scale(self, embeds):
52 | embeds = (embeds - self.mean) * 1.0 / self.std
53 | return embeds
54 |
55 | def unscale(self, embeds):
56 | embeds = (embeds * self.std) + self.mean
57 | return embeds
58 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/stable_diffusion_safe/__init__.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from enum import Enum
3 | from typing import List, Optional, Union
4 |
5 | import numpy as np
6 | import PIL
7 | from PIL import Image
8 |
9 | from ...utils import BaseOutput, is_torch_available, is_transformers_available
10 |
11 |
12 | @dataclass
13 | class SafetyConfig(object):
14 | WEAK = {
15 | "sld_warmup_steps": 15,
16 | "sld_guidance_scale": 20,
17 | "sld_threshold": 0.0,
18 | "sld_momentum_scale": 0.0,
19 | "sld_mom_beta": 0.0,
20 | }
21 | MEDIUM = {
22 | "sld_warmup_steps": 10,
23 | "sld_guidance_scale": 1000,
24 | "sld_threshold": 0.01,
25 | "sld_momentum_scale": 0.3,
26 | "sld_mom_beta": 0.4,
27 | }
28 | STRONG = {
29 | "sld_warmup_steps": 7,
30 | "sld_guidance_scale": 2000,
31 | "sld_threshold": 0.025,
32 | "sld_momentum_scale": 0.5,
33 | "sld_mom_beta": 0.7,
34 | }
35 | MAX = {
36 | "sld_warmup_steps": 0,
37 | "sld_guidance_scale": 5000,
38 | "sld_threshold": 1.0,
39 | "sld_momentum_scale": 0.5,
40 | "sld_mom_beta": 0.7,
41 | }
42 |
43 |
44 | @dataclass
45 | class StableDiffusionSafePipelineOutput(BaseOutput):
46 | """
47 | Output class for Safe Stable Diffusion pipelines.
48 |
49 | Args:
50 | images (`List[PIL.Image.Image]` or `np.ndarray`)
51 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
52 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
53 | nsfw_content_detected (`List[bool]`)
54 | List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
55 | (nsfw) content, or `None` if safety checking could not be performed.
56 | images (`List[PIL.Image.Image]` or `np.ndarray`)
57 | List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work"
58 | (nsfw) content, or `None` if no safety check was performed or no images were flagged.
59 | applied_safety_concept (`str`)
60 | The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled
61 | """
62 |
63 | images: Union[List[PIL.Image.Image], np.ndarray]
64 | nsfw_content_detected: Optional[List[bool]]
65 | unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]]
66 | applied_safety_concept: Optional[str]
67 |
68 |
69 | if is_transformers_available() and is_torch_available():
70 | from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe
71 | from .safety_checker import SafeStableDiffusionSafetyChecker
72 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/stable_diffusion_safe/safety_checker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
18 |
19 | from ...utils import logging
20 |
21 |
22 | logger = logging.get_logger(__name__)
23 |
24 |
25 | def cosine_distance(image_embeds, text_embeds):
26 | normalized_image_embeds = nn.functional.normalize(image_embeds)
27 | normalized_text_embeds = nn.functional.normalize(text_embeds)
28 | return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
29 |
30 |
31 | class SafeStableDiffusionSafetyChecker(PreTrainedModel):
32 | config_class = CLIPConfig
33 |
34 | _no_split_modules = ["CLIPEncoderLayer"]
35 |
36 | def __init__(self, config: CLIPConfig):
37 | super().__init__(config)
38 |
39 | self.vision_model = CLIPVisionModel(config.vision_config)
40 | self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
41 |
42 | self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
43 | self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
44 |
45 | self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
46 | self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
47 |
48 | @torch.no_grad()
49 | def forward(self, clip_input, images):
50 | pooled_output = self.vision_model(clip_input)[1] # pooled_output
51 | image_embeds = self.visual_projection(pooled_output)
52 |
53 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
54 | special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
55 | cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
56 |
57 | result = []
58 | batch_size = image_embeds.shape[0]
59 | for i in range(batch_size):
60 | result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
61 |
62 | # increase this value to create a stronger `nfsw` filter
63 | # at the cost of increasing the possibility of filtering benign images
64 | adjustment = 0.0
65 |
66 | for concept_idx in range(len(special_cos_dist[0])):
67 | concept_cos = special_cos_dist[i][concept_idx]
68 | concept_threshold = self.special_care_embeds_weights[concept_idx].item()
69 | result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
70 | if result_img["special_scores"][concept_idx] > 0:
71 | result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
72 | adjustment = 0.01
73 |
74 | for concept_idx in range(len(cos_dist[0])):
75 | concept_cos = cos_dist[i][concept_idx]
76 | concept_threshold = self.concept_embeds_weights[concept_idx].item()
77 | result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
78 | if result_img["concept_scores"][concept_idx] > 0:
79 | result_img["bad_concepts"].append(concept_idx)
80 |
81 | result.append(result_img)
82 |
83 | has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
84 |
85 | return images, has_nsfw_concepts
86 |
87 | @torch.no_grad()
88 | def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
89 | pooled_output = self.vision_model(clip_input)[1] # pooled_output
90 | image_embeds = self.visual_projection(pooled_output)
91 |
92 | special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
93 | cos_dist = cosine_distance(image_embeds, self.concept_embeds)
94 |
95 | # increase this value to create a stronger `nsfw` filter
96 | # at the cost of increasing the possibility of filtering benign images
97 | adjustment = 0.0
98 |
99 | special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
100 | # special_scores = special_scores.round(decimals=3)
101 | special_care = torch.any(special_scores > 0, dim=1)
102 | special_adjustment = special_care * 0.01
103 | special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
104 |
105 | concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
106 | # concept_scores = concept_scores.round(decimals=3)
107 | has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
108 |
109 | return images, has_nsfw_concepts
110 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/stochastic_karras_ve/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_stochastic_karras_ve import KarrasVePipeline
2 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import List, Optional, Tuple, Union
16 |
17 | import torch
18 |
19 | from ...models import UNet2DModel
20 | from ...schedulers import KarrasVeScheduler
21 | from ...utils import randn_tensor
22 | from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
23 |
24 |
25 | class KarrasVePipeline(DiffusionPipeline):
26 | r"""
27 | Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
28 | the VE column of Table 1 from [1] for reference.
29 |
30 | [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
31 | https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
32 | differential equations." https://arxiv.org/abs/2011.13456
33 |
34 | Parameters:
35 | unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
36 | scheduler ([`KarrasVeScheduler`]):
37 | Scheduler for the diffusion process to be used in combination with `unet` to denoise the encoded image.
38 | """
39 |
40 | # add type hints for linting
41 | unet: UNet2DModel
42 | scheduler: KarrasVeScheduler
43 |
44 | def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
45 | super().__init__()
46 | self.register_modules(unet=unet, scheduler=scheduler)
47 |
48 | @torch.no_grad()
49 | def __call__(
50 | self,
51 | batch_size: int = 1,
52 | num_inference_steps: int = 50,
53 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
54 | output_type: Optional[str] = "pil",
55 | return_dict: bool = True,
56 | **kwargs,
57 | ) -> Union[Tuple, ImagePipelineOutput]:
58 | r"""
59 | Args:
60 | batch_size (`int`, *optional*, defaults to 1):
61 | The number of images to generate.
62 | generator (`torch.Generator`, *optional*):
63 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
64 | to make generation deterministic.
65 | num_inference_steps (`int`, *optional*, defaults to 50):
66 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
67 | expense of slower inference.
68 | output_type (`str`, *optional*, defaults to `"pil"`):
69 | The output format of the generate image. Choose between
70 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
71 | return_dict (`bool`, *optional*, defaults to `True`):
72 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
73 |
74 | Returns:
75 | [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
76 | True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
77 | """
78 |
79 | img_size = self.unet.config.sample_size
80 | shape = (batch_size, 3, img_size, img_size)
81 |
82 | model = self.unet
83 |
84 | # sample x_0 ~ N(0, sigma_0^2 * I)
85 | sample = randn_tensor(shape, generator=generator, device=self.device) * self.scheduler.init_noise_sigma
86 |
87 | self.scheduler.set_timesteps(num_inference_steps)
88 |
89 | for t in self.progress_bar(self.scheduler.timesteps):
90 | # here sigma_t == t_i from the paper
91 | sigma = self.scheduler.schedule[t]
92 | sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0
93 |
94 | # 1. Select temporarily increased noise level sigma_hat
95 | # 2. Add new noise to move from sample_i to sample_hat
96 | sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)
97 |
98 | # 3. Predict the noise residual given the noise magnitude `sigma_hat`
99 | # The model inputs and output are adjusted by following eq. (213) in [1].
100 | model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample
101 |
102 | # 4. Evaluate dx/dt at sigma_hat
103 | # 5. Take Euler step from sigma to sigma_prev
104 | step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat)
105 |
106 | if sigma_prev != 0:
107 | # 6. Apply 2nd order correction
108 | # The model inputs and output are adjusted by following eq. (213) in [1].
109 | model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample
110 | step_output = self.scheduler.step_correct(
111 | model_output,
112 | sigma_hat,
113 | sigma_prev,
114 | sample_hat,
115 | step_output.prev_sample,
116 | step_output["derivative"],
117 | )
118 | sample = step_output.prev_sample
119 |
120 | sample = (sample / 2 + 0.5).clamp(0, 1)
121 | image = sample.cpu().permute(0, 2, 3, 1).numpy()
122 | if output_type == "pil":
123 | image = self.numpy_to_pil(image)
124 |
125 | if not return_dict:
126 | return (image,)
127 |
128 | return ImagePipelineOutput(images=image)
129 |
--------------------------------------------------------------------------------
/diffuser/diffusers/pipelines/text_to_video_synthesis/__init__.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional, Union
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
8 |
9 |
10 | @dataclass
11 | class TextToVideoSDPipelineOutput(BaseOutput):
12 | """
13 | Output class for text to video pipelines.
14 |
15 | Args:
16 | frames (`List[np.ndarray]` or `torch.FloatTensor`)
17 | List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
18 | a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list
19 | denotes the video length i.e., the number of frames.
20 | """
21 |
22 | frames: Union[List[np.ndarray], torch.FloatTensor]
23 |
24 |
25 | try:
26 | if not (is_transformers_available() and is_torch_available()):
27 | raise OptionalDependencyNotAvailable()
28 | except OptionalDependencyNotAvailable:
29 | from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
30 | else:
31 | from .pipeline_text_to_video_synth import TextToVideoSDPipeline # noqa: F401
32 | from .pipeline_text_to_video_zero import TextToVideoZeroPipeline
33 |
--------------------------------------------------------------------------------
/diffuser/diffusers/schedulers/README.md:
--------------------------------------------------------------------------------
1 | # Schedulers
2 |
3 | For more information on the schedulers, please refer to the [docs](https://huggingface.co/docs/diffusers/api/schedulers/overview).
--------------------------------------------------------------------------------
/diffuser/diffusers/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from ..utils import OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available
17 |
18 |
19 | try:
20 | if not is_torch_available():
21 | raise OptionalDependencyNotAvailable()
22 | except OptionalDependencyNotAvailable:
23 | from ..utils.dummy_pt_objects import * # noqa F403
24 | else:
25 | from .scheduling_ddim import DDIMScheduler
26 | from .scheduling_sde import SDEScheduler
27 | from .scheduling_ddim_inverse import DDIMInverseScheduler
28 | from .scheduling_ddpm import DDPMScheduler
29 | from .scheduling_deis_multistep import DEISMultistepScheduler
30 | from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
31 | from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
32 | from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
33 | from .scheduling_euler_discrete import EulerDiscreteScheduler
34 | from .scheduling_heun_discrete import HeunDiscreteScheduler
35 | from .scheduling_ipndm import IPNDMScheduler
36 | from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
37 | from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
38 | from .scheduling_karras_ve import KarrasVeScheduler
39 | from .scheduling_pndm import PNDMScheduler
40 | from .scheduling_repaint import RePaintScheduler
41 | from .scheduling_sde_ve import ScoreSdeVeScheduler
42 | from .scheduling_sde_vp import ScoreSdeVpScheduler
43 | from .scheduling_unclip import UnCLIPScheduler
44 | from .scheduling_unipc_multistep import UniPCMultistepScheduler
45 | from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
46 | from .scheduling_vq_diffusion import VQDiffusionScheduler
47 |
48 | try:
49 | if not is_flax_available():
50 | raise OptionalDependencyNotAvailable()
51 | except OptionalDependencyNotAvailable:
52 | from ..utils.dummy_flax_objects import * # noqa F403
53 | else:
54 | from .scheduling_ddim_flax import FlaxDDIMScheduler
55 | from .scheduling_ddpm_flax import FlaxDDPMScheduler
56 | from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
57 | from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
58 | from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
59 | from .scheduling_pndm_flax import FlaxPNDMScheduler
60 | from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
61 | from .scheduling_utils_flax import (
62 | FlaxKarrasDiffusionSchedulers,
63 | FlaxSchedulerMixin,
64 | FlaxSchedulerOutput,
65 | broadcast_to_shape_from_left,
66 | )
67 |
68 |
69 | try:
70 | if not (is_torch_available() and is_scipy_available()):
71 | raise OptionalDependencyNotAvailable()
72 | except OptionalDependencyNotAvailable:
73 | from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
74 | else:
75 | from .scheduling_lms_discrete import LMSDiscreteScheduler
76 |
--------------------------------------------------------------------------------
/diffuser/diffusers/schedulers/scheduling_sde_vp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
16 |
17 | import math
18 | from typing import Union
19 |
20 | import torch
21 |
22 | from ..configuration_utils import ConfigMixin, register_to_config
23 | from ..utils import randn_tensor
24 | from .scheduling_utils import SchedulerMixin
25 |
26 |
27 | class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
28 | """
29 | The variance preserving stochastic differential equation (SDE) scheduler.
30 |
31 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
32 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
33 | [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
34 | [`~SchedulerMixin.from_pretrained`] functions.
35 |
36 | For more information, see the original paper: https://arxiv.org/abs/2011.13456
37 |
38 | UNDER CONSTRUCTION
39 |
40 | """
41 |
42 | order = 1
43 |
44 | @register_to_config
45 | def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
46 | self.sigmas = None
47 | self.discrete_sigmas = None
48 | self.timesteps = None
49 |
50 | def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None):
51 | self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device)
52 |
53 | def step_pred(self, score, x, t, generator=None):
54 | if self.timesteps is None:
55 | raise ValueError(
56 | "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
57 | )
58 |
59 | # TODO(Patrick) better comments + non-PyTorch
60 | # postprocess model score
61 | log_mean_coeff = (
62 | -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
63 | )
64 | std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
65 | std = std.flatten()
66 | while len(std.shape) < len(score.shape):
67 | std = std.unsqueeze(-1)
68 | score = -score / std
69 |
70 | # compute
71 | dt = -1.0 / len(self.timesteps)
72 |
73 | beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
74 | beta_t = beta_t.flatten()
75 | while len(beta_t.shape) < len(x.shape):
76 | beta_t = beta_t.unsqueeze(-1)
77 | drift = -0.5 * beta_t * x
78 |
79 | diffusion = torch.sqrt(beta_t)
80 | drift = drift - diffusion**2 * score
81 | x_mean = x + drift * dt
82 |
83 | # add noise
84 | noise = randn_tensor(x.shape, layout=x.layout, generator=generator, device=x.device, dtype=x.dtype)
85 | x = x_mean + diffusion * math.sqrt(-dt) * noise
86 |
87 | return x, x_mean
88 |
89 | def __len__(self):
90 | return self.config.num_train_timesteps
91 |
--------------------------------------------------------------------------------
/diffuser/diffusers/test.py:
--------------------------------------------------------------------------------
1 | # make sure you're logged in with \`huggingface-cli login\`
2 | from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
3 |
4 | pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
5 | pipe = pipe.to("cuda")
6 |
7 | prompt = "a photo of an astronaut riding a horse on mars"
8 | image = pipe(prompt).images[0]
9 |
10 | image.save("astronaut_rides_horse.png")
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 |
18 | from packaging import version
19 |
20 | from .. import __version__
21 | from .accelerate_utils import apply_forward_hook
22 | from .constants import (
23 | CONFIG_NAME,
24 | DEPRECATED_REVISION_ARGS,
25 | DIFFUSERS_CACHE,
26 | DIFFUSERS_DYNAMIC_MODULE_NAME,
27 | FLAX_WEIGHTS_NAME,
28 | HF_MODULES_CACHE,
29 | HUGGINGFACE_CO_RESOLVE_ENDPOINT,
30 | ONNX_EXTERNAL_WEIGHTS_NAME,
31 | ONNX_WEIGHTS_NAME,
32 | SAFETENSORS_WEIGHTS_NAME,
33 | TEXT_ENCODER_TARGET_MODULES,
34 | WEIGHTS_NAME,
35 | )
36 | from .deprecation_utils import deprecate
37 | from .doc_utils import replace_example_docstring
38 | from .dynamic_modules_utils import get_class_from_dynamic_module
39 | from .hub_utils import (
40 | HF_HUB_OFFLINE,
41 | _add_variant,
42 | _get_model_file,
43 | extract_commit_hash,
44 | http_user_agent,
45 | )
46 | from .import_utils import (
47 | ENV_VARS_TRUE_AND_AUTO_VALUES,
48 | ENV_VARS_TRUE_VALUES,
49 | USE_JAX,
50 | USE_TF,
51 | USE_TORCH,
52 | DummyObject,
53 | OptionalDependencyNotAvailable,
54 | is_accelerate_available,
55 | is_accelerate_version,
56 | is_flax_available,
57 | is_inflect_available,
58 | is_k_diffusion_available,
59 | is_k_diffusion_version,
60 | is_librosa_available,
61 | is_note_seq_available,
62 | is_omegaconf_available,
63 | is_onnx_available,
64 | is_safetensors_available,
65 | is_scipy_available,
66 | is_tensorboard_available,
67 | is_tf_available,
68 | is_torch_available,
69 | is_torch_version,
70 | is_transformers_available,
71 | is_transformers_version,
72 | is_unidecode_available,
73 | is_wandb_available,
74 | is_xformers_available,
75 | requires_backends,
76 | )
77 | from .logging import get_logger
78 | from .outputs import BaseOutput
79 | from .pil_utils import PIL_INTERPOLATION
80 | from .torch_utils import is_compiled_module, randn_tensor
81 |
82 |
83 | if is_torch_available():
84 | from .testing_utils import (
85 | floats_tensor,
86 | load_hf_numpy,
87 | load_image,
88 | load_numpy,
89 | load_pt,
90 | nightly,
91 | parse_flag_from_env,
92 | print_tensor_test,
93 | require_torch_2,
94 | require_torch_gpu,
95 | skip_mps,
96 | slow,
97 | torch_all_close,
98 | torch_device,
99 | )
100 |
101 | from .testing_utils import export_to_video
102 |
103 |
104 | logger = get_logger(__name__)
105 |
106 |
107 | def check_min_version(min_version):
108 | if version.parse(__version__) < version.parse(min_version):
109 | if "dev" in min_version:
110 | error_message = (
111 | "This example requires a source install from HuggingFace diffusers (see "
112 | "`https://huggingface.co/docs/diffusers/installation#install-from-source`),"
113 | )
114 | else:
115 | error_message = f"This example requires a minimum version of {min_version},"
116 | error_message += f" but the version found is {__version__}.\n"
117 | raise ImportError(error_message)
118 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/accelerate_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Accelerate utilities: Utilities related to accelerate
16 | """
17 |
18 | from packaging import version
19 |
20 | from .import_utils import is_accelerate_available
21 |
22 |
23 | if is_accelerate_available():
24 | import accelerate
25 |
26 |
27 | def apply_forward_hook(method):
28 | """
29 | Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful
30 | for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the
31 | appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`].
32 |
33 | This decorator looks inside the internal `_hf_hook` property to find a registered offload hook.
34 |
35 | :param method: The method to decorate. This method should be a method of a PyTorch module.
36 | """
37 | if not is_accelerate_available():
38 | return method
39 | accelerate_version = version.parse(accelerate.__version__).base_version
40 | if version.parse(accelerate_version) < version.parse("0.17.0"):
41 | return method
42 |
43 | def wrapper(self, *args, **kwargs):
44 | if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"):
45 | self._hf_hook.pre_forward(self)
46 | return method(self, *args, **kwargs)
47 |
48 | return wrapper
49 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 |
16 | from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
17 |
18 |
19 | default_cache_path = HUGGINGFACE_HUB_CACHE
20 |
21 |
22 | CONFIG_NAME = "config.json"
23 | WEIGHTS_NAME = "diffusion_pytorch_model.bin"
24 | FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
25 | ONNX_WEIGHTS_NAME = "model.onnx"
26 | SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
27 | ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
28 | HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
29 | DIFFUSERS_CACHE = default_cache_path
30 | DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
31 | HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
32 | DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
33 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"]
34 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/deprecation_utils.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import warnings
3 | from typing import Any, Dict, Optional, Union
4 |
5 | from packaging import version
6 |
7 |
8 | def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True):
9 | from .. import __version__
10 |
11 | deprecated_kwargs = take_from
12 | values = ()
13 | if not isinstance(args[0], tuple):
14 | args = (args,)
15 |
16 | for attribute, version_name, message in args:
17 | if version.parse(version.parse(__version__).base_version) >= version.parse(version_name):
18 | raise ValueError(
19 | f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'"
20 | f" version {__version__} is >= {version_name}"
21 | )
22 |
23 | warning = None
24 | if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs:
25 | values += (deprecated_kwargs.pop(attribute),)
26 | warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}."
27 | elif hasattr(deprecated_kwargs, attribute):
28 | values += (getattr(deprecated_kwargs, attribute),)
29 | warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}."
30 | elif deprecated_kwargs is None:
31 | warning = f"`{attribute}` is deprecated and will be removed in version {version_name}."
32 |
33 | if warning is not None:
34 | warning = warning + " " if standard_warn else ""
35 | warnings.warn(warning + message, FutureWarning, stacklevel=2)
36 |
37 | if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
38 | call_frame = inspect.getouterframes(inspect.currentframe())[1]
39 | filename = call_frame.filename
40 | line_number = call_frame.lineno
41 | function = call_frame.function
42 | key, value = next(iter(deprecated_kwargs.items()))
43 | raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`")
44 |
45 | if len(values) == 0:
46 | return
47 | elif len(values) == 1:
48 | return values[0]
49 | return values
50 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/doc_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Doc utilities: Utilities related to documentation
16 | """
17 | import re
18 |
19 |
20 | def replace_example_docstring(example_docstring):
21 | def docstring_decorator(fn):
22 | func_doc = fn.__doc__
23 | lines = func_doc.split("\n")
24 | i = 0
25 | while i < len(lines) and re.search(r"^\s*Examples?:\s*$", lines[i]) is None:
26 | i += 1
27 | if i < len(lines):
28 | lines[i] = example_docstring
29 | func_doc = "\n".join(lines)
30 | else:
31 | raise ValueError(
32 | f"The function {fn} should have an empty 'Examples:' in its docstring as placeholder, "
33 | f"current docstring is:\n{func_doc}"
34 | )
35 | fn.__doc__ = func_doc
36 | return fn
37 |
38 | return docstring_decorator
39 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/dummy_flax_and_transformers_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject):
6 | _backends = ["flax", "transformers"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["flax", "transformers"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["flax", "transformers"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["flax", "transformers"])
18 |
19 |
20 | class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject):
21 | _backends = ["flax", "transformers"]
22 |
23 | def __init__(self, *args, **kwargs):
24 | requires_backends(self, ["flax", "transformers"])
25 |
26 | @classmethod
27 | def from_config(cls, *args, **kwargs):
28 | requires_backends(cls, ["flax", "transformers"])
29 |
30 | @classmethod
31 | def from_pretrained(cls, *args, **kwargs):
32 | requires_backends(cls, ["flax", "transformers"])
33 |
34 |
35 | class FlaxStableDiffusionInpaintPipeline(metaclass=DummyObject):
36 | _backends = ["flax", "transformers"]
37 |
38 | def __init__(self, *args, **kwargs):
39 | requires_backends(self, ["flax", "transformers"])
40 |
41 | @classmethod
42 | def from_config(cls, *args, **kwargs):
43 | requires_backends(cls, ["flax", "transformers"])
44 |
45 | @classmethod
46 | def from_pretrained(cls, *args, **kwargs):
47 | requires_backends(cls, ["flax", "transformers"])
48 |
49 |
50 | class FlaxStableDiffusionPipeline(metaclass=DummyObject):
51 | _backends = ["flax", "transformers"]
52 |
53 | def __init__(self, *args, **kwargs):
54 | requires_backends(self, ["flax", "transformers"])
55 |
56 | @classmethod
57 | def from_config(cls, *args, **kwargs):
58 | requires_backends(cls, ["flax", "transformers"])
59 |
60 | @classmethod
61 | def from_pretrained(cls, *args, **kwargs):
62 | requires_backends(cls, ["flax", "transformers"])
63 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/dummy_flax_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class FlaxControlNetModel(metaclass=DummyObject):
6 | _backends = ["flax"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["flax"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["flax"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["flax"])
18 |
19 |
20 | class FlaxModelMixin(metaclass=DummyObject):
21 | _backends = ["flax"]
22 |
23 | def __init__(self, *args, **kwargs):
24 | requires_backends(self, ["flax"])
25 |
26 | @classmethod
27 | def from_config(cls, *args, **kwargs):
28 | requires_backends(cls, ["flax"])
29 |
30 | @classmethod
31 | def from_pretrained(cls, *args, **kwargs):
32 | requires_backends(cls, ["flax"])
33 |
34 |
35 | class FlaxUNet2DConditionModel(metaclass=DummyObject):
36 | _backends = ["flax"]
37 |
38 | def __init__(self, *args, **kwargs):
39 | requires_backends(self, ["flax"])
40 |
41 | @classmethod
42 | def from_config(cls, *args, **kwargs):
43 | requires_backends(cls, ["flax"])
44 |
45 | @classmethod
46 | def from_pretrained(cls, *args, **kwargs):
47 | requires_backends(cls, ["flax"])
48 |
49 |
50 | class FlaxAutoencoderKL(metaclass=DummyObject):
51 | _backends = ["flax"]
52 |
53 | def __init__(self, *args, **kwargs):
54 | requires_backends(self, ["flax"])
55 |
56 | @classmethod
57 | def from_config(cls, *args, **kwargs):
58 | requires_backends(cls, ["flax"])
59 |
60 | @classmethod
61 | def from_pretrained(cls, *args, **kwargs):
62 | requires_backends(cls, ["flax"])
63 |
64 |
65 | class FlaxDiffusionPipeline(metaclass=DummyObject):
66 | _backends = ["flax"]
67 |
68 | def __init__(self, *args, **kwargs):
69 | requires_backends(self, ["flax"])
70 |
71 | @classmethod
72 | def from_config(cls, *args, **kwargs):
73 | requires_backends(cls, ["flax"])
74 |
75 | @classmethod
76 | def from_pretrained(cls, *args, **kwargs):
77 | requires_backends(cls, ["flax"])
78 |
79 |
80 | class FlaxDDIMScheduler(metaclass=DummyObject):
81 | _backends = ["flax"]
82 |
83 | def __init__(self, *args, **kwargs):
84 | requires_backends(self, ["flax"])
85 |
86 | @classmethod
87 | def from_config(cls, *args, **kwargs):
88 | requires_backends(cls, ["flax"])
89 |
90 | @classmethod
91 | def from_pretrained(cls, *args, **kwargs):
92 | requires_backends(cls, ["flax"])
93 |
94 |
95 | class FlaxDDPMScheduler(metaclass=DummyObject):
96 | _backends = ["flax"]
97 |
98 | def __init__(self, *args, **kwargs):
99 | requires_backends(self, ["flax"])
100 |
101 | @classmethod
102 | def from_config(cls, *args, **kwargs):
103 | requires_backends(cls, ["flax"])
104 |
105 | @classmethod
106 | def from_pretrained(cls, *args, **kwargs):
107 | requires_backends(cls, ["flax"])
108 |
109 |
110 | class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject):
111 | _backends = ["flax"]
112 |
113 | def __init__(self, *args, **kwargs):
114 | requires_backends(self, ["flax"])
115 |
116 | @classmethod
117 | def from_config(cls, *args, **kwargs):
118 | requires_backends(cls, ["flax"])
119 |
120 | @classmethod
121 | def from_pretrained(cls, *args, **kwargs):
122 | requires_backends(cls, ["flax"])
123 |
124 |
125 | class FlaxKarrasVeScheduler(metaclass=DummyObject):
126 | _backends = ["flax"]
127 |
128 | def __init__(self, *args, **kwargs):
129 | requires_backends(self, ["flax"])
130 |
131 | @classmethod
132 | def from_config(cls, *args, **kwargs):
133 | requires_backends(cls, ["flax"])
134 |
135 | @classmethod
136 | def from_pretrained(cls, *args, **kwargs):
137 | requires_backends(cls, ["flax"])
138 |
139 |
140 | class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
141 | _backends = ["flax"]
142 |
143 | def __init__(self, *args, **kwargs):
144 | requires_backends(self, ["flax"])
145 |
146 | @classmethod
147 | def from_config(cls, *args, **kwargs):
148 | requires_backends(cls, ["flax"])
149 |
150 | @classmethod
151 | def from_pretrained(cls, *args, **kwargs):
152 | requires_backends(cls, ["flax"])
153 |
154 |
155 | class FlaxPNDMScheduler(metaclass=DummyObject):
156 | _backends = ["flax"]
157 |
158 | def __init__(self, *args, **kwargs):
159 | requires_backends(self, ["flax"])
160 |
161 | @classmethod
162 | def from_config(cls, *args, **kwargs):
163 | requires_backends(cls, ["flax"])
164 |
165 | @classmethod
166 | def from_pretrained(cls, *args, **kwargs):
167 | requires_backends(cls, ["flax"])
168 |
169 |
170 | class FlaxSchedulerMixin(metaclass=DummyObject):
171 | _backends = ["flax"]
172 |
173 | def __init__(self, *args, **kwargs):
174 | requires_backends(self, ["flax"])
175 |
176 | @classmethod
177 | def from_config(cls, *args, **kwargs):
178 | requires_backends(cls, ["flax"])
179 |
180 | @classmethod
181 | def from_pretrained(cls, *args, **kwargs):
182 | requires_backends(cls, ["flax"])
183 |
184 |
185 | class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
186 | _backends = ["flax"]
187 |
188 | def __init__(self, *args, **kwargs):
189 | requires_backends(self, ["flax"])
190 |
191 | @classmethod
192 | def from_config(cls, *args, **kwargs):
193 | requires_backends(cls, ["flax"])
194 |
195 | @classmethod
196 | def from_pretrained(cls, *args, **kwargs):
197 | requires_backends(cls, ["flax"])
198 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/dummy_note_seq_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class MidiProcessor(metaclass=DummyObject):
6 | _backends = ["note_seq"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["note_seq"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["note_seq"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["note_seq"])
18 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/dummy_onnx_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class OnnxRuntimeModel(metaclass=DummyObject):
6 | _backends = ["onnx"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["onnx"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["onnx"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["onnx"])
18 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/dummy_torch_and_librosa_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class AudioDiffusionPipeline(metaclass=DummyObject):
6 | _backends = ["torch", "librosa"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["torch", "librosa"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["torch", "librosa"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["torch", "librosa"])
18 |
19 |
20 | class Mel(metaclass=DummyObject):
21 | _backends = ["torch", "librosa"]
22 |
23 | def __init__(self, *args, **kwargs):
24 | requires_backends(self, ["torch", "librosa"])
25 |
26 | @classmethod
27 | def from_config(cls, *args, **kwargs):
28 | requires_backends(cls, ["torch", "librosa"])
29 |
30 | @classmethod
31 | def from_pretrained(cls, *args, **kwargs):
32 | requires_backends(cls, ["torch", "librosa"])
33 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/dummy_torch_and_scipy_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class LMSDiscreteScheduler(metaclass=DummyObject):
6 | _backends = ["torch", "scipy"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["torch", "scipy"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["torch", "scipy"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["torch", "scipy"])
18 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class StableDiffusionKDiffusionPipeline(metaclass=DummyObject):
6 | _backends = ["torch", "transformers", "k_diffusion"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["torch", "transformers", "k_diffusion"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["torch", "transformers", "k_diffusion"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["torch", "transformers", "k_diffusion"])
18 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class OnnxStableDiffusionImg2ImgPipeline(metaclass=DummyObject):
6 | _backends = ["torch", "transformers", "onnx"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["torch", "transformers", "onnx"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["torch", "transformers", "onnx"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["torch", "transformers", "onnx"])
18 |
19 |
20 | class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject):
21 | _backends = ["torch", "transformers", "onnx"]
22 |
23 | def __init__(self, *args, **kwargs):
24 | requires_backends(self, ["torch", "transformers", "onnx"])
25 |
26 | @classmethod
27 | def from_config(cls, *args, **kwargs):
28 | requires_backends(cls, ["torch", "transformers", "onnx"])
29 |
30 | @classmethod
31 | def from_pretrained(cls, *args, **kwargs):
32 | requires_backends(cls, ["torch", "transformers", "onnx"])
33 |
34 |
35 | class OnnxStableDiffusionInpaintPipelineLegacy(metaclass=DummyObject):
36 | _backends = ["torch", "transformers", "onnx"]
37 |
38 | def __init__(self, *args, **kwargs):
39 | requires_backends(self, ["torch", "transformers", "onnx"])
40 |
41 | @classmethod
42 | def from_config(cls, *args, **kwargs):
43 | requires_backends(cls, ["torch", "transformers", "onnx"])
44 |
45 | @classmethod
46 | def from_pretrained(cls, *args, **kwargs):
47 | requires_backends(cls, ["torch", "transformers", "onnx"])
48 |
49 |
50 | class OnnxStableDiffusionPipeline(metaclass=DummyObject):
51 | _backends = ["torch", "transformers", "onnx"]
52 |
53 | def __init__(self, *args, **kwargs):
54 | requires_backends(self, ["torch", "transformers", "onnx"])
55 |
56 | @classmethod
57 | def from_config(cls, *args, **kwargs):
58 | requires_backends(cls, ["torch", "transformers", "onnx"])
59 |
60 | @classmethod
61 | def from_pretrained(cls, *args, **kwargs):
62 | requires_backends(cls, ["torch", "transformers", "onnx"])
63 |
64 |
65 | class OnnxStableDiffusionUpscalePipeline(metaclass=DummyObject):
66 | _backends = ["torch", "transformers", "onnx"]
67 |
68 | def __init__(self, *args, **kwargs):
69 | requires_backends(self, ["torch", "transformers", "onnx"])
70 |
71 | @classmethod
72 | def from_config(cls, *args, **kwargs):
73 | requires_backends(cls, ["torch", "transformers", "onnx"])
74 |
75 | @classmethod
76 | def from_pretrained(cls, *args, **kwargs):
77 | requires_backends(cls, ["torch", "transformers", "onnx"])
78 |
79 |
80 | class StableDiffusionOnnxPipeline(metaclass=DummyObject):
81 | _backends = ["torch", "transformers", "onnx"]
82 |
83 | def __init__(self, *args, **kwargs):
84 | requires_backends(self, ["torch", "transformers", "onnx"])
85 |
86 | @classmethod
87 | def from_config(cls, *args, **kwargs):
88 | requires_backends(cls, ["torch", "transformers", "onnx"])
89 |
90 | @classmethod
91 | def from_pretrained(cls, *args, **kwargs):
92 | requires_backends(cls, ["torch", "transformers", "onnx"])
93 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class SpectrogramDiffusionPipeline(metaclass=DummyObject):
6 | _backends = ["transformers", "torch", "note_seq"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["transformers", "torch", "note_seq"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["transformers", "torch", "note_seq"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["transformers", "torch", "note_seq"])
18 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/model_card_template.md:
--------------------------------------------------------------------------------
1 | ---
2 | {{ card_data }}
3 | ---
4 |
5 |
7 |
8 | # {{ model_name | default("Diffusion Model") }}
9 |
10 | ## Model description
11 |
12 | This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library
13 | on the `{{ dataset_name }}` dataset.
14 |
15 | ## Intended uses & limitations
16 |
17 | #### How to use
18 |
19 | ```python
20 | # TODO: add an example code snippet for running this diffusion pipeline
21 | ```
22 |
23 | #### Limitations and bias
24 |
25 | [TODO: provide examples of latent issues and potential remediations]
26 |
27 | ## Training data
28 |
29 | [TODO: describe the data used to train the model]
30 |
31 | ### Training hyperparameters
32 |
33 | The following hyperparameters were used during training:
34 | - learning_rate: {{ learning_rate }}
35 | - train_batch_size: {{ train_batch_size }}
36 | - eval_batch_size: {{ eval_batch_size }}
37 | - gradient_accumulation_steps: {{ gradient_accumulation_steps }}
38 | - optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }}
39 | - lr_scheduler: {{ lr_scheduler }}
40 | - lr_warmup_steps: {{ lr_warmup_steps }}
41 | - ema_inv_gamma: {{ ema_inv_gamma }}
42 | - ema_inv_gamma: {{ ema_power }}
43 | - ema_inv_gamma: {{ ema_max_decay }}
44 | - mixed_precision: {{ mixed_precision }}
45 |
46 | ### Training results
47 |
48 | 📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars)
49 |
50 |
51 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/outputs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Generic utilities
16 | """
17 |
18 | from collections import OrderedDict
19 | from dataclasses import fields
20 | from typing import Any, Tuple
21 |
22 | import numpy as np
23 |
24 | from .import_utils import is_torch_available
25 |
26 |
27 | def is_tensor(x):
28 | """
29 | Tests if `x` is a `torch.Tensor` or `np.ndarray`.
30 | """
31 | if is_torch_available():
32 | import torch
33 |
34 | if isinstance(x, torch.Tensor):
35 | return True
36 |
37 | return isinstance(x, np.ndarray)
38 |
39 |
40 | class BaseOutput(OrderedDict):
41 | """
42 | Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
43 | tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
44 | python dictionary.
45 |
46 |
47 |
48 | You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
49 | before.
50 |
51 |
52 | """
53 |
54 | def __post_init__(self):
55 | class_fields = fields(self)
56 |
57 | # Safety and consistency checks
58 | if not len(class_fields):
59 | raise ValueError(f"{self.__class__.__name__} has no fields.")
60 |
61 | first_field = getattr(self, class_fields[0].name)
62 | other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
63 |
64 | if other_fields_are_none and isinstance(first_field, dict):
65 | for key, value in first_field.items():
66 | self[key] = value
67 | else:
68 | for field in class_fields:
69 | v = getattr(self, field.name)
70 | if v is not None:
71 | self[field.name] = v
72 |
73 | def __delitem__(self, *args, **kwargs):
74 | raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
75 |
76 | def setdefault(self, *args, **kwargs):
77 | raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
78 |
79 | def pop(self, *args, **kwargs):
80 | raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
81 |
82 | def update(self, *args, **kwargs):
83 | raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
84 |
85 | def __getitem__(self, k):
86 | if isinstance(k, str):
87 | inner_dict = dict(self.items())
88 | return inner_dict[k]
89 | else:
90 | return self.to_tuple()[k]
91 |
92 | def __setattr__(self, name, value):
93 | if name in self.keys() and value is not None:
94 | # Don't call self.__setitem__ to avoid recursion errors
95 | super().__setitem__(name, value)
96 | super().__setattr__(name, value)
97 |
98 | def __setitem__(self, key, value):
99 | # Will raise a KeyException if needed
100 | super().__setitem__(key, value)
101 | # Don't call self.__setattr__ to avoid recursion errors
102 | super().__setattr__(key, value)
103 |
104 | def to_tuple(self) -> Tuple[Any]:
105 | """
106 | Convert self to a tuple containing all the attributes/keys that are not `None`.
107 | """
108 | return tuple(self[k] for k in self.keys())
109 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/pil_utils.py:
--------------------------------------------------------------------------------
1 | import PIL.Image
2 | import PIL.ImageOps
3 | from packaging import version
4 |
5 |
6 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
7 | PIL_INTERPOLATION = {
8 | "linear": PIL.Image.Resampling.BILINEAR,
9 | "bilinear": PIL.Image.Resampling.BILINEAR,
10 | "bicubic": PIL.Image.Resampling.BICUBIC,
11 | "lanczos": PIL.Image.Resampling.LANCZOS,
12 | "nearest": PIL.Image.Resampling.NEAREST,
13 | }
14 | else:
15 | PIL_INTERPOLATION = {
16 | "linear": PIL.Image.LINEAR,
17 | "bilinear": PIL.Image.BILINEAR,
18 | "bicubic": PIL.Image.BICUBIC,
19 | "lanczos": PIL.Image.LANCZOS,
20 | "nearest": PIL.Image.NEAREST,
21 | }
22 |
--------------------------------------------------------------------------------
/diffuser/diffusers/utils/torch_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | PyTorch utilities: Utilities related to PyTorch
16 | """
17 | from typing import List, Optional, Tuple, Union
18 |
19 | from . import logging
20 | from .import_utils import is_torch_available, is_torch_version
21 |
22 |
23 | if is_torch_available():
24 | import torch
25 |
26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27 |
28 |
29 | def randn_tensor(
30 | shape: Union[Tuple, List],
31 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
32 | device: Optional["torch.device"] = None,
33 | dtype: Optional["torch.dtype"] = None,
34 | layout: Optional["torch.layout"] = None,
35 | ):
36 | """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When
37 | passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor
38 | will always be created on CPU.
39 | """
40 | # device on which tensor is created defaults to device
41 | rand_device = device
42 | batch_size = shape[0]
43 |
44 | layout = layout or torch.strided
45 | device = device or torch.device("cpu")
46 |
47 | if generator is not None:
48 | gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
49 | if gen_device_type != device.type and gen_device_type == "cpu":
50 | rand_device = "cpu"
51 | if device != "mps":
52 | logger.info(
53 | f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
54 | f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
55 | f" slighly speed up this function by passing a generator that was created on the {device} device."
56 | )
57 | elif gen_device_type != device.type and gen_device_type == "cuda":
58 | raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
59 |
60 | if isinstance(generator, list):
61 | shape = (1,) + shape[1:]
62 | latents = [
63 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
64 | for i in range(batch_size)
65 | ]
66 | latents = torch.cat(latents, dim=0).to(device)
67 | else:
68 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
69 |
70 | return latents
71 |
72 |
73 | def is_compiled_module(module):
74 | """Check whether the module was compiled with torch.compile()"""
75 | if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
76 | return False
77 | return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
78 |
--------------------------------------------------------------------------------
/diffuser/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | from .util import EasyDict, make_cache_dir_path
9 |
--------------------------------------------------------------------------------
/diffuser/dnnlib/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/dnnlib/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/diffuser/dnnlib/__pycache__/util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/dnnlib/__pycache__/util.cpython-39.pyc
--------------------------------------------------------------------------------
/diffuser/eval_clip_score.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import open_clip
4 | from coco_data_loader import text_image_pair
5 | import argparse
6 | from tqdm import tqdm
7 | import clip
8 | import aesthetic_score
9 |
10 | parser = argparse.ArgumentParser(description='Generate images with stable diffusion')
11 | parser.add_argument('--steps', type=int, default=50, help='number of inference steps during sampling')
12 | parser.add_argument('--generate_seed', type=int, default=6)
13 | parser.add_argument('--bs', type=int, default=16)
14 | parser.add_argument('--max_cnt', type=int, default=100, help='number of maximum geneated samples')
15 | parser.add_argument('--csv_path', type=str, default='./generated_images/subset.csv')
16 | parser.add_argument('--dir_path', type=str, default='./generated_images/subset')
17 | parser.add_argument('--scheduler', type=str, default='DDPM')
18 | args = parser.parse_args()
19 |
20 | # define dataset / data_loader
21 | text2img_dataset = text_image_pair(dir_path=args.dir_path, csv_path=args.csv_path)
22 | text2img_loader = torch.utils.data.DataLoader(dataset=text2img_dataset, batch_size=args.bs, shuffle=False)
23 |
24 | print("total length:", len(text2img_dataset))
25 | model, _, preprocess = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k')
26 | model2, _ = clip.load("ViT-L/14", device='cuda') #RN50x64
27 | model = model.cuda().eval()
28 | model2 = model2.eval()
29 | tokenizer = open_clip.get_tokenizer('ViT-g-14')
30 |
31 |
32 | model_aes = aesthetic_score.MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14
33 | s = torch.load("./clip-refs/sac+logos+ava1-l14-linearMSE.pth") # load the model you trained previously or the model available in this repo
34 | model_aes.load_state_dict(s)
35 | model_aes.to("cuda")
36 | model_aes.eval()
37 |
38 | # text = tokenizer(["a horse", "a dog", "a cat"])
39 | cnt = 0.
40 | total_clip_score = 0.
41 | total_aesthetic_score = 0.
42 |
43 |
44 | with torch.no_grad(), torch.cuda.amp.autocast():
45 | for idx, (image, image2, text) in tqdm(enumerate(text2img_loader)):
46 | image = image.cuda().float()
47 | image2 = image2.cuda().float()
48 | text = list(text)
49 | text = tokenizer(text).cuda()
50 | # print('text:')
51 | # print(text.shape)
52 | image_features = model.encode_image(image).float()
53 | text_features = model.encode_text(text).float()
54 | # (bs, 1024)
55 | image_features /= image_features.norm(dim=-1, keepdim=True)
56 | text_features /= text_features.norm(dim=-1, keepdim=True)
57 |
58 | total_clip_score += (image_features * text_features).sum()
59 |
60 | image_features = model2.encode_image(image)
61 | im_emb_arr = aesthetic_score.normalized(image_features.cpu().detach().numpy())
62 | total_aesthetic_score += model_aes(torch.from_numpy(im_emb_arr).to(image.device).type(torch.cuda.FloatTensor)).sum()
63 |
64 | cnt += len(image)
65 |
66 |
67 | print("Average ClIP score :", total_clip_score.item() / cnt)
68 | print("Average Aesthetic score :", total_aesthetic_score.item() / cnt)
--------------------------------------------------------------------------------
/diffuser/ffhq.py:
--------------------------------------------------------------------------------
1 |
2 | from diffusers import DiffusionPipeline
3 | import torch
4 | import argparse
5 | from torchvision.utils import make_grid, save_image
6 | import os
7 | import numpy as np
8 | # setup args
9 | parser = argparse.ArgumentParser(description='Generate images with stable diffusion')
10 | parser.add_argument('--name', type=str, default=None)
11 | args = parser.parse_args()
12 |
13 | model_id = "google/ncsnpp-ffhq-1024"
14 |
15 | # load model and scheduler
16 | sde_ve = DiffusionPipeline.from_pretrained(model_id).to("cuda")
17 |
18 | # run pipeline in inference (sample random noise and denoise)
19 | #gen = torch.Generator(torch.device('cpu')).manual_seed(int(123))
20 | image_list = []
21 | seed_list = [i for i in range(1)]
22 | for seed in seed_list:
23 | image = sde_ve(batch_size=1, seed=seed, output_type='tensor')
24 | print(image.shape)
25 | image_list.append(image)
26 |
27 | images = torch.cat(image_list, dim=0)
28 | #images_ = (images + 1) / 2.
29 | images_ = images
30 | print("len:", len(images))
31 | image_grid = make_grid(images_, nrow=int(np.sqrt(len(images))))
32 | save_image(image_grid, os.path.join(f'{args.name}.png'))
33 |
34 | # # save image
35 | # if args.name is not None:
36 | # image.save(f"{args.name}.png")
37 | # else:
38 | # image.save("sde_ve_generated_image.png")
39 |
40 |
--------------------------------------------------------------------------------
/diffuser/generate.py:
--------------------------------------------------------------------------------
1 | # make sure you're logged in with \`huggingface-cli login\`
2 | from diffusers import StableDiffusionPipeline, DDIMScheduler, DDPMScheduler, SDEScheduler, EulerDiscreteScheduler
3 | import torch
4 | from coco_data_loader import text_image_pair
5 | from PIL import Image
6 | import os
7 | import pandas as pd
8 | import argparse
9 | import torch.nn as nn
10 | from torch_utils import distributed as dist
11 | import numpy as np
12 | import tqdm
13 |
14 | parser = argparse.ArgumentParser(description='Generate images with stable diffusion')
15 | parser.add_argument('--steps', type=int, default=30, help='number of inference steps during sampling')
16 | parser.add_argument('--generate_seed', type=int, default=6)
17 | parser.add_argument('--w', type=float, default=7.5)
18 | parser.add_argument('--s_noise', type=float, default=1.)
19 | parser.add_argument('--bs', type=int, default=16)
20 | parser.add_argument('--max_cnt', type=int, default=5000, help='number of maximum geneated samples')
21 | parser.add_argument('--save_path', type=str, default='./generated_images')
22 | parser.add_argument('--scheduler', type=str, default='DDPM')
23 | parser.add_argument('--name', type=str, default=None)
24 | parser.add_argument('--restart', action='store_true', default=False)
25 | parser.add_argument('--second', action='store_true', default=False, help='second order ODE')
26 | parser.add_argument('--sigma', action='store_true', default=False, help='use sigma')
27 | args = parser.parse_args()
28 |
29 |
30 | dist.init()
31 |
32 | dist.print0('Args:')
33 | for k, v in sorted(vars(args).items()):
34 | dist.print0('\t{}: {}'.format(k, v))
35 | # define dataset / data_loader
36 |
37 | df = pd.read_csv('./coco/coco/subset.csv')
38 | all_text = list(df['caption'])
39 | all_text = all_text[: args.max_cnt]
40 |
41 | num_batches = ((len(all_text) - 1) // (args.bs * dist.get_world_size()) + 1) * dist.get_world_size()
42 | all_batches = np.array_split(np.array(all_text), num_batches)
43 | rank_batches = all_batches[dist.get_rank():: dist.get_world_size()]
44 |
45 |
46 | index_list = np.arange(len(all_text))
47 | all_batches_index = np.array_split(index_list, num_batches)
48 | rank_batches_index = all_batches_index[dist.get_rank():: dist.get_world_size()]
49 |
50 |
51 | ##### load stable diffusion models #####
52 | pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
53 | dist.print0("default scheduler config:")
54 | dist.print0(pipe.scheduler.config)
55 | pipe = pipe.to("cuda")
56 |
57 | if args.scheduler == 'DDPM':
58 | pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
59 | elif args.scheduler == 'DDIM':
60 | # recommend using DDIM with Restart
61 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
62 | pipe.scheduler.use_sigma = args.sigma
63 | elif args.scheduler == 'SDE':
64 | pipe.scheduler = SDEScheduler.from_config(pipe.scheduler.config)
65 | elif args.scheduler == 'ODE':
66 | pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
67 | pipe.scheduler.use_karras_sigmas = False
68 | else:
69 | raise NotImplementedError
70 |
71 | generator = torch.Generator(device="cuda").manual_seed(args.generate_seed)
72 |
73 | ##### setup save configuration #######
74 | if args.name is None:
75 | save_dir = os.path.join(args.save_path,
76 | f'scheduler_{args.scheduler}_steps_{args.steps}_restart_{args.restart}_w_{args.w}_second_{args.second}_seed_{args.generate_seed}_sigma_{args.sigma}')
77 | else:
78 | save_dir = os.path.join(args.save_path,
79 | f'scheduler_{args.scheduler}_steps_{args.steps}_restart_{args.restart}_w_{args.w}_second_{args.second}_seed_{args.generate_seed}_sigma_{args.sigma}_name_{args.name}')
80 |
81 | dist.print0("save images to {}".format(save_dir))
82 |
83 | if dist.get_rank() == 0 and not os.path.exists(save_dir):
84 | os.mkdir(save_dir)
85 |
86 | ## generate images ##
87 | for cnt, mini_batch in enumerate(tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0))):
88 | torch.distributed.barrier()
89 | text = list(mini_batch)
90 | image = pipe(text, generator=generator, num_inference_steps=args.steps, guidance_scale=args.w, restart=args.restart, second_order=args.second, dist=dist, S_noise=args.s_noise).images
91 |
92 | for text_idx, global_idx in enumerate(rank_batches_index[cnt]):
93 | image[text_idx].save(os.path.join(save_dir, f'{global_idx}.png'))
94 |
95 | # Done.
96 | torch.distributed.barrier()
97 | if dist.get_rank() == 0:
98 | d = {'caption': all_text}
99 | df = pd.DataFrame(data=d)
100 | df.to_csv(os.path.join(save_dir, 'subset.csv'))
--------------------------------------------------------------------------------
/diffuser/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.18.0
2 | flax==0.6.9
3 | huggingface_hub==0.13.4
4 | importlib_metadata==4.11.3
5 | jax==0.4.8
6 | k_diffusion==0.0.14
7 | librosa==0.10.0.post2
8 | msgpack_python==0.5.6
9 | numpy==1.22.4
10 | omegaconf==2.3.0
11 | onnxruntime==1.14.1
12 | open_clip_torch==2.16.2
13 | opencv_python==4.7.0.72
14 | packaging==21.3
15 | pandas==1.4.2
16 | Pillow==9.5.0
17 | psutil==5.8.0
18 | pyspng==0.1.1
19 | pytest==7.1.1
20 | pytorch_fid==0.3.0
21 | pytorch_lightning==2.0.1.post0
22 | requests==2.27.1
23 | safetensors==0.3.0
24 | scipy==1.10.1
25 | torch==1.12.1
26 | torchvision==0.13.1
27 | tqdm==4.64.0
28 | transformers==4.28.1
29 |
--------------------------------------------------------------------------------
/diffuser/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/diffuser/torch_utils/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/torch_utils/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/diffuser/torch_utils/__pycache__/distributed.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/torch_utils/__pycache__/distributed.cpython-39.pyc
--------------------------------------------------------------------------------
/diffuser/torch_utils/__pycache__/misc.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/torch_utils/__pycache__/misc.cpython-39.pyc
--------------------------------------------------------------------------------
/diffuser/torch_utils/__pycache__/persistence.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/torch_utils/__pycache__/persistence.cpython-39.pyc
--------------------------------------------------------------------------------
/diffuser/torch_utils/__pycache__/training_stats.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/torch_utils/__pycache__/training_stats.cpython-39.pyc
--------------------------------------------------------------------------------
/diffuser/torch_utils/distributed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import os
9 | import torch
10 | from . import training_stats
11 |
12 | #----------------------------------------------------------------------------
13 |
14 | def init():
15 | if 'MASTER_ADDR' not in os.environ:
16 | os.environ['MASTER_ADDR'] = 'localhost'
17 | if 'MASTER_PORT' not in os.environ:
18 | os.environ['MASTER_PORT'] = '29500'
19 | if 'RANK' not in os.environ:
20 | os.environ['RANK'] = '0'
21 | if 'LOCAL_RANK' not in os.environ:
22 | os.environ['LOCAL_RANK'] = '0'
23 | if 'WORLD_SIZE' not in os.environ:
24 | os.environ['WORLD_SIZE'] = '1'
25 |
26 | backend = 'gloo' if os.name == 'nt' else 'nccl'
27 | torch.distributed.init_process_group(backend=backend, init_method='env://')
28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
29 |
30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None
31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device)
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def get_rank():
36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | def get_world_size():
41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
42 |
43 | #----------------------------------------------------------------------------
44 |
45 | def should_stop():
46 | return False
47 |
48 | #----------------------------------------------------------------------------
49 |
50 | def update_progress(cur, total):
51 | _ = cur, total
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | def print0(*args, **kwargs):
56 | if get_rank() == 0:
57 | print(*args, **kwargs)
58 |
59 | #----------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/diffuser/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/diffuser/visualization.py:
--------------------------------------------------------------------------------
1 | # make sure you're logged in with \`huggingface-cli login\`
2 | from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler, DDIMScheduler, DDPMScheduler, SDEScheduler, EulerDiscreteScheduler
3 | import torch
4 | import argparse
5 | import os
6 | from torchvision.utils import make_grid, save_image
7 | import numpy as np
8 | from PIL import Image
9 |
10 | parser = argparse.ArgumentParser(description='Generate images with stable diffusion')
11 | parser.add_argument('--steps', type=int, default=30, help='number of inference steps during sampling')
12 | parser.add_argument('--generate_seed', type=int, default=6)
13 | parser.add_argument('--w', type=float, default=8)
14 | parser.add_argument('--bs', type=int, default=16)
15 | parser.add_argument('--max_cnt', type=int, default=5000, help='number of maximum geneated samples')
16 | parser.add_argument('--save_path', type=str, default='./generated_images')
17 | parser.add_argument('--scheduler', type=str, default='DDIM')
18 | parser.add_argument('--restart', action='store_true', default=False)
19 | parser.add_argument('--second', action='store_true', default=False, help='second order ODE')
20 | parser.add_argument('--sigma', action='store_true', default=False, help='use sigma')
21 | parser.add_argument('--prompt', type=str, default='a photo of an astronaut riding a horse on mars')
22 |
23 | args = parser.parse_args()
24 |
25 | print('Args:')
26 | for k, v in sorted(vars(args).items()):
27 | print('\t{}: {}'.format(k, v))
28 |
29 | os.makedirs('./vis', exist_ok=True)
30 |
31 | # prompt_list = ["a photo of an astronaut riding a horse on mars", "a raccoon playing table tennis",
32 | # "Intricate origami of a fox in a snowy forest", "A transparent sculpture of a duck made out of glass"]
33 | prompt_list = [args.prompt]
34 |
35 |
36 | for prompt_ in prompt_list:
37 |
38 | pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
39 | print("default scheduler config:")
40 | print(pipe.scheduler.config)
41 |
42 | pipe = pipe.to("cuda")
43 |
44 | if args.scheduler == 'DDPM':
45 | pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
46 | elif args.scheduler == 'DDIM':
47 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
48 | pipe.scheduler.use_sigma = args.sigma
49 | else:
50 | raise NotImplementedError
51 |
52 | prompt = [prompt_] * 16
53 |
54 | # Restart
55 | generator = torch.Generator(device="cuda").manual_seed(args.generate_seed)
56 | out, _ = pipe(prompt, generator=generator, num_inference_steps=args.steps, guidance_scale=args.w,
57 | restart=args.restart, second_order=args.second, output_type='tensor')
58 | image = out.images
59 |
60 | print(f'image saving to ./vis/{prompt_}_{args.scheduler}_restart_True_steps_{args.steps}_w_{args.w}_seed_{args.generate_seed}.png')
61 | image_grid = make_grid(torch.from_numpy(image).permute(0, 3, 1, 2), nrow=int(np.sqrt(len(image))))
62 | save_image(image_grid,
63 | f"./vis/{prompt_}_{args.scheduler}_restart_True_steps_{args.steps}_w_{args.w}_seed_{args.generate_seed}.png")
64 |
--------------------------------------------------------------------------------