├── README.md ├── assets ├── .DS_Store ├── fig_3.png ├── fig_5.png ├── restart.png └── vis.png ├── benchmarks ├── .DS_Store ├── README.md ├── dnnlib │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── util.cpython-38.pyc │ │ └── util.cpython-39.pyc │ └── util.py ├── fid.py ├── generate_dpm_solver.py ├── generate_restart.py ├── hyperparam.py ├── params_cifar10_edm.txt ├── params_cifar10_pfgmpp.txt ├── params_cifar10_vp.txt ├── params_imagenet_edm.txt ├── run.sh ├── run_fid.sh ├── run_restart.sh ├── torch_utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── distributed.cpython-38.pyc │ │ ├── distributed.cpython-39.pyc │ │ ├── misc.cpython-38.pyc │ │ ├── misc.cpython-39.pyc │ │ ├── persistence.cpython-39.pyc │ │ ├── training_stats.cpython-38.pyc │ │ └── training_stats.cpython-39.pyc │ ├── distributed.py │ ├── misc.py │ ├── persistence.py │ └── training_stats.py └── training │ ├── __init__.py │ ├── augment.py │ ├── dataset.py │ ├── loss.py │ ├── networks.py │ └── training_loop.py └── diffuser ├── .DS_Store ├── .idea ├── .gitignore ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── src.iml └── vcs.xml ├── README.md ├── aesthetic_score.py ├── clip_score.py ├── coco_data_loader.py ├── data_process.py ├── diffusers ├── .DS_Store ├── __init__.py ├── commands │ ├── __init__.py │ ├── diffusers_cli.py │ └── env.py ├── configuration_utils.py ├── dependency_versions_check.py ├── dependency_versions_table.py ├── experimental │ ├── README.md │ ├── __init__.py │ └── rl │ │ ├── __init__.py │ │ └── value_guided_sampling.py ├── image_processor.py ├── loaders.py ├── models │ ├── README.md │ ├── __init__.py │ ├── attention.py │ ├── attention_flax.py │ ├── attention_processor.py │ ├── autoencoder_kl.py │ ├── controlnet.py │ ├── controlnet_flax.py │ ├── cross_attention.py │ ├── dual_transformer_2d.py │ ├── embeddings.py │ ├── embeddings_flax.py │ ├── modeling_flax_pytorch_utils.py │ ├── modeling_flax_utils.py │ ├── modeling_pytorch_flax_utils.py │ ├── modeling_utils.py │ ├── prior_transformer.py │ ├── resnet.py │ ├── resnet_flax.py │ ├── t5_film_transformer.py │ ├── transformer_2d.py │ ├── transformer_temporal.py │ ├── unet_1d.py │ ├── unet_1d_blocks.py │ ├── unet_2d.py │ ├── unet_2d_blocks.py │ ├── unet_2d_blocks_flax.py │ ├── unet_2d_condition.py │ ├── unet_2d_condition_flax.py │ ├── unet_3d_blocks.py │ ├── unet_3d_condition.py │ ├── vae.py │ ├── vae_flax.py │ └── vq_model.py ├── optimization.py ├── pipeline_utils.py ├── pipelines │ ├── .DS_Store │ ├── README.md │ ├── __init__.py │ ├── alt_diffusion │ │ ├── __init__.py │ │ ├── modeling_roberta_series.py │ │ ├── pipeline_alt_diffusion.py │ │ └── pipeline_alt_diffusion_img2img.py │ ├── audio_diffusion │ │ ├── __init__.py │ │ ├── mel.py │ │ └── pipeline_audio_diffusion.py │ ├── audioldm │ │ ├── __init__.py │ │ └── pipeline_audioldm.py │ ├── dance_diffusion │ │ ├── __init__.py │ │ └── pipeline_dance_diffusion.py │ ├── ddim │ │ ├── __init__.py │ │ └── pipeline_ddim.py │ ├── ddpm │ │ ├── __init__.py │ │ └── pipeline_ddpm.py │ ├── dit │ │ ├── __init__.py │ │ └── pipeline_dit.py │ ├── latent_diffusion │ │ ├── __init__.py │ │ ├── pipeline_latent_diffusion.py │ │ └── pipeline_latent_diffusion_superresolution.py │ ├── latent_diffusion_uncond │ │ ├── __init__.py │ │ └── pipeline_latent_diffusion_uncond.py │ ├── onnx_utils.py │ ├── paint_by_example │ │ ├── __init__.py │ │ ├── image_encoder.py │ │ └── pipeline_paint_by_example.py │ ├── pipeline_flax_utils.py │ ├── pipeline_utils.py │ ├── pndm │ │ ├── __init__.py │ │ └── pipeline_pndm.py │ ├── repaint │ │ ├── __init__.py │ │ └── pipeline_repaint.py │ ├── score_sde_ve │ │ ├── __init__.py │ │ └── pipeline_score_sde_ve.py │ ├── semantic_stable_diffusion │ │ ├── __init__.py │ │ └── pipeline_semantic_stable_diffusion.py │ ├── spectrogram_diffusion │ │ ├── __init__.py │ │ ├── continous_encoder.py │ │ ├── midi_utils.py │ │ ├── notes_encoder.py │ │ └── pipeline_spectrogram_diffusion.py │ ├── stable_diffusion │ │ ├── README.md │ │ ├── __init__.py │ │ ├── convert_from_ckpt.py │ │ ├── pipeline_cycle_diffusion.py │ │ ├── pipeline_flax_stable_diffusion.py │ │ ├── pipeline_flax_stable_diffusion_controlnet.py │ │ ├── pipeline_flax_stable_diffusion_img2img.py │ │ ├── pipeline_flax_stable_diffusion_inpaint.py │ │ ├── pipeline_onnx_stable_diffusion.py │ │ ├── pipeline_onnx_stable_diffusion_img2img.py │ │ ├── pipeline_onnx_stable_diffusion_inpaint.py │ │ ├── pipeline_onnx_stable_diffusion_inpaint_legacy.py │ │ ├── pipeline_onnx_stable_diffusion_upscale.py │ │ ├── pipeline_stable_diffusion.py │ │ ├── pipeline_stable_diffusion_attend_and_excite.py │ │ ├── pipeline_stable_diffusion_controlnet.py │ │ ├── pipeline_stable_diffusion_depth2img.py │ │ ├── pipeline_stable_diffusion_image_variation.py │ │ ├── pipeline_stable_diffusion_img2img.py │ │ ├── pipeline_stable_diffusion_inpaint.py │ │ ├── pipeline_stable_diffusion_inpaint_legacy.py │ │ ├── pipeline_stable_diffusion_instruct_pix2pix.py │ │ ├── pipeline_stable_diffusion_k_diffusion.py │ │ ├── pipeline_stable_diffusion_latent_upscale.py │ │ ├── pipeline_stable_diffusion_model_editing.py │ │ ├── pipeline_stable_diffusion_panorama.py │ │ ├── pipeline_stable_diffusion_pix2pix_zero.py │ │ ├── pipeline_stable_diffusion_sag.py │ │ ├── pipeline_stable_diffusion_upscale.py │ │ ├── pipeline_stable_unclip.py │ │ ├── pipeline_stable_unclip_img2img.py │ │ ├── safety_checker.py │ │ ├── safety_checker_flax.py │ │ └── stable_unclip_image_normalizer.py │ ├── stable_diffusion_safe │ │ ├── __init__.py │ │ ├── pipeline_stable_diffusion_safe.py │ │ └── safety_checker.py │ ├── stochastic_karras_ve │ │ ├── __init__.py │ │ └── pipeline_stochastic_karras_ve.py │ └── text_to_video_synthesis │ │ ├── __init__.py │ │ ├── pipeline_text_to_video_synth.py │ │ └── pipeline_text_to_video_zero.py ├── schedulers │ ├── README.md │ ├── __init__.py │ ├── scheduling_ddim.py │ ├── scheduling_ddim_flax.py │ ├── scheduling_ddim_inverse.py │ ├── scheduling_ddpm.py │ ├── scheduling_ddpm_flax.py │ ├── scheduling_deis_multistep.py │ ├── scheduling_dpmsolver_multistep.py │ ├── scheduling_dpmsolver_multistep_flax.py │ ├── scheduling_dpmsolver_singlestep.py │ ├── scheduling_euler_ancestral_discrete.py │ ├── scheduling_euler_discrete.py │ ├── scheduling_heun_discrete.py │ ├── scheduling_ipndm.py │ ├── scheduling_k_dpm_2_ancestral_discrete.py │ ├── scheduling_k_dpm_2_discrete.py │ ├── scheduling_karras_ve.py │ ├── scheduling_karras_ve_flax.py │ ├── scheduling_lms_discrete.py │ ├── scheduling_lms_discrete_flax.py │ ├── scheduling_pndm.py │ ├── scheduling_pndm_flax.py │ ├── scheduling_repaint.py │ ├── scheduling_sde.py │ ├── scheduling_sde_ve.py │ ├── scheduling_sde_ve_flax.py │ ├── scheduling_sde_vp.py │ ├── scheduling_unclip.py │ ├── scheduling_unipc_multistep.py │ ├── scheduling_utils.py │ ├── scheduling_utils_flax.py │ └── scheduling_vq_diffusion.py ├── test.py ├── training_utils.py └── utils │ ├── __init__.py │ ├── accelerate_utils.py │ ├── constants.py │ ├── deprecation_utils.py │ ├── doc_utils.py │ ├── dummy_flax_and_transformers_objects.py │ ├── dummy_flax_objects.py │ ├── dummy_note_seq_objects.py │ ├── dummy_onnx_objects.py │ ├── dummy_pt_objects.py │ ├── dummy_torch_and_librosa_objects.py │ ├── dummy_torch_and_scipy_objects.py │ ├── dummy_torch_and_transformers_and_k_diffusion_objects.py │ ├── dummy_torch_and_transformers_and_onnx_objects.py │ ├── dummy_torch_and_transformers_objects.py │ ├── dummy_transformers_and_torch_and_note_seq_objects.py │ ├── dynamic_modules_utils.py │ ├── hub_utils.py │ ├── import_utils.py │ ├── logging.py │ ├── model_card_template.md │ ├── outputs.py │ ├── pil_utils.py │ ├── testing_utils.py │ └── torch_utils.py ├── dnnlib ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ └── util.cpython-39.pyc └── util.py ├── eval_clip_score.py ├── ffhq.py ├── fid.py ├── generate.py ├── requirements.txt ├── torch_utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── distributed.cpython-39.pyc │ ├── misc.cpython-39.pyc │ ├── persistence.cpython-39.pyc │ └── training_stats.cpython-39.pyc ├── distributed.py ├── misc.py ├── persistence.py └── training_stats.py ├── training ├── __init__.py ├── augment.py ├── dataset.py ├── loss.py ├── networks.py └── training_loop.py └── visualization.py /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/assets/.DS_Store -------------------------------------------------------------------------------- /assets/fig_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/assets/fig_3.png -------------------------------------------------------------------------------- /assets/fig_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/assets/fig_5.png -------------------------------------------------------------------------------- /assets/restart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/assets/restart.png -------------------------------------------------------------------------------- /assets/vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/assets/vis.png -------------------------------------------------------------------------------- /benchmarks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/.DS_Store -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | ### Standard Benchmarks (CIFAR-10, ImageNet-64) 2 | 3 | **The working directory for standard benchmarks is under `./benchmarks`** 4 | 5 | #### 1. Preparing datasets and checkpoints 6 | 7 | **CIFAR-10:** Download the [CIFAR-10 python version](https://www.cs.toronto.edu/~kriz/cifar.html) and convert to ZIP archive: 8 | 9 | ```.bash 10 | python dataset_tool.py --source=downloads/cifar10/cifar-10-python.tar.gz \ 11 | --dest=datasets/cifar10-32x32.zip 12 | python fid.py ref --data=datasets/cifar10-32x32.zip --dest=fid-refs/cifar10-32x32.npz 13 | ``` 14 | 15 | **ImageNet:** Download the [ImageNet Object Localization Challenge](https://www.kaggle.com/competitions/imagenet-object-localization-challenge/data) and convert to ZIP archive at 64x64 resolution: 16 | 17 | ```.bash 18 | python dataset_tool.py --source=downloads/imagenet/ILSVRC/Data/CLS-LOC/train \ 19 | --dest=datasets/imagenet-64x64.zip --resolution=64x64 --transform=center-crop 20 | python fid.py ref --data=datasets/imagenet-64x64.zip --dest=fid-refs/imagenet-64x64.npz 21 | ``` 22 | 23 | Alternatively, you could consider downloading the FID statistics at [CIFAR-10-FID](https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz) and [ImageNet-$64\times 64$-FID](https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/imagenet-64x64.npz). 24 | 25 | 26 | 27 | Please download the CIFAR-10 or ImageNet checkpoints from [EDM](https://github.com/NVlabs/edm) repo or [PFGM++](https://github.com/Newbeeer/pfgmpp) repo. For example, 28 | 29 | | Dataset | Method | Path | 30 | | ---------------------- | ------------------------------- | ------------------------------------------------------------ | 31 | | CIFAR-10 | VP unconditional | [path](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl) | 32 | | CIFAR-10 | PFGM++ ($D=2048$) unconditional | [path](https://drive.google.com/drive/folders/1sZ7vh7o8kuXfFjK8ROWXxtEZi8Srewgo) | 33 | | ImageNet $64\times 64$ | EDM conditional | [path](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl) | 34 | 35 | 36 | 37 | #### 2. Generate 38 | 39 | Generating a large number of images can be time-consuming; the workload can be distributed across multiple GPUs by launching the above command using `torchrun`. Before generation, please make sure the checkpoint is downloaded in the `./benchmarks/imgs` folder 40 | 41 | ```shell 42 | torchrun --standalone --nproc_per_node=8 generate_restart.py --outdir=./imgs \ 43 | --restart_info='{restart_config}' --S_min=0.01 --S_max=1 --S_noise 1.003 \ 44 | --steps={steps} --seeds=00000-49999 --name={name} (--pfgmpp=1) (--aug_dim={D}) 45 | 46 | 47 | restart_config: configuration for Restart (details below) 48 | steps: number of steps in the main backward process, default=18 49 | name: name of experiments (for FID evaluation) 50 | pfgmpp: flag for using PFGM++ 51 | D: augmented dimension in PFGM++ 52 | ``` 53 | 54 | The `restart_info` is in the format of $\lbrace i: [N_{\textrm{Restart},i}, K_i, t_{\textrm{min}, i}, t_{\textrm{max}, i}] \rbrace_{i=0}^{l-1}$ , such as `{"0": [3, 2, 0.06, 0.30]}`. Please refer to Table 3 (CIFAR-10) and Table5 (ImageNet-64) for detailed configuration. For example, on uncond. EDM cond. ImageNet-64, with NFE=203, FID=1.41, the command line is: 55 | 56 | ```shell 57 | torchrun --standalone --nproc_per_node=8 generate_restart.py --outdir=./imgs \ 58 | --restart_info='{"0": [4, 1, 19.35, 40.79], "1": [4, 1, 1.09, 1.92], "2": [4, 5, 0.59, 1.09], "3": [4, 5, 0.30, 0.59], "4": [6, 6, 0.06, 0.30]}' --S_min=0.01 --S_max=1 --S_noise 1.003 \ 59 | --steps=36 --seeds=00000-49999 --name=imagenet_edm 60 | ``` 61 | 62 | We also provide the extentive Restart configurations in `params_cifar10_vp.txt`, `params_imagenet_edm.txt`, corresponding to Table 3 (CIFAR-10) and Table5 (ImageNet-64) respectively. Each line in these `txt` is in the form of $N_{\textrm{main}} \quad \lbrace i: [N_{\textrm{Restart},i}, K_i, t_{\textrm{min}, i}, t_{\textrm{max}, i}]\rbrace_{i=0}^{l-1}$. To sweep the Restart configurations in the `txt` files, please run 63 | 64 | ```shell 65 | python3 hyperparams.py --dataset {dataset} --method {method} 66 | 67 | dataset: cifar10 | imagenet 68 | method: edm | vp | pfgmpp 69 | ``` 70 | 71 | The above sweeping will reproduce the results in the following figure (Fig 3 in the paper): 72 | 73 | ![schematic](../assets/fig_3.png) 74 | 75 | #### 3. Evaluation 76 | 77 | For FID evaluation, please run: 78 | 79 | ```shell 80 | python fid.py ./imgs/imgs_{name} stats_path 81 | 82 | name: name of experiments (specified in geneation command line) 83 | stats_path: path to FID statistics, such as ./cifar10-32x32.npz or ./imagenet-64x64.npz 84 | ``` 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /benchmarks/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | from .util import EasyDict, make_cache_dir_path 9 | -------------------------------------------------------------------------------- /benchmarks/dnnlib/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/dnnlib/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /benchmarks/dnnlib/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/dnnlib/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /benchmarks/dnnlib/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/dnnlib/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /benchmarks/dnnlib/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/dnnlib/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /benchmarks/hyperparam.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from subprocess import PIPE, run 4 | import argparse 5 | import gdown 6 | 7 | 8 | # setup arguments 9 | parser = argparse.ArgumentParser(description='Restart Sampling') 10 | parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet']) 11 | parser.add_argument('--method', type=str, default='vp', choices=['edm', 'vp', 'pfgmpp']) 12 | args = parser.parse_args() 13 | 14 | 15 | devices='0,1,2,3,4,5,6,7' 16 | os.environ['CUDA_VISIBLE_DEVICES'] = devices 17 | batch_size = 192 18 | 19 | def get_fid(num, name): 20 | 21 | if args.dataset == 'cifar10': 22 | fid_stats_path = './cifar10-32x32.npz' 23 | else: 24 | fid_stats_path = './imagenet-64x64.npz' 25 | 26 | for tries in range(2): 27 | fid_command = f'CUDA_VISIBLE_DEVICES=0 python fid.py ./imgs/imgs_{name} {fid_stats_path} --num={num} > tmp_fid.txt' 28 | print('----------------------------') 29 | print(fid_command) 30 | os.system(fid_command) 31 | with open("tmp_fid.txt", "r") as f: 32 | output = f.read() 33 | print(output) 34 | try: 35 | fid_score = float(output.split()[-1]) 36 | return fid_score 37 | except: 38 | print("FID computation failed, trying again") 39 | print('----------------------------') 40 | return 1e9 41 | 42 | 43 | runs_dict = dict() 44 | 45 | def generate(num, name, store = True, **kwargs): 46 | 47 | # If using PFGM++, setting up the augmentation dimension and the method flag 48 | pfgmpp = 1 if args.method == 'pfgmpp' else 0 49 | aug_dim = 2048 if args.method == 'pfgmpp' else -1 50 | s_noise = 1.003 if args.dataset == 'imagenet' else 1.0 51 | 52 | command = f"CUDA_VISIBLE_DEVICES={devices} torchrun --standalone --nproc_per_node={len(devices.split(','))} generate_restart.py --outdir=./imgs " \ 53 | f"--restart_info='{kwargs['restart']}' --S_min=0.01 --S_max=1 --S_noise={s_noise} " \ 54 | f"--steps={kwargs['steps']} --seeds=00000-{00000+num-1} --name={name} --batch={batch_size} --pfgmpp={pfgmpp} --aug_dim={aug_dim} #generate" 55 | print(command) 56 | os.system(command) 57 | if store: 58 | fid_score = get_fid(num, name) 59 | NFE = 0 60 | print("restart:", kwargs["restart"]) 61 | dic = json.loads(kwargs["restart"]) 62 | print("dic:", dic) 63 | for restartid in dic.keys(): 64 | info = dic[restartid] 65 | NFE += 2 * info[1] * (info[0] - 1) 66 | NFE += (2 * kwargs['steps'] - 1) 67 | print(f'NFE:{NFE} FID_OF_{num}:{fid_score}') 68 | runs_dict[name] = {"fid": fid_score, "NFE": NFE, "Args": kwargs} 69 | 70 | 71 | import random 72 | import json 73 | 74 | runs_dict = dict() 75 | with open(f"params_{args.dataset}_{args.method}.txt", "r") as f: 76 | lines = f.readlines() 77 | for line in lines: 78 | sample_runs = 50000 79 | infos = line.split(' ') 80 | 81 | steps = int(infos[0]) 82 | restart_info = ' '.join(infos[1:]) 83 | print("restart_info:", restart_info) 84 | cur_name = random.randint(0, 20000000) 85 | generate(sample_runs, cur_name, True, restart=restart_info, steps=steps) 86 | 87 | with open(f"restart_runs_dict_{args.dataset}_{args.method}.json", "w") as f: 88 | json.dump(runs_dict, f) 89 | 90 | -------------------------------------------------------------------------------- /benchmarks/params_cifar10_edm.txt: -------------------------------------------------------------------------------- 1 | 18 {"0": [3, 2, 0.14, 0.30]} -------------------------------------------------------------------------------- /benchmarks/params_cifar10_pfgmpp.txt: -------------------------------------------------------------------------------- 1 | 18 {"0": [3, 2, 0.14, 0.30]} -------------------------------------------------------------------------------- /benchmarks/params_cifar10_vp.txt: -------------------------------------------------------------------------------- 1 | 18 {"0": [3, 2, 0.06, 0.30]} 2 | 18 {"0": [3, 5, 0.06, 0.30]} 3 | 18 {"0": [3, 10, 0.06, 0.30]} 4 | 18 {"0": [3, 20, 0.06, 0.30]} 5 | 20 {"0": [9, 30, 0.06, 0.20]} -------------------------------------------------------------------------------- /benchmarks/params_imagenet_edm.txt: -------------------------------------------------------------------------------- 1 | 14 {"0": [3, 1, 19.35, 40.79], "1": [3, 1, 1.09, 1.92], "2": [3, 1, 0.06, 0.30]} 2 | 18 {"0": [5, 1, 19.35, 40.79], "1": [5, 1, 1.09, 1.92], "2": [5, 1, 0.59, 1.09], "3": [5, 1, 0.06, 0.30]} 3 | 18 {"0": [3, 1, 19.35, 40.79], "1": [4, 1, 1.09, 1.92], "2": [4, 4, 0.59, 1.09], "3": [4, 1, 0.30, 0.59], "4": [4, 4, 0.06, 0.30]} 4 | 18 {"0": [3, 1, 19.35, 40.79], "1": [4, 1, 1.09, 1.92], "2": [4, 5, 0.59, 1.09], "3": [4, 5, 0.30, 0.59], "4": [4, 10, 0.06, 0.30]} 5 | 36 {"0": [4, 1, 19.35, 40.79], "1": [4, 1, 1.09, 1.92], "2": [4, 5, 0.59, 1.09], "3": [4, 5, 0.30, 0.59], "4": [6, 6, 0.06, 0.30]} 6 | 36 {"0": [3, 1, 19.35, 40.79], "1": [6, 1, 1.09, 1.92], "2": [6, 5, 0.59, 1.09], "3": [6, 5, 0.30, 0.59], "4": [6, 20, 0.06, 0.30]} 7 | 36 {"0": [6, 1, 19.35, 40.79], "1": [6, 1, 1.09, 1.92], "2": [7, 6, 0.59, 1.09], "3": [7, 6, 0.30, 0.59], "4": [7, 25, 0.06, 0.30]} 8 | 36 {"0": [10, 3, 19.35, 40.79], "1": [10, 3, 1.09, 1.92], "2": [7, 6, 0.59, 1.09], "3": [7, 6, 0.30, 0.59], "4": [7, 25, 0.06, 0.30]} -------------------------------------------------------------------------------- /benchmarks/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for i in $(seq 5 13); 3 | do 4 | torchrun --rdzv_endpoint=0.0.0.0:1201 --nproc_per_node=8 generate_dpm_solver.py --outdir=./imgs --restart_info='' --S_min=0.01 --S_max=1 --S_noise 1.003 --steps=$i --seeds=50000-99999 --name step_dpm3_$i 5 | done -------------------------------------------------------------------------------- /benchmarks/run_fid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for i in $(seq 8 11); 3 | do 4 | python fid.py ./imgs/imgs_step_3_1_0.3_dpm3_3_$i cifar10-32x32.npz 5 | done -------------------------------------------------------------------------------- /benchmarks/run_restart.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for i in $(seq 5 7); 3 | do 4 | torchrun --rdzv_endpoint=0.0.0.0:1203 --nproc_per_node=8 generate_dpm_solver.py --outdir=./imgs --restart_info='{"0": [3, 1, 0.06, 1]}' --S_min=0.01 --S_max=1 --S_noise 1.003 --steps=$i --seeds=50000-99999 --name step_3_1_1_dpm3_3_$i 5 | done -------------------------------------------------------------------------------- /benchmarks/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /benchmarks/torch_utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /benchmarks/torch_utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /benchmarks/torch_utils/__pycache__/distributed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/distributed.cpython-38.pyc -------------------------------------------------------------------------------- /benchmarks/torch_utils/__pycache__/distributed.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/distributed.cpython-39.pyc -------------------------------------------------------------------------------- /benchmarks/torch_utils/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /benchmarks/torch_utils/__pycache__/misc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/misc.cpython-39.pyc -------------------------------------------------------------------------------- /benchmarks/torch_utils/__pycache__/persistence.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/persistence.cpython-39.pyc -------------------------------------------------------------------------------- /benchmarks/torch_utils/__pycache__/training_stats.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/training_stats.cpython-38.pyc -------------------------------------------------------------------------------- /benchmarks/torch_utils/__pycache__/training_stats.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/benchmarks/torch_utils/__pycache__/training_stats.cpython-39.pyc -------------------------------------------------------------------------------- /benchmarks/torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import os 9 | import torch 10 | from . import training_stats 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def init(): 15 | if 'MASTER_ADDR' not in os.environ: 16 | os.environ['MASTER_ADDR'] = 'localhost' 17 | if 'MASTER_PORT' not in os.environ: 18 | os.environ['MASTER_PORT'] = '29500' 19 | if 'RANK' not in os.environ: 20 | os.environ['RANK'] = '0' 21 | if 'LOCAL_RANK' not in os.environ: 22 | os.environ['LOCAL_RANK'] = '0' 23 | if 'WORLD_SIZE' not in os.environ: 24 | os.environ['WORLD_SIZE'] = '1' 25 | 26 | backend = 'gloo' if os.name == 'nt' else 'nccl' 27 | torch.distributed.init_process_group(backend=backend, init_method='env://') 28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 29 | 30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def get_rank(): 36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def get_world_size(): 41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def should_stop(): 46 | return False 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def update_progress(cur, total): 51 | _ = cur, total 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def print0(*args, **kwargs): 56 | if get_rank() == 0: 57 | print(*args, **kwargs) 58 | 59 | #---------------------------------------------------------------------------- 60 | -------------------------------------------------------------------------------- /benchmarks/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /diffuser/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/.DS_Store -------------------------------------------------------------------------------- /diffuser/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /diffuser/.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 26 | -------------------------------------------------------------------------------- /diffuser/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /diffuser/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /diffuser/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /diffuser/.idea/src.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /diffuser/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /diffuser/README.md: -------------------------------------------------------------------------------- 1 | ### Stable Diffusion 2 | 3 | TODO: merge into the diffuser repo. 4 | 5 | **The working directory for standard benchmarks is under `./diffuser`** 6 | 7 | ![schematic](../assets/fig_5.png) 8 | 9 | #### 1. Data processing: 10 | 11 | - Step 1: Follow the instruction at the head of `data_process.py` 12 | - Step 2: Run `python3 data_process.py` to randomly sampled 5K image-text pair from COCO validation set. 13 | 14 | - Step 3: Calculate FID statistics `python fid_edm.py ref --data=path-to-coco-subset --dest=./coco.npz` 15 | 16 | #### 2. Generate 17 | 18 | ``` 19 | torchrun --rdzv_endpoint=0.0.0.0:1201 --nproc_per_node=8 generate.py 20 | --steps: number of sampling steps (default=50) 21 | --scheduler: baseline method (DDIM | DDPM | Heun) 22 | --save_path: path to save images 23 | --w: classifier-guidance weight ({2,3,5,8}) 24 | --name: name of experiments 25 | --restart 26 | ``` 27 | 28 | If you would like to visualize the images given text prompt, run: 29 | 30 | ```python 31 | python3 visualization.py --prompt {prompt} --w {w} --steps {steps} --scheduler {scheduler} (--restart) 32 | 33 | prompt: text prompt. defautlt='a photo of an astronaut riding a horse on mars' 34 | steps: number of sampling steps (default=50) 35 | scheduler: baseline method (DDIM | DDPM) 36 | w: classifier-guidance weight ({2,3,5,8}) 37 | ``` 38 | 39 | 40 | 41 | #### 3. Evaluation 42 | 43 | - FID score 44 | 45 | ```sh 46 | python3 fid.py {path} ./coco.npz 47 | 48 | path: path to the directory of generated image 49 | ``` 50 | 51 | - Aesthetic score & CLIP score 52 | 53 | ```shell 54 | python3 eval_clip_score.py --csv_path {path}/subset.csv --dir_path {path} 55 | 56 | path: path to the directory of generated image 57 | ``` 58 | 59 | -------------------------------------------------------------------------------- /diffuser/aesthetic_score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytorch_lightning as pl 4 | import torch.nn as nn 5 | import clip 6 | import torch.nn.functional as F 7 | from PIL import Image, ImageFile 8 | 9 | ##### This script will predict the aesthetic score for this image file: 10 | 11 | img_path = "test.jpg" 12 | 13 | 14 | # if you changed the MLP architecture during training, change it also here: 15 | class MLP(pl.LightningModule): 16 | def __init__(self, input_size, xcol='emb', ycol='avg_rating'): 17 | super().__init__() 18 | self.input_size = input_size 19 | self.xcol = xcol 20 | self.ycol = ycol 21 | self.layers = nn.Sequential( 22 | nn.Linear(self.input_size, 1024), 23 | # nn.ReLU(), 24 | nn.Dropout(0.2), 25 | nn.Linear(1024, 128), 26 | # nn.ReLU(), 27 | nn.Dropout(0.2), 28 | nn.Linear(128, 64), 29 | # nn.ReLU(), 30 | nn.Dropout(0.1), 31 | 32 | nn.Linear(64, 16), 33 | # nn.ReLU(), 34 | 35 | nn.Linear(16, 1) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.layers(x) 40 | 41 | def training_step(self, batch, batch_idx): 42 | x = batch[self.xcol] 43 | y = batch[self.ycol].reshape(-1, 1) 44 | x_hat = self.layers(x) 45 | loss = F.mse_loss(x_hat, y) 46 | return loss 47 | 48 | def validation_step(self, batch, batch_idx): 49 | x = batch[self.xcol] 50 | y = batch[self.ycol].reshape(-1, 1) 51 | x_hat = self.layers(x) 52 | loss = F.mse_loss(x_hat, y) 53 | return loss 54 | 55 | def configure_optimizers(self): 56 | optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) 57 | return optimizer 58 | 59 | 60 | def normalized(a, axis=-1, order=2): 61 | import numpy as np # pylint: disable=import-outside-toplevel 62 | 63 | l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) 64 | l2[l2 == 0] = 1 65 | return a / np.expand_dims(l2, axis) 66 | 67 | 68 | # model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14 69 | # 70 | # s = torch.load("sac+logos+ava1-l14-linearMSE.pth") # load the model you trained previously or the model available in this repo 71 | # 72 | # model.load_state_dict(s) 73 | # 74 | # model.to("cuda") 75 | # model.eval() 76 | # 77 | # 78 | # device = "cuda" if torch.cuda.is_available() else "cpu" 79 | # model2, preprocess = clip.load("ViT-L/14", device=device) #RN50x64 80 | # 81 | # 82 | # pil_image = Image.open(img_path) 83 | # 84 | # image = preprocess(pil_image).unsqueeze(0).to(device) 85 | # 86 | # 87 | # 88 | # with torch.no_grad(): 89 | # image_features = model2.encode_image(image) 90 | # 91 | # im_emb_arr = normalized(image_features.cpu().detach().numpy() ) 92 | # 93 | # prediction = model(torch.from_numpy(im_emb_arr).to(device).type(torch.cuda.FloatTensor)) 94 | # 95 | # print( "Aesthetic score predicted by the model:") 96 | # print( prediction ) -------------------------------------------------------------------------------- /diffuser/clip_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import open_clip 4 | 5 | model, _, preprocess = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k') 6 | tokenizer = open_clip.get_tokenizer('ViT-g-14') 7 | 8 | image = preprocess(Image.open("astronaut_rides_horse.png")).unsqueeze(0) 9 | text = tokenizer(["a horse", "a dog", "a cat"]) 10 | 11 | with torch.no_grad(), torch.cuda.amp.autocast(): 12 | image_features = model.encode_image(image) 13 | text_features = model.encode_text(text) 14 | image_features /= image_features.norm(dim=-1, keepdim=True) 15 | text_features /= text_features.norm(dim=-1, keepdim=True) 16 | 17 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 18 | 19 | print("Label probs:", text_probs) # prints: [[1., 0., 0.]] -------------------------------------------------------------------------------- /diffuser/coco_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils 3 | import pandas as pd 4 | import os 5 | import open_clip 6 | from PIL import Image 7 | import clip 8 | 9 | class text_image_pair(torch.utils.data.Dataset): 10 | def __init__(self, dir_path, csv_path): 11 | """ 12 | 13 | Args: 14 | dir_path: the path to the stored images 15 | file_path: 16 | """ 17 | self.dir_path = dir_path 18 | df = pd.read_csv(csv_path) 19 | self.text_description = df['caption'] 20 | _, _, self.preprocess = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k') 21 | _, self.preprocess2 = clip.load("ViT-L/14", device='cuda') # RN50x64 22 | # tokenizer = open_clip.get_tokenizer('ViT-g-14') 23 | 24 | def __len__(self): 25 | return len(self.text_description) 26 | 27 | def __getitem__(self, idx): 28 | 29 | img_path = os.path.join(self.dir_path, f'{idx}.png') 30 | raw_image = Image.open(img_path) 31 | image = self.preprocess(raw_image).squeeze().float() 32 | image2 = self.preprocess2(raw_image).squeeze().float() 33 | text = self.text_description[idx] 34 | return image, image2, text 35 | -------------------------------------------------------------------------------- /diffuser/data_process.py: -------------------------------------------------------------------------------- 1 | # # bash commands to download the data 2 | # !wget http://images.cocodataset.org/zips/val2014.zip 3 | # !wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip 4 | # !unzip annotations_trainval2014.zip -d coco/ 5 | # !unzip val2014.zip -d coco/ 6 | 7 | 8 | 9 | # load data 10 | import json 11 | import numpy as np 12 | 13 | 14 | dir_path = './coco/' 15 | data_file = dir_path + 'coco/annotations/captions_val2014.json' 16 | data = json.load(open(data_file)) 17 | 18 | np.random.seed(123) 19 | 20 | # merge images and annotations 21 | import pandas as pd 22 | images = data['images'] 23 | annotations = data['annotations'] 24 | df = pd.DataFrame(images) 25 | df_annotations = pd.DataFrame(annotations) 26 | df = df.merge(pd.DataFrame(annotations), how='left', left_on='id', right_on='image_id') 27 | 28 | 29 | # keep only the relevant columns 30 | df = df[['file_name', 'caption']] 31 | print(df) 32 | print("length:", len(df['file_name'])) 33 | # shuffle the dataset 34 | df = df.sample(frac=1) 35 | 36 | 37 | # remove duplicate images 38 | # df = df.drop_duplicates(subset='file_name') 39 | 40 | # create a random subset of n_samples 41 | n_samples = 5000 42 | df_sample = df.sample(n_samples) 43 | print(df_sample) 44 | 45 | # save the sample to a parquet file 46 | df_sample.to_csv(dir_path + 'coco/subset.csv') 47 | 48 | # copy the images to reference folder 49 | from pathlib import Path 50 | import shutil 51 | subset_path = Path(dir_path + 'coco/subset') 52 | subset_path.mkdir(exist_ok=True) 53 | for i, row in df_sample.iterrows(): 54 | path = dir_path + 'coco/val2014/' + row['file_name'] 55 | shutil.copy(path, dir_path + 'coco/subset/') 56 | -------------------------------------------------------------------------------- /diffuser/diffusers/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/diffusers/.DS_Store -------------------------------------------------------------------------------- /diffuser/diffusers/commands/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | from argparse import ArgumentParser 17 | 18 | 19 | class BaseDiffusersCLICommand(ABC): 20 | @staticmethod 21 | @abstractmethod 22 | def register_subcommand(parser: ArgumentParser): 23 | raise NotImplementedError() 24 | 25 | @abstractmethod 26 | def run(self): 27 | raise NotImplementedError() 28 | -------------------------------------------------------------------------------- /diffuser/diffusers/commands/diffusers_cli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from argparse import ArgumentParser 17 | 18 | from .env import EnvironmentCommand 19 | 20 | 21 | def main(): 22 | parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []") 23 | commands_parser = parser.add_subparsers(help="diffusers-cli command helpers") 24 | 25 | # Register commands 26 | EnvironmentCommand.register_subcommand(commands_parser) 27 | 28 | # Let's go 29 | args = parser.parse_args() 30 | 31 | if not hasattr(args, "func"): 32 | parser.print_help() 33 | exit(1) 34 | 35 | # Run 36 | service = args.func(args) 37 | service.run() 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /diffuser/diffusers/commands/env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import platform 16 | from argparse import ArgumentParser 17 | 18 | import huggingface_hub 19 | 20 | from .. import __version__ as version 21 | from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available 22 | from . import BaseDiffusersCLICommand 23 | 24 | 25 | def info_command_factory(_): 26 | return EnvironmentCommand() 27 | 28 | 29 | class EnvironmentCommand(BaseDiffusersCLICommand): 30 | @staticmethod 31 | def register_subcommand(parser: ArgumentParser): 32 | download_parser = parser.add_parser("env") 33 | download_parser.set_defaults(func=info_command_factory) 34 | 35 | def run(self): 36 | hub_version = huggingface_hub.__version__ 37 | 38 | pt_version = "not installed" 39 | pt_cuda_available = "NA" 40 | if is_torch_available(): 41 | import torch 42 | 43 | pt_version = torch.__version__ 44 | pt_cuda_available = torch.cuda.is_available() 45 | 46 | transformers_version = "not installed" 47 | if is_transformers_available(): 48 | import transformers 49 | 50 | transformers_version = transformers.__version__ 51 | 52 | accelerate_version = "not installed" 53 | if is_accelerate_available(): 54 | import accelerate 55 | 56 | accelerate_version = accelerate.__version__ 57 | 58 | xformers_version = "not installed" 59 | if is_xformers_available(): 60 | import xformers 61 | 62 | xformers_version = xformers.__version__ 63 | 64 | info = { 65 | "`diffusers` version": version, 66 | "Platform": platform.platform(), 67 | "Python version": platform.python_version(), 68 | "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", 69 | "Huggingface_hub version": hub_version, 70 | "Transformers version": transformers_version, 71 | "Accelerate version": accelerate_version, 72 | "xFormers version": xformers_version, 73 | "Using GPU in script?": "", 74 | "Using distributed or parallel set-up in script?": "", 75 | } 76 | 77 | print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") 78 | print(self.format_dict(info)) 79 | 80 | return info 81 | 82 | @staticmethod 83 | def format_dict(d): 84 | return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" 85 | -------------------------------------------------------------------------------- /diffuser/diffusers/dependency_versions_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import sys 15 | 16 | from .dependency_versions_table import deps 17 | from .utils.versions import require_version, require_version_core 18 | 19 | 20 | # define which module versions we always want to check at run time 21 | # (usually the ones defined in `install_requires` in setup.py) 22 | # 23 | # order specific notes: 24 | # - tqdm must be checked before tokenizers 25 | 26 | pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split() 27 | if sys.version_info < (3, 7): 28 | pkgs_to_check_at_runtime.append("dataclasses") 29 | if sys.version_info < (3, 8): 30 | pkgs_to_check_at_runtime.append("importlib_metadata") 31 | 32 | for pkg in pkgs_to_check_at_runtime: 33 | if pkg in deps: 34 | if pkg == "tokenizers": 35 | # must be loaded here, or else tqdm check may fail 36 | from .utils import is_tokenizers_available 37 | 38 | if not is_tokenizers_available(): 39 | continue # not required, check version only if installed 40 | 41 | require_version_core(deps[pkg]) 42 | else: 43 | raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") 44 | 45 | 46 | def dep_version_check(pkg, hint=None): 47 | require_version(deps[pkg], hint) 48 | -------------------------------------------------------------------------------- /diffuser/diffusers/dependency_versions_table.py: -------------------------------------------------------------------------------- 1 | # THIS FILE HAS BEEN AUTOGENERATED. To update: 2 | # 1. modify the `_deps` dict in setup.py 3 | # 2. run `make deps_table_update`` 4 | deps = { 5 | "Pillow": "Pillow", 6 | "accelerate": "accelerate>=0.11.0", 7 | "compel": "compel==0.1.8", 8 | "black": "black~=23.1", 9 | "datasets": "datasets", 10 | "filelock": "filelock", 11 | "flax": "flax>=0.4.1", 12 | "hf-doc-builder": "hf-doc-builder>=0.3.0", 13 | "huggingface-hub": "huggingface-hub>=0.13.2", 14 | "requests-mock": "requests-mock==1.10.0", 15 | "importlib_metadata": "importlib_metadata", 16 | "isort": "isort>=5.5.4", 17 | "jax": "jax>=0.2.8,!=0.3.2", 18 | "jaxlib": "jaxlib>=0.1.65", 19 | "Jinja2": "Jinja2", 20 | "k-diffusion": "k-diffusion>=0.0.12", 21 | "librosa": "librosa", 22 | "note-seq": "note-seq", 23 | "numpy": "numpy", 24 | "parameterized": "parameterized", 25 | "protobuf": "protobuf>=3.20.3,<4", 26 | "pytest": "pytest", 27 | "pytest-timeout": "pytest-timeout", 28 | "pytest-xdist": "pytest-xdist", 29 | "ruff": "ruff>=0.0.241", 30 | "safetensors": "safetensors", 31 | "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", 32 | "scipy": "scipy", 33 | "regex": "regex!=2019.12.17", 34 | "requests": "requests", 35 | "tensorboard": "tensorboard", 36 | "torch": "torch>=1.4", 37 | "torchvision": "torchvision", 38 | "transformers": "transformers>=4.25.1", 39 | } 40 | -------------------------------------------------------------------------------- /diffuser/diffusers/experimental/README.md: -------------------------------------------------------------------------------- 1 | # 🧨 Diffusers Experimental 2 | 3 | We are adding experimental code to support novel applications and usages of the Diffusers library. 4 | Currently, the following experiments are supported: 5 | * Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model. -------------------------------------------------------------------------------- /diffuser/diffusers/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | from .rl import ValueGuidedRLPipeline 2 | -------------------------------------------------------------------------------- /diffuser/diffusers/experimental/rl/__init__.py: -------------------------------------------------------------------------------- 1 | from .value_guided_sampling import ValueGuidedRLPipeline 2 | -------------------------------------------------------------------------------- /diffuser/diffusers/models/README.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models). -------------------------------------------------------------------------------- /diffuser/diffusers/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from ..utils import is_flax_available, is_torch_available 16 | 17 | 18 | if is_torch_available(): 19 | from .autoencoder_kl import AutoencoderKL 20 | from .controlnet import ControlNetModel 21 | from .dual_transformer_2d import DualTransformer2DModel 22 | from .modeling_utils import ModelMixin 23 | from .prior_transformer import PriorTransformer 24 | from .t5_film_transformer import T5FilmDecoder 25 | from .transformer_2d import Transformer2DModel 26 | from .unet_1d import UNet1DModel 27 | from .unet_2d import UNet2DModel 28 | from .unet_2d_condition import UNet2DConditionModel 29 | from .unet_3d_condition import UNet3DConditionModel 30 | from .vq_model import VQModel 31 | 32 | if is_flax_available(): 33 | from .controlnet_flax import FlaxControlNetModel 34 | from .unet_2d_condition_flax import FlaxUNet2DConditionModel 35 | from .vae_flax import FlaxAutoencoderKL 36 | -------------------------------------------------------------------------------- /diffuser/diffusers/models/cross_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from ..utils import deprecate 15 | from .attention_processor import ( # noqa: F401 16 | Attention, 17 | AttentionProcessor, 18 | AttnAddedKVProcessor, 19 | AttnProcessor2_0, 20 | LoRAAttnProcessor, 21 | LoRALinearLayer, 22 | LoRAXFormersAttnProcessor, 23 | SlicedAttnAddedKVProcessor, 24 | SlicedAttnProcessor, 25 | XFormersAttnProcessor, 26 | ) 27 | from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401 28 | 29 | 30 | deprecate( 31 | "cross_attention", 32 | "0.18.0", 33 | "Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.", 34 | standard_warn=False, 35 | ) 36 | 37 | 38 | AttnProcessor = AttentionProcessor 39 | 40 | 41 | class CrossAttention(Attention): 42 | def __init__(self, *args, **kwargs): 43 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 44 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) 45 | super().__init__(*args, **kwargs) 46 | 47 | 48 | class CrossAttnProcessor(AttnProcessorRename): 49 | def __init__(self, *args, **kwargs): 50 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 51 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) 52 | super().__init__(*args, **kwargs) 53 | 54 | 55 | class LoRACrossAttnProcessor(LoRAAttnProcessor): 56 | def __init__(self, *args, **kwargs): 57 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 58 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) 59 | super().__init__(*args, **kwargs) 60 | 61 | 62 | class CrossAttnAddedKVProcessor(AttnAddedKVProcessor): 63 | def __init__(self, *args, **kwargs): 64 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 65 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) 66 | super().__init__(*args, **kwargs) 67 | 68 | 69 | class XFormersCrossAttnProcessor(XFormersAttnProcessor): 70 | def __init__(self, *args, **kwargs): 71 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 72 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) 73 | super().__init__(*args, **kwargs) 74 | 75 | 76 | class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor): 77 | def __init__(self, *args, **kwargs): 78 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 79 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) 80 | super().__init__(*args, **kwargs) 81 | 82 | 83 | class SlicedCrossAttnProcessor(SlicedAttnProcessor): 84 | def __init__(self, *args, **kwargs): 85 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 86 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) 87 | super().__init__(*args, **kwargs) 88 | 89 | 90 | class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor): 91 | def __init__(self, *args, **kwargs): 92 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 93 | deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) 94 | super().__init__(*args, **kwargs) 95 | -------------------------------------------------------------------------------- /diffuser/diffusers/models/embeddings_flax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import math 15 | 16 | import flax.linen as nn 17 | import jax.numpy as jnp 18 | 19 | 20 | def get_sinusoidal_embeddings( 21 | timesteps: jnp.ndarray, 22 | embedding_dim: int, 23 | freq_shift: float = 1, 24 | min_timescale: float = 1, 25 | max_timescale: float = 1.0e4, 26 | flip_sin_to_cos: bool = False, 27 | scale: float = 1.0, 28 | ) -> jnp.ndarray: 29 | """Returns the positional encoding (same as Tensor2Tensor). 30 | 31 | Args: 32 | timesteps: a 1-D Tensor of N indices, one per batch element. 33 | These may be fractional. 34 | embedding_dim: The number of output channels. 35 | min_timescale: The smallest time unit (should probably be 0.0). 36 | max_timescale: The largest time unit. 37 | Returns: 38 | a Tensor of timing signals [N, num_channels] 39 | """ 40 | assert timesteps.ndim == 1, "Timesteps should be a 1d-array" 41 | assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" 42 | num_timescales = float(embedding_dim // 2) 43 | log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) 44 | inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) 45 | emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) 46 | 47 | # scale embeddings 48 | scaled_time = scale * emb 49 | 50 | if flip_sin_to_cos: 51 | signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) 52 | else: 53 | signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) 54 | signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) 55 | return signal 56 | 57 | 58 | class FlaxTimestepEmbedding(nn.Module): 59 | r""" 60 | Time step Embedding Module. Learns embeddings for input time steps. 61 | 62 | Args: 63 | time_embed_dim (`int`, *optional*, defaults to `32`): 64 | Time step embedding dimension 65 | dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): 66 | Parameters `dtype` 67 | """ 68 | time_embed_dim: int = 32 69 | dtype: jnp.dtype = jnp.float32 70 | 71 | @nn.compact 72 | def __call__(self, temb): 73 | temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) 74 | temb = nn.silu(temb) 75 | temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) 76 | return temb 77 | 78 | 79 | class FlaxTimesteps(nn.Module): 80 | r""" 81 | Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 82 | 83 | Args: 84 | dim (`int`, *optional*, defaults to `32`): 85 | Time step embedding dimension 86 | """ 87 | dim: int = 32 88 | flip_sin_to_cos: bool = False 89 | freq_shift: float = 1 90 | 91 | @nn.compact 92 | def __call__(self, timesteps): 93 | return get_sinusoidal_embeddings( 94 | timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift 95 | ) 96 | -------------------------------------------------------------------------------- /diffuser/diffusers/models/modeling_flax_pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch - Flax general utilities.""" 16 | import re 17 | 18 | import jax.numpy as jnp 19 | from flax.traverse_util import flatten_dict, unflatten_dict 20 | from jax.random import PRNGKey 21 | 22 | from ..utils import logging 23 | 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | 28 | def rename_key(key): 29 | regex = r"\w+[.]\d+" 30 | pats = re.findall(regex, key) 31 | for pat in pats: 32 | key = key.replace(pat, "_".join(pat.split("."))) 33 | return key 34 | 35 | 36 | ##################### 37 | # PyTorch => Flax # 38 | ##################### 39 | 40 | 41 | # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 42 | # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py 43 | def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): 44 | """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" 45 | 46 | # conv norm or layer norm 47 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) 48 | if ( 49 | any("norm" in str_ for str_ in pt_tuple_key) 50 | and (pt_tuple_key[-1] == "bias") 51 | and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict) 52 | and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) 53 | ): 54 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) 55 | return renamed_pt_tuple_key, pt_tensor 56 | elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: 57 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) 58 | return renamed_pt_tuple_key, pt_tensor 59 | 60 | # embedding 61 | if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: 62 | pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) 63 | return renamed_pt_tuple_key, pt_tensor 64 | 65 | # conv layer 66 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) 67 | if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: 68 | pt_tensor = pt_tensor.transpose(2, 3, 1, 0) 69 | return renamed_pt_tuple_key, pt_tensor 70 | 71 | # linear layer 72 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) 73 | if pt_tuple_key[-1] == "weight": 74 | pt_tensor = pt_tensor.T 75 | return renamed_pt_tuple_key, pt_tensor 76 | 77 | # old PyTorch layer norm weight 78 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) 79 | if pt_tuple_key[-1] == "gamma": 80 | return renamed_pt_tuple_key, pt_tensor 81 | 82 | # old PyTorch layer norm bias 83 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) 84 | if pt_tuple_key[-1] == "beta": 85 | return renamed_pt_tuple_key, pt_tensor 86 | 87 | return pt_tuple_key, pt_tensor 88 | 89 | 90 | def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): 91 | # Step 1: Convert pytorch tensor to numpy 92 | pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} 93 | 94 | # Step 2: Since the model is stateless, get random Flax params 95 | random_flax_params = flax_model.init_weights(PRNGKey(init_key)) 96 | 97 | random_flax_state_dict = flatten_dict(random_flax_params) 98 | flax_state_dict = {} 99 | 100 | # Need to change some parameters name to match Flax names 101 | for pt_key, pt_tensor in pt_state_dict.items(): 102 | renamed_pt_key = rename_key(pt_key) 103 | pt_tuple_key = tuple(renamed_pt_key.split(".")) 104 | 105 | # Correctly rename weight parameters 106 | flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict) 107 | 108 | if flax_key in random_flax_state_dict: 109 | if flax_tensor.shape != random_flax_state_dict[flax_key].shape: 110 | raise ValueError( 111 | f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " 112 | f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." 113 | ) 114 | 115 | # also add unexpected weight so that warning is thrown 116 | flax_state_dict[flax_key] = jnp.asarray(flax_tensor) 117 | 118 | return unflatten_dict(flax_state_dict) 119 | -------------------------------------------------------------------------------- /diffuser/diffusers/models/resnet_flax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import flax.linen as nn 15 | import jax 16 | import jax.numpy as jnp 17 | 18 | 19 | class FlaxUpsample2D(nn.Module): 20 | out_channels: int 21 | dtype: jnp.dtype = jnp.float32 22 | 23 | def setup(self): 24 | self.conv = nn.Conv( 25 | self.out_channels, 26 | kernel_size=(3, 3), 27 | strides=(1, 1), 28 | padding=((1, 1), (1, 1)), 29 | dtype=self.dtype, 30 | ) 31 | 32 | def __call__(self, hidden_states): 33 | batch, height, width, channels = hidden_states.shape 34 | hidden_states = jax.image.resize( 35 | hidden_states, 36 | shape=(batch, height * 2, width * 2, channels), 37 | method="nearest", 38 | ) 39 | hidden_states = self.conv(hidden_states) 40 | return hidden_states 41 | 42 | 43 | class FlaxDownsample2D(nn.Module): 44 | out_channels: int 45 | dtype: jnp.dtype = jnp.float32 46 | 47 | def setup(self): 48 | self.conv = nn.Conv( 49 | self.out_channels, 50 | kernel_size=(3, 3), 51 | strides=(2, 2), 52 | padding=((1, 1), (1, 1)), # padding="VALID", 53 | dtype=self.dtype, 54 | ) 55 | 56 | def __call__(self, hidden_states): 57 | # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim 58 | # hidden_states = jnp.pad(hidden_states, pad_width=pad) 59 | hidden_states = self.conv(hidden_states) 60 | return hidden_states 61 | 62 | 63 | class FlaxResnetBlock2D(nn.Module): 64 | in_channels: int 65 | out_channels: int = None 66 | dropout_prob: float = 0.0 67 | use_nin_shortcut: bool = None 68 | dtype: jnp.dtype = jnp.float32 69 | 70 | def setup(self): 71 | out_channels = self.in_channels if self.out_channels is None else self.out_channels 72 | 73 | self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) 74 | self.conv1 = nn.Conv( 75 | out_channels, 76 | kernel_size=(3, 3), 77 | strides=(1, 1), 78 | padding=((1, 1), (1, 1)), 79 | dtype=self.dtype, 80 | ) 81 | 82 | self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) 83 | 84 | self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) 85 | self.dropout = nn.Dropout(self.dropout_prob) 86 | self.conv2 = nn.Conv( 87 | out_channels, 88 | kernel_size=(3, 3), 89 | strides=(1, 1), 90 | padding=((1, 1), (1, 1)), 91 | dtype=self.dtype, 92 | ) 93 | 94 | use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut 95 | 96 | self.conv_shortcut = None 97 | if use_nin_shortcut: 98 | self.conv_shortcut = nn.Conv( 99 | out_channels, 100 | kernel_size=(1, 1), 101 | strides=(1, 1), 102 | padding="VALID", 103 | dtype=self.dtype, 104 | ) 105 | 106 | def __call__(self, hidden_states, temb, deterministic=True): 107 | residual = hidden_states 108 | hidden_states = self.norm1(hidden_states) 109 | hidden_states = nn.swish(hidden_states) 110 | hidden_states = self.conv1(hidden_states) 111 | 112 | temb = self.time_emb_proj(nn.swish(temb)) 113 | temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) 114 | hidden_states = hidden_states + temb 115 | 116 | hidden_states = self.norm2(hidden_states) 117 | hidden_states = nn.swish(hidden_states) 118 | hidden_states = self.dropout(hidden_states, deterministic) 119 | hidden_states = self.conv2(hidden_states) 120 | 121 | if self.conv_shortcut is not None: 122 | residual = self.conv_shortcut(residual) 123 | 124 | return hidden_states + residual 125 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipeline_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | 14 | # limitations under the License. 15 | 16 | # NOTE: This file is deprecated and will be removed in a future version. 17 | # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works 18 | 19 | from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401 20 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/diffusers/pipelines/.DS_Store -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import ( 2 | OptionalDependencyNotAvailable, 3 | is_flax_available, 4 | is_k_diffusion_available, 5 | is_librosa_available, 6 | is_note_seq_available, 7 | is_onnx_available, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | try: 14 | if not is_torch_available(): 15 | raise OptionalDependencyNotAvailable() 16 | except OptionalDependencyNotAvailable: 17 | from ..utils.dummy_pt_objects import * # noqa F403 18 | else: 19 | from .dance_diffusion import DanceDiffusionPipeline 20 | from .ddim import DDIMPipeline 21 | from .ddpm import DDPMPipeline 22 | from .dit import DiTPipeline 23 | from .latent_diffusion import LDMSuperResolutionPipeline 24 | from .latent_diffusion_uncond import LDMPipeline 25 | from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput 26 | from .pndm import PNDMPipeline 27 | from .repaint import RePaintPipeline 28 | from .score_sde_ve import ScoreSdeVePipeline 29 | from .stochastic_karras_ve import KarrasVePipeline 30 | 31 | try: 32 | if not (is_torch_available() and is_librosa_available()): 33 | raise OptionalDependencyNotAvailable() 34 | except OptionalDependencyNotAvailable: 35 | from ..utils.dummy_torch_and_librosa_objects import * # noqa F403 36 | else: 37 | from .audio_diffusion import AudioDiffusionPipeline, Mel 38 | 39 | try: 40 | if not (is_torch_available() and is_transformers_available()): 41 | raise OptionalDependencyNotAvailable() 42 | except OptionalDependencyNotAvailable: 43 | from ..utils.dummy_torch_and_transformers_objects import * # noqa F403 44 | else: 45 | from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline 46 | from .audioldm import AudioLDMPipeline 47 | from .latent_diffusion import LDMTextToImagePipeline 48 | from .paint_by_example import PaintByExamplePipeline 49 | from .semantic_stable_diffusion import SemanticStableDiffusionPipeline 50 | from .stable_diffusion import ( 51 | CycleDiffusionPipeline, 52 | StableDiffusionAttendAndExcitePipeline, 53 | StableDiffusionControlNetPipeline, 54 | StableDiffusionDepth2ImgPipeline, 55 | StableDiffusionImageVariationPipeline, 56 | StableDiffusionImg2ImgPipeline, 57 | StableDiffusionInpaintPipeline, 58 | StableDiffusionInpaintPipelineLegacy, 59 | StableDiffusionInstructPix2PixPipeline, 60 | StableDiffusionLatentUpscalePipeline, 61 | StableDiffusionModelEditingPipeline, 62 | StableDiffusionPanoramaPipeline, 63 | StableDiffusionPipeline, 64 | StableDiffusionPix2PixZeroPipeline, 65 | StableDiffusionSAGPipeline, 66 | StableDiffusionUpscalePipeline, 67 | StableUnCLIPImg2ImgPipeline, 68 | StableUnCLIPPipeline, 69 | ) 70 | from .stable_diffusion_safe import StableDiffusionPipelineSafe 71 | from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline 72 | from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline 73 | from .versatile_diffusion import ( 74 | VersatileDiffusionDualGuidedPipeline, 75 | VersatileDiffusionImageVariationPipeline, 76 | VersatileDiffusionPipeline, 77 | VersatileDiffusionTextToImagePipeline, 78 | ) 79 | from .vq_diffusion import VQDiffusionPipeline 80 | 81 | try: 82 | if not is_onnx_available(): 83 | raise OptionalDependencyNotAvailable() 84 | except OptionalDependencyNotAvailable: 85 | from ..utils.dummy_onnx_objects import * # noqa F403 86 | else: 87 | from .onnx_utils import OnnxRuntimeModel 88 | 89 | try: 90 | if not (is_torch_available() and is_transformers_available() and is_onnx_available()): 91 | raise OptionalDependencyNotAvailable() 92 | except OptionalDependencyNotAvailable: 93 | from ..utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 94 | else: 95 | from .stable_diffusion import ( 96 | OnnxStableDiffusionImg2ImgPipeline, 97 | OnnxStableDiffusionInpaintPipeline, 98 | OnnxStableDiffusionInpaintPipelineLegacy, 99 | OnnxStableDiffusionPipeline, 100 | OnnxStableDiffusionUpscalePipeline, 101 | StableDiffusionOnnxPipeline, 102 | ) 103 | 104 | try: 105 | if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): 106 | raise OptionalDependencyNotAvailable() 107 | except OptionalDependencyNotAvailable: 108 | from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 109 | else: 110 | from .stable_diffusion import StableDiffusionKDiffusionPipeline 111 | 112 | try: 113 | if not is_flax_available(): 114 | raise OptionalDependencyNotAvailable() 115 | except OptionalDependencyNotAvailable: 116 | from ..utils.dummy_flax_objects import * # noqa F403 117 | else: 118 | from .pipeline_flax_utils import FlaxDiffusionPipeline 119 | 120 | 121 | try: 122 | if not (is_flax_available() and is_transformers_available()): 123 | raise OptionalDependencyNotAvailable() 124 | except OptionalDependencyNotAvailable: 125 | from ..utils.dummy_flax_and_transformers_objects import * # noqa F403 126 | else: 127 | from .stable_diffusion import ( 128 | FlaxStableDiffusionControlNetPipeline, 129 | FlaxStableDiffusionImg2ImgPipeline, 130 | FlaxStableDiffusionInpaintPipeline, 131 | FlaxStableDiffusionPipeline, 132 | ) 133 | try: 134 | if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): 135 | raise OptionalDependencyNotAvailable() 136 | except OptionalDependencyNotAvailable: 137 | from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 138 | else: 139 | from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline 140 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/alt_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | import PIL 6 | from PIL import Image 7 | 8 | from ...utils import BaseOutput, is_torch_available, is_transformers_available 9 | 10 | 11 | @dataclass 12 | # Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with Stable->Alt 13 | class AltDiffusionPipelineOutput(BaseOutput): 14 | """ 15 | Output class for Alt Diffusion pipelines. 16 | 17 | Args: 18 | images (`List[PIL.Image.Image]` or `np.ndarray`) 19 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, 20 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. 21 | nsfw_content_detected (`List[bool]`) 22 | List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" 23 | (nsfw) content, or `None` if safety checking could not be performed. 24 | """ 25 | 26 | images: Union[List[PIL.Image.Image], np.ndarray] 27 | nsfw_content_detected: Optional[List[bool]] 28 | 29 | 30 | if is_transformers_available() and is_torch_available(): 31 | from .modeling_roberta_series import RobertaSeriesModelWithTransformation 32 | from .pipeline_alt_diffusion import AltDiffusionPipeline 33 | from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline 34 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | from torch import nn 6 | from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel 7 | from transformers.utils import ModelOutput 8 | 9 | 10 | @dataclass 11 | class TransformationModelOutput(ModelOutput): 12 | """ 13 | Base class for text model's outputs that also contains a pooling of the last hidden states. 14 | 15 | Args: 16 | text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): 17 | The text embeddings obtained by applying the projection layer to the pooler_output. 18 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 19 | Sequence of hidden-states at the output of the last layer of the model. 20 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 21 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 22 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 23 | 24 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 25 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 26 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 27 | sequence_length)`. 28 | 29 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 30 | heads. 31 | """ 32 | 33 | projection_state: Optional[torch.FloatTensor] = None 34 | last_hidden_state: torch.FloatTensor = None 35 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 36 | attentions: Optional[Tuple[torch.FloatTensor]] = None 37 | 38 | 39 | class RobertaSeriesConfig(XLMRobertaConfig): 40 | def __init__( 41 | self, 42 | pad_token_id=1, 43 | bos_token_id=0, 44 | eos_token_id=2, 45 | project_dim=512, 46 | pooler_fn="cls", 47 | learn_encoder=False, 48 | use_attention_mask=True, 49 | **kwargs, 50 | ): 51 | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 52 | self.project_dim = project_dim 53 | self.pooler_fn = pooler_fn 54 | self.learn_encoder = learn_encoder 55 | self.use_attention_mask = use_attention_mask 56 | 57 | 58 | class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): 59 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 60 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] 61 | base_model_prefix = "roberta" 62 | config_class = RobertaSeriesConfig 63 | 64 | def __init__(self, config): 65 | super().__init__(config) 66 | self.roberta = XLMRobertaModel(config) 67 | self.transformation = nn.Linear(config.hidden_size, config.project_dim) 68 | self.post_init() 69 | 70 | def forward( 71 | self, 72 | input_ids: Optional[torch.Tensor] = None, 73 | attention_mask: Optional[torch.Tensor] = None, 74 | token_type_ids: Optional[torch.Tensor] = None, 75 | position_ids: Optional[torch.Tensor] = None, 76 | head_mask: Optional[torch.Tensor] = None, 77 | inputs_embeds: Optional[torch.Tensor] = None, 78 | encoder_hidden_states: Optional[torch.Tensor] = None, 79 | encoder_attention_mask: Optional[torch.Tensor] = None, 80 | output_attentions: Optional[bool] = None, 81 | return_dict: Optional[bool] = None, 82 | output_hidden_states: Optional[bool] = None, 83 | ): 84 | r""" """ 85 | 86 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 87 | 88 | outputs = self.base_model( 89 | input_ids=input_ids, 90 | attention_mask=attention_mask, 91 | token_type_ids=token_type_ids, 92 | position_ids=position_ids, 93 | head_mask=head_mask, 94 | inputs_embeds=inputs_embeds, 95 | encoder_hidden_states=encoder_hidden_states, 96 | encoder_attention_mask=encoder_attention_mask, 97 | output_attentions=output_attentions, 98 | output_hidden_states=output_hidden_states, 99 | return_dict=return_dict, 100 | ) 101 | 102 | projection_state = self.transformation(outputs.last_hidden_state) 103 | 104 | return TransformationModelOutput( 105 | projection_state=projection_state, 106 | last_hidden_state=outputs.last_hidden_state, 107 | hidden_states=outputs.hidden_states, 108 | attentions=outputs.attentions, 109 | ) 110 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/audio_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .mel import Mel 2 | from .pipeline_audio_diffusion import AudioDiffusionPipeline 3 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/audio_diffusion/mel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np # noqa: E402 17 | 18 | from ...configuration_utils import ConfigMixin, register_to_config 19 | from ...schedulers.scheduling_utils import SchedulerMixin 20 | 21 | 22 | try: 23 | import librosa # noqa: E402 24 | 25 | _librosa_can_be_imported = True 26 | _import_error = "" 27 | except Exception as e: 28 | _librosa_can_be_imported = False 29 | _import_error = ( 30 | f"Cannot import librosa because {e}. Make sure to correctly install librosa to be able to install it." 31 | ) 32 | 33 | 34 | from PIL import Image # noqa: E402 35 | 36 | 37 | class Mel(ConfigMixin, SchedulerMixin): 38 | """ 39 | Parameters: 40 | x_res (`int`): x resolution of spectrogram (time) 41 | y_res (`int`): y resolution of spectrogram (frequency bins) 42 | sample_rate (`int`): sample rate of audio 43 | n_fft (`int`): number of Fast Fourier Transforms 44 | hop_length (`int`): hop length (a higher number is recommended for lower than 256 y_res) 45 | top_db (`int`): loudest in decibels 46 | n_iter (`int`): number of iterations for Griffin Linn mel inversion 47 | """ 48 | 49 | config_name = "mel_config.json" 50 | 51 | @register_to_config 52 | def __init__( 53 | self, 54 | x_res: int = 256, 55 | y_res: int = 256, 56 | sample_rate: int = 22050, 57 | n_fft: int = 2048, 58 | hop_length: int = 512, 59 | top_db: int = 80, 60 | n_iter: int = 32, 61 | ): 62 | self.hop_length = hop_length 63 | self.sr = sample_rate 64 | self.n_fft = n_fft 65 | self.top_db = top_db 66 | self.n_iter = n_iter 67 | self.set_resolution(x_res, y_res) 68 | self.audio = None 69 | 70 | if not _librosa_can_be_imported: 71 | raise ValueError(_import_error) 72 | 73 | def set_resolution(self, x_res: int, y_res: int): 74 | """Set resolution. 75 | 76 | Args: 77 | x_res (`int`): x resolution of spectrogram (time) 78 | y_res (`int`): y resolution of spectrogram (frequency bins) 79 | """ 80 | self.x_res = x_res 81 | self.y_res = y_res 82 | self.n_mels = self.y_res 83 | self.slice_size = self.x_res * self.hop_length - 1 84 | 85 | def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None): 86 | """Load audio. 87 | 88 | Args: 89 | audio_file (`str`): must be a file on disk due to Librosa limitation or 90 | raw_audio (`np.ndarray`): audio as numpy array 91 | """ 92 | if audio_file is not None: 93 | self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr) 94 | else: 95 | self.audio = raw_audio 96 | 97 | # Pad with silence if necessary. 98 | if len(self.audio) < self.x_res * self.hop_length: 99 | self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))]) 100 | 101 | def get_number_of_slices(self) -> int: 102 | """Get number of slices in audio. 103 | 104 | Returns: 105 | `int`: number of spectograms audio can be sliced into 106 | """ 107 | return len(self.audio) // self.slice_size 108 | 109 | def get_audio_slice(self, slice: int = 0) -> np.ndarray: 110 | """Get slice of audio. 111 | 112 | Args: 113 | slice (`int`): slice number of audio (out of get_number_of_slices()) 114 | 115 | Returns: 116 | `np.ndarray`: audio as numpy array 117 | """ 118 | return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)] 119 | 120 | def get_sample_rate(self) -> int: 121 | """Get sample rate: 122 | 123 | Returns: 124 | `int`: sample rate of audio 125 | """ 126 | return self.sr 127 | 128 | def audio_slice_to_image(self, slice: int) -> Image.Image: 129 | """Convert slice of audio to spectrogram. 130 | 131 | Args: 132 | slice (`int`): slice number of audio to convert (out of get_number_of_slices()) 133 | 134 | Returns: 135 | `PIL Image`: grayscale image of x_res x y_res 136 | """ 137 | S = librosa.feature.melspectrogram( 138 | y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels 139 | ) 140 | log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db) 141 | bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8) 142 | image = Image.fromarray(bytedata) 143 | return image 144 | 145 | def image_to_audio(self, image: Image.Image) -> np.ndarray: 146 | """Converts spectrogram to audio. 147 | 148 | Args: 149 | image (`PIL Image`): x_res x y_res grayscale image 150 | 151 | Returns: 152 | audio (`np.ndarray`): raw audio 153 | """ 154 | bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width)) 155 | log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db 156 | S = librosa.db_to_power(log_S) 157 | audio = librosa.feature.inverse.mel_to_audio( 158 | S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter 159 | ) 160 | return audio 161 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/audioldm/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils import ( 2 | OptionalDependencyNotAvailable, 3 | is_torch_available, 4 | is_transformers_available, 5 | is_transformers_version, 6 | ) 7 | 8 | 9 | try: 10 | if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): 11 | raise OptionalDependencyNotAvailable() 12 | except OptionalDependencyNotAvailable: 13 | from ...utils.dummy_torch_and_transformers_objects import ( 14 | AudioLDMPipeline, 15 | ) 16 | else: 17 | from .pipeline_audioldm import AudioLDMPipeline 18 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/dance_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_dance_diffusion import DanceDiffusionPipeline 2 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | 20 | from ...utils import logging, randn_tensor 21 | from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline 22 | 23 | 24 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 25 | 26 | 27 | class DanceDiffusionPipeline(DiffusionPipeline): 28 | r""" 29 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 30 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 31 | 32 | Parameters: 33 | unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded image. 34 | scheduler ([`SchedulerMixin`]): 35 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of 36 | [`IPNDMScheduler`]. 37 | """ 38 | 39 | def __init__(self, unet, scheduler): 40 | super().__init__() 41 | self.register_modules(unet=unet, scheduler=scheduler) 42 | 43 | @torch.no_grad() 44 | def __call__( 45 | self, 46 | batch_size: int = 1, 47 | num_inference_steps: int = 100, 48 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 49 | audio_length_in_s: Optional[float] = None, 50 | return_dict: bool = True, 51 | ) -> Union[AudioPipelineOutput, Tuple]: 52 | r""" 53 | Args: 54 | batch_size (`int`, *optional*, defaults to 1): 55 | The number of audio samples to generate. 56 | num_inference_steps (`int`, *optional*, defaults to 50): 57 | The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at 58 | the expense of slower inference. 59 | generator (`torch.Generator`, *optional*): 60 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 61 | to make generation deterministic. 62 | audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`): 63 | The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.* 64 | `sample_size`, will be `audio_length_in_s` * `self.unet.config.sample_rate`. 65 | return_dict (`bool`, *optional*, defaults to `True`): 66 | Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. 67 | 68 | Returns: 69 | [`~pipelines.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if `return_dict` is 70 | True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. 71 | """ 72 | 73 | if audio_length_in_s is None: 74 | audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate 75 | 76 | sample_size = audio_length_in_s * self.unet.config.sample_rate 77 | 78 | down_scale_factor = 2 ** len(self.unet.up_blocks) 79 | if sample_size < 3 * down_scale_factor: 80 | raise ValueError( 81 | f"{audio_length_in_s} is too small. Make sure it's bigger or equal to" 82 | f" {3 * down_scale_factor / self.unet.config.sample_rate}." 83 | ) 84 | 85 | original_sample_size = int(sample_size) 86 | if sample_size % down_scale_factor != 0: 87 | sample_size = ( 88 | (audio_length_in_s * self.unet.config.sample_rate) // down_scale_factor + 1 89 | ) * down_scale_factor 90 | logger.info( 91 | f"{audio_length_in_s} is increased to {sample_size / self.unet.config.sample_rate} so that it can be handled" 92 | f" by the model. It will be cut to {original_sample_size / self.unet.config.sample_rate} after the denoising" 93 | " process." 94 | ) 95 | sample_size = int(sample_size) 96 | 97 | dtype = next(iter(self.unet.parameters())).dtype 98 | shape = (batch_size, self.unet.config.in_channels, sample_size) 99 | if isinstance(generator, list) and len(generator) != batch_size: 100 | raise ValueError( 101 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 102 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 103 | ) 104 | 105 | audio = randn_tensor(shape, generator=generator, device=self.device, dtype=dtype) 106 | 107 | # set step values 108 | self.scheduler.set_timesteps(num_inference_steps, device=audio.device) 109 | self.scheduler.timesteps = self.scheduler.timesteps.to(dtype) 110 | 111 | for t in self.progress_bar(self.scheduler.timesteps): 112 | # 1. predict noise model_output 113 | model_output = self.unet(audio, t).sample 114 | 115 | # 2. compute previous image: x_t -> t_t-1 116 | audio = self.scheduler.step(model_output, t, audio).prev_sample 117 | 118 | audio = audio.clamp(-1, 1).float().cpu().numpy() 119 | 120 | audio = audio[:, :, :original_sample_size] 121 | 122 | if not return_dict: 123 | return (audio,) 124 | 125 | return AudioPipelineOutput(audios=audio) 126 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/ddim/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_ddim import DDIMPipeline 2 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/ddim/pipeline_ddim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, Optional, Tuple, Union 16 | 17 | import torch 18 | 19 | from ...schedulers import DDIMScheduler 20 | from ...utils import randn_tensor 21 | from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput 22 | 23 | 24 | class DDIMPipeline(DiffusionPipeline): 25 | r""" 26 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 27 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 28 | 29 | Parameters: 30 | unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. 31 | scheduler ([`SchedulerMixin`]): 32 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of 33 | [`DDPMScheduler`], or [`DDIMScheduler`]. 34 | """ 35 | 36 | def __init__(self, unet, scheduler): 37 | super().__init__() 38 | 39 | # make sure scheduler can always be converted to DDIM 40 | scheduler = DDIMScheduler.from_config(scheduler.config) 41 | 42 | self.register_modules(unet=unet, scheduler=scheduler) 43 | 44 | @torch.no_grad() 45 | def __call__( 46 | self, 47 | batch_size: int = 1, 48 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 49 | eta: float = 0.0, 50 | num_inference_steps: int = 50, 51 | use_clipped_model_output: Optional[bool] = None, 52 | output_type: Optional[str] = "pil", 53 | return_dict: bool = True, 54 | ) -> Union[ImagePipelineOutput, Tuple]: 55 | r""" 56 | Args: 57 | batch_size (`int`, *optional*, defaults to 1): 58 | The number of images to generate. 59 | generator (`torch.Generator`, *optional*): 60 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 61 | to make generation deterministic. 62 | eta (`float`, *optional*, defaults to 0.0): 63 | The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). 64 | num_inference_steps (`int`, *optional*, defaults to 50): 65 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 66 | expense of slower inference. 67 | use_clipped_model_output (`bool`, *optional*, defaults to `None`): 68 | if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed 69 | downstream to the scheduler. So use `None` for schedulers which don't support this argument. 70 | output_type (`str`, *optional*, defaults to `"pil"`): 71 | The output format of the generate image. Choose between 72 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 73 | return_dict (`bool`, *optional*, defaults to `True`): 74 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. 75 | 76 | Returns: 77 | [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is 78 | True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. 79 | """ 80 | 81 | # Sample gaussian noise to begin loop 82 | if isinstance(self.unet.config.sample_size, int): 83 | image_shape = ( 84 | batch_size, 85 | self.unet.config.in_channels, 86 | self.unet.config.sample_size, 87 | self.unet.config.sample_size, 88 | ) 89 | else: 90 | image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) 91 | 92 | if isinstance(generator, list) and len(generator) != batch_size: 93 | raise ValueError( 94 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 95 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 96 | ) 97 | 98 | image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) 99 | 100 | # set step values 101 | self.scheduler.set_timesteps(num_inference_steps) 102 | 103 | for t in self.progress_bar(self.scheduler.timesteps): 104 | # 1. predict noise model_output 105 | model_output = self.unet(image, t).sample 106 | 107 | # 2. predict previous mean of image x_t-1 and add variance depending on eta 108 | # eta corresponds to η in paper and should be between [0, 1] 109 | # do x_t -> x_t-1 110 | image = self.scheduler.step( 111 | model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator 112 | ).prev_sample 113 | 114 | image = (image / 2 + 0.5).clamp(0, 1) 115 | image = image.cpu().permute(0, 2, 3, 1).numpy() 116 | if output_type == "pil": 117 | image = self.numpy_to_pil(image) 118 | 119 | if not return_dict: 120 | return (image,) 121 | 122 | return ImagePipelineOutput(images=image) 123 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/ddpm/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_ddpm import DDPMPipeline 2 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/ddpm/pipeline_ddpm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | 20 | from ...utils import randn_tensor 21 | from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput 22 | 23 | 24 | class DDPMPipeline(DiffusionPipeline): 25 | r""" 26 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 27 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 28 | 29 | Parameters: 30 | unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. 31 | scheduler ([`SchedulerMixin`]): 32 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of 33 | [`DDPMScheduler`], or [`DDIMScheduler`]. 34 | """ 35 | 36 | def __init__(self, unet, scheduler): 37 | super().__init__() 38 | self.register_modules(unet=unet, scheduler=scheduler) 39 | 40 | @torch.no_grad() 41 | def __call__( 42 | self, 43 | batch_size: int = 1, 44 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 45 | num_inference_steps: int = 1000, 46 | output_type: Optional[str] = "pil", 47 | return_dict: bool = True, 48 | ) -> Union[ImagePipelineOutput, Tuple]: 49 | r""" 50 | Args: 51 | batch_size (`int`, *optional*, defaults to 1): 52 | The number of images to generate. 53 | generator (`torch.Generator`, *optional*): 54 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 55 | to make generation deterministic. 56 | num_inference_steps (`int`, *optional*, defaults to 1000): 57 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 58 | expense of slower inference. 59 | output_type (`str`, *optional*, defaults to `"pil"`): 60 | The output format of the generate image. Choose between 61 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 62 | return_dict (`bool`, *optional*, defaults to `True`): 63 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. 64 | 65 | Returns: 66 | [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is 67 | True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. 68 | """ 69 | # Sample gaussian noise to begin loop 70 | if isinstance(self.unet.config.sample_size, int): 71 | image_shape = ( 72 | batch_size, 73 | self.unet.config.in_channels, 74 | self.unet.config.sample_size, 75 | self.unet.config.sample_size, 76 | ) 77 | else: 78 | image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) 79 | 80 | if self.device.type == "mps": 81 | # randn does not work reproducibly on mps 82 | image = randn_tensor(image_shape, generator=generator) 83 | image = image.to(self.device) 84 | else: 85 | image = randn_tensor(image_shape, generator=generator, device=self.device) 86 | 87 | # set step values 88 | self.scheduler.set_timesteps(num_inference_steps) 89 | 90 | for t in self.progress_bar(self.scheduler.timesteps): 91 | # 1. predict noise model_output 92 | model_output = self.unet(image, t).sample 93 | 94 | # 2. compute previous image: x_t -> x_t-1 95 | image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample 96 | 97 | image = (image / 2 + 0.5).clamp(0, 1) 98 | image = image.cpu().permute(0, 2, 3, 1).numpy() 99 | if output_type == "pil": 100 | image = self.numpy_to_pil(image) 101 | 102 | if not return_dict: 103 | return (image,) 104 | 105 | return ImagePipelineOutput(images=image) 106 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/dit/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_dit import DiTPipeline 2 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/latent_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils import is_transformers_available 2 | from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline 3 | 4 | 5 | if is_transformers_available(): 6 | from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline 7 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/latent_diffusion_uncond/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_latent_diffusion_uncond import LDMPipeline 2 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | 20 | from ...models import UNet2DModel, VQModel 21 | from ...schedulers import DDIMScheduler 22 | from ...utils import randn_tensor 23 | from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput 24 | 25 | 26 | class LDMPipeline(DiffusionPipeline): 27 | r""" 28 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 29 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 30 | 31 | Parameters: 32 | vqvae ([`VQModel`]): 33 | Vector-quantized (VQ) Model to encode and decode images to and from latent representations. 34 | unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents. 35 | scheduler ([`SchedulerMixin`]): 36 | [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents. 37 | """ 38 | 39 | def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): 40 | super().__init__() 41 | self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) 42 | 43 | @torch.no_grad() 44 | def __call__( 45 | self, 46 | batch_size: int = 1, 47 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 48 | eta: float = 0.0, 49 | num_inference_steps: int = 50, 50 | output_type: Optional[str] = "pil", 51 | return_dict: bool = True, 52 | **kwargs, 53 | ) -> Union[Tuple, ImagePipelineOutput]: 54 | r""" 55 | Args: 56 | batch_size (`int`, *optional*, defaults to 1): 57 | Number of images to generate. 58 | generator (`torch.Generator`, *optional*): 59 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 60 | to make generation deterministic. 61 | num_inference_steps (`int`, *optional*, defaults to 50): 62 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 63 | expense of slower inference. 64 | output_type (`str`, *optional*, defaults to `"pil"`): 65 | The output format of the generate image. Choose between 66 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 67 | return_dict (`bool`, *optional*, defaults to `True`): 68 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. 69 | 70 | Returns: 71 | [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is 72 | True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. 73 | """ 74 | 75 | latents = randn_tensor( 76 | (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), 77 | generator=generator, 78 | ) 79 | latents = latents.to(self.device) 80 | 81 | # scale the initial noise by the standard deviation required by the scheduler 82 | latents = latents * self.scheduler.init_noise_sigma 83 | 84 | self.scheduler.set_timesteps(num_inference_steps) 85 | 86 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 87 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 88 | 89 | extra_kwargs = {} 90 | if accepts_eta: 91 | extra_kwargs["eta"] = eta 92 | 93 | for t in self.progress_bar(self.scheduler.timesteps): 94 | latent_model_input = self.scheduler.scale_model_input(latents, t) 95 | # predict the noise residual 96 | noise_prediction = self.unet(latent_model_input, t).sample 97 | # compute the previous noisy sample x_t -> x_t-1 98 | latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample 99 | 100 | # decode the image latents with the VAE 101 | image = self.vqvae.decode(latents).sample 102 | 103 | image = (image / 2 + 0.5).clamp(0, 1) 104 | image = image.cpu().permute(0, 2, 3, 1).numpy() 105 | if output_type == "pil": 106 | image = self.numpy_to_pil(image) 107 | 108 | if not return_dict: 109 | return (image,) 110 | 111 | return ImagePipelineOutput(images=image) 112 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/paint_by_example/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | import PIL 6 | from PIL import Image 7 | 8 | from ...utils import is_torch_available, is_transformers_available 9 | 10 | 11 | if is_transformers_available() and is_torch_available(): 12 | from .image_encoder import PaintByExampleImageEncoder 13 | from .pipeline_paint_by_example import PaintByExamplePipeline 14 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/paint_by_example/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from torch import nn 16 | from transformers import CLIPPreTrainedModel, CLIPVisionModel 17 | 18 | from ...models.attention import BasicTransformerBlock 19 | from ...utils import logging 20 | 21 | 22 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 23 | 24 | 25 | class PaintByExampleImageEncoder(CLIPPreTrainedModel): 26 | def __init__(self, config, proj_size=768): 27 | super().__init__(config) 28 | self.proj_size = proj_size 29 | 30 | self.model = CLIPVisionModel(config) 31 | self.mapper = PaintByExampleMapper(config) 32 | self.final_layer_norm = nn.LayerNorm(config.hidden_size) 33 | self.proj_out = nn.Linear(config.hidden_size, self.proj_size) 34 | 35 | # uncondition for scaling 36 | self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size))) 37 | 38 | def forward(self, pixel_values, return_uncond_vector=False): 39 | clip_output = self.model(pixel_values=pixel_values) 40 | latent_states = clip_output.pooler_output 41 | latent_states = self.mapper(latent_states[:, None]) 42 | latent_states = self.final_layer_norm(latent_states) 43 | latent_states = self.proj_out(latent_states) 44 | if return_uncond_vector: 45 | return latent_states, self.uncond_vector 46 | 47 | return latent_states 48 | 49 | 50 | class PaintByExampleMapper(nn.Module): 51 | def __init__(self, config): 52 | super().__init__() 53 | num_layers = (config.num_hidden_layers + 1) // 5 54 | hid_size = config.hidden_size 55 | num_heads = 1 56 | self.blocks = nn.ModuleList( 57 | [ 58 | BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True) 59 | for _ in range(num_layers) 60 | ] 61 | ) 62 | 63 | def forward(self, hidden_states): 64 | for block in self.blocks: 65 | hidden_states = block(hidden_states) 66 | 67 | return hidden_states 68 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/pndm/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_pndm import PNDMPipeline 2 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/pndm/pipeline_pndm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | 20 | from ...models import UNet2DModel 21 | from ...schedulers import PNDMScheduler 22 | from ...utils import randn_tensor 23 | from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput 24 | 25 | 26 | class PNDMPipeline(DiffusionPipeline): 27 | r""" 28 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 29 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 30 | 31 | Parameters: 32 | unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents. 33 | scheduler ([`SchedulerMixin`]): 34 | The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image. 35 | """ 36 | 37 | unet: UNet2DModel 38 | scheduler: PNDMScheduler 39 | 40 | def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): 41 | super().__init__() 42 | 43 | scheduler = PNDMScheduler.from_config(scheduler.config) 44 | 45 | self.register_modules(unet=unet, scheduler=scheduler) 46 | 47 | @torch.no_grad() 48 | def __call__( 49 | self, 50 | batch_size: int = 1, 51 | num_inference_steps: int = 50, 52 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 53 | output_type: Optional[str] = "pil", 54 | return_dict: bool = True, 55 | **kwargs, 56 | ) -> Union[ImagePipelineOutput, Tuple]: 57 | r""" 58 | Args: 59 | batch_size (`int`, `optional`, defaults to 1): The number of images to generate. 60 | num_inference_steps (`int`, `optional`, defaults to 50): 61 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 62 | expense of slower inference. 63 | generator (`torch.Generator`, `optional`): A [torch 64 | generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 65 | deterministic. 66 | output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose 67 | between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 68 | return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a 69 | [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. 70 | 71 | Returns: 72 | [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is 73 | True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. 74 | """ 75 | # For more information on the sampling method you can take a look at Algorithm 2 of 76 | # the official paper: https://arxiv.org/pdf/2202.09778.pdf 77 | 78 | # Sample gaussian noise to begin loop 79 | image = randn_tensor( 80 | (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), 81 | generator=generator, 82 | device=self.device, 83 | ) 84 | 85 | self.scheduler.set_timesteps(num_inference_steps) 86 | for t in self.progress_bar(self.scheduler.timesteps): 87 | model_output = self.unet(image, t).sample 88 | 89 | image = self.scheduler.step(model_output, t, image).prev_sample 90 | 91 | image = (image / 2 + 0.5).clamp(0, 1) 92 | image = image.cpu().permute(0, 2, 3, 1).numpy() 93 | if output_type == "pil": 94 | image = self.numpy_to_pil(image) 95 | 96 | if not return_dict: 97 | return (image,) 98 | 99 | return ImagePipelineOutput(images=image) 100 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/repaint/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_repaint import RePaintPipeline 2 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/score_sde_ve/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_score_sde_ve import ScoreSdeVePipeline 2 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/semantic_stable_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from typing import List, Optional, Union 4 | 5 | import numpy as np 6 | import PIL 7 | from PIL import Image 8 | 9 | from ...utils import BaseOutput, is_torch_available, is_transformers_available 10 | 11 | 12 | @dataclass 13 | class SemanticStableDiffusionPipelineOutput(BaseOutput): 14 | """ 15 | Output class for Stable Diffusion pipelines. 16 | 17 | Args: 18 | images (`List[PIL.Image.Image]` or `np.ndarray`) 19 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, 20 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. 21 | nsfw_content_detected (`List[bool]`) 22 | List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" 23 | (nsfw) content, or `None` if safety checking could not be performed. 24 | """ 25 | 26 | images: Union[List[PIL.Image.Image], np.ndarray] 27 | nsfw_content_detected: Optional[List[bool]] 28 | 29 | 30 | if is_transformers_available() and is_torch_available(): 31 | from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline 32 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/spectrogram_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from ...utils import is_note_seq_available, is_transformers_available, is_torch_available 3 | from ...utils import OptionalDependencyNotAvailable 4 | 5 | 6 | try: 7 | if not (is_transformers_available() and is_torch_available()): 8 | raise OptionalDependencyNotAvailable() 9 | except OptionalDependencyNotAvailable: 10 | from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 11 | else: 12 | from .notes_encoder import SpectrogramNotesEncoder 13 | from .continous_encoder import SpectrogramContEncoder 14 | from .pipeline_spectrogram_diffusion import ( 15 | SpectrogramContEncoder, 16 | SpectrogramDiffusionPipeline, 17 | T5FilmDecoder, 18 | ) 19 | 20 | try: 21 | if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): 22 | raise OptionalDependencyNotAvailable() 23 | except OptionalDependencyNotAvailable: 24 | from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 25 | else: 26 | from .midi_utils import MidiProcessor 27 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/spectrogram_diffusion/continous_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Music Spectrogram Diffusion Authors. 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn as nn 18 | from transformers.modeling_utils import ModuleUtilsMixin 19 | from transformers.models.t5.modeling_t5 import ( 20 | T5Block, 21 | T5Config, 22 | T5LayerNorm, 23 | ) 24 | 25 | from ...configuration_utils import ConfigMixin, register_to_config 26 | from ...models import ModelMixin 27 | 28 | 29 | class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): 30 | @register_to_config 31 | def __init__( 32 | self, 33 | input_dims: int, 34 | targets_context_length: int, 35 | d_model: int, 36 | dropout_rate: float, 37 | num_layers: int, 38 | num_heads: int, 39 | d_kv: int, 40 | d_ff: int, 41 | feed_forward_proj: str, 42 | is_decoder: bool = False, 43 | ): 44 | super().__init__() 45 | 46 | self.input_proj = nn.Linear(input_dims, d_model, bias=False) 47 | 48 | self.position_encoding = nn.Embedding(targets_context_length, d_model) 49 | self.position_encoding.weight.requires_grad = False 50 | 51 | self.dropout_pre = nn.Dropout(p=dropout_rate) 52 | 53 | t5config = T5Config( 54 | d_model=d_model, 55 | num_heads=num_heads, 56 | d_kv=d_kv, 57 | d_ff=d_ff, 58 | feed_forward_proj=feed_forward_proj, 59 | dropout_rate=dropout_rate, 60 | is_decoder=is_decoder, 61 | is_encoder_decoder=False, 62 | ) 63 | self.encoders = nn.ModuleList() 64 | for lyr_num in range(num_layers): 65 | lyr = T5Block(t5config) 66 | self.encoders.append(lyr) 67 | 68 | self.layer_norm = T5LayerNorm(d_model) 69 | self.dropout_post = nn.Dropout(p=dropout_rate) 70 | 71 | def forward(self, encoder_inputs, encoder_inputs_mask): 72 | x = self.input_proj(encoder_inputs) 73 | 74 | # terminal relative positional encodings 75 | max_positions = encoder_inputs.shape[1] 76 | input_positions = torch.arange(max_positions, device=encoder_inputs.device) 77 | 78 | seq_lens = encoder_inputs_mask.sum(-1) 79 | input_positions = torch.roll(input_positions.unsqueeze(0), tuple(seq_lens.tolist()), dims=0) 80 | x += self.position_encoding(input_positions) 81 | 82 | x = self.dropout_pre(x) 83 | 84 | # inverted the attention mask 85 | input_shape = encoder_inputs.size() 86 | extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) 87 | 88 | for lyr in self.encoders: 89 | x = lyr(x, extended_attention_mask)[0] 90 | x = self.layer_norm(x) 91 | 92 | return self.dropout_post(x), encoder_inputs_mask 93 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/spectrogram_diffusion/notes_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Music Spectrogram Diffusion Authors. 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn as nn 18 | from transformers.modeling_utils import ModuleUtilsMixin 19 | from transformers.models.t5.modeling_t5 import T5Block, T5Config, T5LayerNorm 20 | 21 | from ...configuration_utils import ConfigMixin, register_to_config 22 | from ...models import ModelMixin 23 | 24 | 25 | class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): 26 | @register_to_config 27 | def __init__( 28 | self, 29 | max_length: int, 30 | vocab_size: int, 31 | d_model: int, 32 | dropout_rate: float, 33 | num_layers: int, 34 | num_heads: int, 35 | d_kv: int, 36 | d_ff: int, 37 | feed_forward_proj: str, 38 | is_decoder: bool = False, 39 | ): 40 | super().__init__() 41 | 42 | self.token_embedder = nn.Embedding(vocab_size, d_model) 43 | 44 | self.position_encoding = nn.Embedding(max_length, d_model) 45 | self.position_encoding.weight.requires_grad = False 46 | 47 | self.dropout_pre = nn.Dropout(p=dropout_rate) 48 | 49 | t5config = T5Config( 50 | vocab_size=vocab_size, 51 | d_model=d_model, 52 | num_heads=num_heads, 53 | d_kv=d_kv, 54 | d_ff=d_ff, 55 | dropout_rate=dropout_rate, 56 | feed_forward_proj=feed_forward_proj, 57 | is_decoder=is_decoder, 58 | is_encoder_decoder=False, 59 | ) 60 | 61 | self.encoders = nn.ModuleList() 62 | for lyr_num in range(num_layers): 63 | lyr = T5Block(t5config) 64 | self.encoders.append(lyr) 65 | 66 | self.layer_norm = T5LayerNorm(d_model) 67 | self.dropout_post = nn.Dropout(p=dropout_rate) 68 | 69 | def forward(self, encoder_input_tokens, encoder_inputs_mask): 70 | x = self.token_embedder(encoder_input_tokens) 71 | 72 | seq_length = encoder_input_tokens.shape[1] 73 | inputs_positions = torch.arange(seq_length, device=encoder_input_tokens.device) 74 | x += self.position_encoding(inputs_positions) 75 | 76 | x = self.dropout_pre(x) 77 | 78 | # inverted the attention mask 79 | input_shape = encoder_input_tokens.size() 80 | extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) 81 | 82 | for lyr in self.encoders: 83 | x = lyr(x, extended_attention_mask)[0] 84 | x = self.layer_norm(x) 85 | 86 | return self.dropout_post(x), encoder_inputs_mask 87 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/stable_diffusion/safety_checker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel 19 | 20 | from ...utils import logging 21 | 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | 26 | def cosine_distance(image_embeds, text_embeds): 27 | normalized_image_embeds = nn.functional.normalize(image_embeds) 28 | normalized_text_embeds = nn.functional.normalize(text_embeds) 29 | return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) 30 | 31 | 32 | class StableDiffusionSafetyChecker(PreTrainedModel): 33 | config_class = CLIPConfig 34 | 35 | _no_split_modules = ["CLIPEncoderLayer"] 36 | 37 | def __init__(self, config: CLIPConfig): 38 | super().__init__(config) 39 | 40 | self.vision_model = CLIPVisionModel(config.vision_config) 41 | self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) 42 | 43 | self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) 44 | self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) 45 | 46 | self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) 47 | self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) 48 | 49 | @torch.no_grad() 50 | def forward(self, clip_input, images): 51 | pooled_output = self.vision_model(clip_input)[1] # pooled_output 52 | image_embeds = self.visual_projection(pooled_output) 53 | 54 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 55 | special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() 56 | cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() 57 | 58 | result = [] 59 | batch_size = image_embeds.shape[0] 60 | for i in range(batch_size): 61 | result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} 62 | 63 | # increase this value to create a stronger `nfsw` filter 64 | # at the cost of increasing the possibility of filtering benign images 65 | adjustment = 0.0 66 | 67 | for concept_idx in range(len(special_cos_dist[0])): 68 | concept_cos = special_cos_dist[i][concept_idx] 69 | concept_threshold = self.special_care_embeds_weights[concept_idx].item() 70 | result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) 71 | if result_img["special_scores"][concept_idx] > 0: 72 | result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) 73 | adjustment = 0.01 74 | 75 | for concept_idx in range(len(cos_dist[0])): 76 | concept_cos = cos_dist[i][concept_idx] 77 | concept_threshold = self.concept_embeds_weights[concept_idx].item() 78 | result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) 79 | if result_img["concept_scores"][concept_idx] > 0: 80 | result_img["bad_concepts"].append(concept_idx) 81 | 82 | result.append(result_img) 83 | 84 | has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] 85 | 86 | for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): 87 | if has_nsfw_concept: 88 | images[idx] = np.zeros(images[idx].shape) # black image 89 | 90 | if any(has_nsfw_concepts): 91 | logger.warning( 92 | "Potential NSFW content was detected in one or more images. A black image will be returned instead." 93 | " Try again with a different prompt and/or seed." 94 | ) 95 | 96 | return images, has_nsfw_concepts 97 | 98 | @torch.no_grad() 99 | def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): 100 | pooled_output = self.vision_model(clip_input)[1] # pooled_output 101 | image_embeds = self.visual_projection(pooled_output) 102 | 103 | special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) 104 | cos_dist = cosine_distance(image_embeds, self.concept_embeds) 105 | 106 | # increase this value to create a stronger `nsfw` filter 107 | # at the cost of increasing the possibility of filtering benign images 108 | adjustment = 0.0 109 | 110 | special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment 111 | # special_scores = special_scores.round(decimals=3) 112 | special_care = torch.any(special_scores > 0, dim=1) 113 | special_adjustment = special_care * 0.01 114 | special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) 115 | 116 | concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment 117 | # concept_scores = concept_scores.round(decimals=3) 118 | has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) 119 | 120 | images[has_nsfw_concepts] = 0.0 # black image 121 | 122 | return images, has_nsfw_concepts 123 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/stable_diffusion/safety_checker_flax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional, Tuple 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | from flax import linen as nn 20 | from flax.core.frozen_dict import FrozenDict 21 | from transformers import CLIPConfig, FlaxPreTrainedModel 22 | from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule 23 | 24 | 25 | def jax_cosine_distance(emb_1, emb_2, eps=1e-12): 26 | norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T 27 | norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T 28 | return jnp.matmul(norm_emb_1, norm_emb_2.T) 29 | 30 | 31 | class FlaxStableDiffusionSafetyCheckerModule(nn.Module): 32 | config: CLIPConfig 33 | dtype: jnp.dtype = jnp.float32 34 | 35 | def setup(self): 36 | self.vision_model = FlaxCLIPVisionModule(self.config.vision_config) 37 | self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype) 38 | 39 | self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim)) 40 | self.special_care_embeds = self.param( 41 | "special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim) 42 | ) 43 | 44 | self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) 45 | self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,)) 46 | 47 | def __call__(self, clip_input): 48 | pooled_output = self.vision_model(clip_input)[1] 49 | image_embeds = self.visual_projection(pooled_output) 50 | 51 | special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) 52 | cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) 53 | 54 | # increase this value to create a stronger `nfsw` filter 55 | # at the cost of increasing the possibility of filtering benign image inputs 56 | adjustment = 0.0 57 | 58 | special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment 59 | special_scores = jnp.round(special_scores, 3) 60 | is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True) 61 | # Use a lower threshold if an image has any special care concept 62 | special_adjustment = is_special_care * 0.01 63 | 64 | concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment 65 | concept_scores = jnp.round(concept_scores, 3) 66 | has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1) 67 | 68 | return has_nsfw_concepts 69 | 70 | 71 | class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): 72 | config_class = CLIPConfig 73 | main_input_name = "clip_input" 74 | module_class = FlaxStableDiffusionSafetyCheckerModule 75 | 76 | def __init__( 77 | self, 78 | config: CLIPConfig, 79 | input_shape: Optional[Tuple] = None, 80 | seed: int = 0, 81 | dtype: jnp.dtype = jnp.float32, 82 | _do_init: bool = True, 83 | **kwargs, 84 | ): 85 | if input_shape is None: 86 | input_shape = (1, 224, 224, 3) 87 | module = self.module_class(config=config, dtype=dtype, **kwargs) 88 | super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) 89 | 90 | def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: 91 | # init input tensor 92 | clip_input = jax.random.normal(rng, input_shape) 93 | 94 | params_rng, dropout_rng = jax.random.split(rng) 95 | rngs = {"params": params_rng, "dropout": dropout_rng} 96 | 97 | random_params = self.module.init(rngs, clip_input)["params"] 98 | 99 | return random_params 100 | 101 | def __call__( 102 | self, 103 | clip_input, 104 | params: dict = None, 105 | ): 106 | clip_input = jnp.transpose(clip_input, (0, 2, 3, 1)) 107 | 108 | return self.module.apply( 109 | {"params": params or self.params}, 110 | jnp.array(clip_input, dtype=jnp.float32), 111 | rngs={}, 112 | ) 113 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional, Union 16 | 17 | import torch 18 | from torch import nn 19 | 20 | from ...configuration_utils import ConfigMixin, register_to_config 21 | from ...models.modeling_utils import ModelMixin 22 | 23 | 24 | class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): 25 | """ 26 | This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. 27 | 28 | It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image 29 | embeddings. 30 | """ 31 | 32 | @register_to_config 33 | def __init__( 34 | self, 35 | embedding_dim: int = 768, 36 | ): 37 | super().__init__() 38 | 39 | self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) 40 | self.std = nn.Parameter(torch.ones(1, embedding_dim)) 41 | 42 | def to( 43 | self, 44 | torch_device: Optional[Union[str, torch.device]] = None, 45 | torch_dtype: Optional[torch.dtype] = None, 46 | ): 47 | self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype)) 48 | self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype)) 49 | return self 50 | 51 | def scale(self, embeds): 52 | embeds = (embeds - self.mean) * 1.0 / self.std 53 | return embeds 54 | 55 | def unscale(self, embeds): 56 | embeds = (embeds * self.std) + self.mean 57 | return embeds 58 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/stable_diffusion_safe/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from typing import List, Optional, Union 4 | 5 | import numpy as np 6 | import PIL 7 | from PIL import Image 8 | 9 | from ...utils import BaseOutput, is_torch_available, is_transformers_available 10 | 11 | 12 | @dataclass 13 | class SafetyConfig(object): 14 | WEAK = { 15 | "sld_warmup_steps": 15, 16 | "sld_guidance_scale": 20, 17 | "sld_threshold": 0.0, 18 | "sld_momentum_scale": 0.0, 19 | "sld_mom_beta": 0.0, 20 | } 21 | MEDIUM = { 22 | "sld_warmup_steps": 10, 23 | "sld_guidance_scale": 1000, 24 | "sld_threshold": 0.01, 25 | "sld_momentum_scale": 0.3, 26 | "sld_mom_beta": 0.4, 27 | } 28 | STRONG = { 29 | "sld_warmup_steps": 7, 30 | "sld_guidance_scale": 2000, 31 | "sld_threshold": 0.025, 32 | "sld_momentum_scale": 0.5, 33 | "sld_mom_beta": 0.7, 34 | } 35 | MAX = { 36 | "sld_warmup_steps": 0, 37 | "sld_guidance_scale": 5000, 38 | "sld_threshold": 1.0, 39 | "sld_momentum_scale": 0.5, 40 | "sld_mom_beta": 0.7, 41 | } 42 | 43 | 44 | @dataclass 45 | class StableDiffusionSafePipelineOutput(BaseOutput): 46 | """ 47 | Output class for Safe Stable Diffusion pipelines. 48 | 49 | Args: 50 | images (`List[PIL.Image.Image]` or `np.ndarray`) 51 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, 52 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. 53 | nsfw_content_detected (`List[bool]`) 54 | List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" 55 | (nsfw) content, or `None` if safety checking could not be performed. 56 | images (`List[PIL.Image.Image]` or `np.ndarray`) 57 | List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work" 58 | (nsfw) content, or `None` if no safety check was performed or no images were flagged. 59 | applied_safety_concept (`str`) 60 | The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled 61 | """ 62 | 63 | images: Union[List[PIL.Image.Image], np.ndarray] 64 | nsfw_content_detected: Optional[List[bool]] 65 | unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]] 66 | applied_safety_concept: Optional[str] 67 | 68 | 69 | if is_transformers_available() and is_torch_available(): 70 | from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe 71 | from .safety_checker import SafeStableDiffusionSafetyChecker 72 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/stable_diffusion_safe/safety_checker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel 18 | 19 | from ...utils import logging 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | 25 | def cosine_distance(image_embeds, text_embeds): 26 | normalized_image_embeds = nn.functional.normalize(image_embeds) 27 | normalized_text_embeds = nn.functional.normalize(text_embeds) 28 | return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) 29 | 30 | 31 | class SafeStableDiffusionSafetyChecker(PreTrainedModel): 32 | config_class = CLIPConfig 33 | 34 | _no_split_modules = ["CLIPEncoderLayer"] 35 | 36 | def __init__(self, config: CLIPConfig): 37 | super().__init__(config) 38 | 39 | self.vision_model = CLIPVisionModel(config.vision_config) 40 | self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) 41 | 42 | self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) 43 | self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) 44 | 45 | self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) 46 | self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) 47 | 48 | @torch.no_grad() 49 | def forward(self, clip_input, images): 50 | pooled_output = self.vision_model(clip_input)[1] # pooled_output 51 | image_embeds = self.visual_projection(pooled_output) 52 | 53 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 54 | special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() 55 | cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() 56 | 57 | result = [] 58 | batch_size = image_embeds.shape[0] 59 | for i in range(batch_size): 60 | result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} 61 | 62 | # increase this value to create a stronger `nfsw` filter 63 | # at the cost of increasing the possibility of filtering benign images 64 | adjustment = 0.0 65 | 66 | for concept_idx in range(len(special_cos_dist[0])): 67 | concept_cos = special_cos_dist[i][concept_idx] 68 | concept_threshold = self.special_care_embeds_weights[concept_idx].item() 69 | result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) 70 | if result_img["special_scores"][concept_idx] > 0: 71 | result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) 72 | adjustment = 0.01 73 | 74 | for concept_idx in range(len(cos_dist[0])): 75 | concept_cos = cos_dist[i][concept_idx] 76 | concept_threshold = self.concept_embeds_weights[concept_idx].item() 77 | result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) 78 | if result_img["concept_scores"][concept_idx] > 0: 79 | result_img["bad_concepts"].append(concept_idx) 80 | 81 | result.append(result_img) 82 | 83 | has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] 84 | 85 | return images, has_nsfw_concepts 86 | 87 | @torch.no_grad() 88 | def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): 89 | pooled_output = self.vision_model(clip_input)[1] # pooled_output 90 | image_embeds = self.visual_projection(pooled_output) 91 | 92 | special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) 93 | cos_dist = cosine_distance(image_embeds, self.concept_embeds) 94 | 95 | # increase this value to create a stronger `nsfw` filter 96 | # at the cost of increasing the possibility of filtering benign images 97 | adjustment = 0.0 98 | 99 | special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment 100 | # special_scores = special_scores.round(decimals=3) 101 | special_care = torch.any(special_scores > 0, dim=1) 102 | special_adjustment = special_care * 0.01 103 | special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) 104 | 105 | concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment 106 | # concept_scores = concept_scores.round(decimals=3) 107 | has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) 108 | 109 | return images, has_nsfw_concepts 110 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/stochastic_karras_ve/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_stochastic_karras_ve import KarrasVePipeline 2 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, Optional, Tuple, Union 16 | 17 | import torch 18 | 19 | from ...models import UNet2DModel 20 | from ...schedulers import KarrasVeScheduler 21 | from ...utils import randn_tensor 22 | from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput 23 | 24 | 25 | class KarrasVePipeline(DiffusionPipeline): 26 | r""" 27 | Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and 28 | the VE column of Table 1 from [1] for reference. 29 | 30 | [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." 31 | https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic 32 | differential equations." https://arxiv.org/abs/2011.13456 33 | 34 | Parameters: 35 | unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. 36 | scheduler ([`KarrasVeScheduler`]): 37 | Scheduler for the diffusion process to be used in combination with `unet` to denoise the encoded image. 38 | """ 39 | 40 | # add type hints for linting 41 | unet: UNet2DModel 42 | scheduler: KarrasVeScheduler 43 | 44 | def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): 45 | super().__init__() 46 | self.register_modules(unet=unet, scheduler=scheduler) 47 | 48 | @torch.no_grad() 49 | def __call__( 50 | self, 51 | batch_size: int = 1, 52 | num_inference_steps: int = 50, 53 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 54 | output_type: Optional[str] = "pil", 55 | return_dict: bool = True, 56 | **kwargs, 57 | ) -> Union[Tuple, ImagePipelineOutput]: 58 | r""" 59 | Args: 60 | batch_size (`int`, *optional*, defaults to 1): 61 | The number of images to generate. 62 | generator (`torch.Generator`, *optional*): 63 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 64 | to make generation deterministic. 65 | num_inference_steps (`int`, *optional*, defaults to 50): 66 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 67 | expense of slower inference. 68 | output_type (`str`, *optional*, defaults to `"pil"`): 69 | The output format of the generate image. Choose between 70 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 71 | return_dict (`bool`, *optional*, defaults to `True`): 72 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. 73 | 74 | Returns: 75 | [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is 76 | True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. 77 | """ 78 | 79 | img_size = self.unet.config.sample_size 80 | shape = (batch_size, 3, img_size, img_size) 81 | 82 | model = self.unet 83 | 84 | # sample x_0 ~ N(0, sigma_0^2 * I) 85 | sample = randn_tensor(shape, generator=generator, device=self.device) * self.scheduler.init_noise_sigma 86 | 87 | self.scheduler.set_timesteps(num_inference_steps) 88 | 89 | for t in self.progress_bar(self.scheduler.timesteps): 90 | # here sigma_t == t_i from the paper 91 | sigma = self.scheduler.schedule[t] 92 | sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 93 | 94 | # 1. Select temporarily increased noise level sigma_hat 95 | # 2. Add new noise to move from sample_i to sample_hat 96 | sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator) 97 | 98 | # 3. Predict the noise residual given the noise magnitude `sigma_hat` 99 | # The model inputs and output are adjusted by following eq. (213) in [1]. 100 | model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample 101 | 102 | # 4. Evaluate dx/dt at sigma_hat 103 | # 5. Take Euler step from sigma to sigma_prev 104 | step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat) 105 | 106 | if sigma_prev != 0: 107 | # 6. Apply 2nd order correction 108 | # The model inputs and output are adjusted by following eq. (213) in [1]. 109 | model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample 110 | step_output = self.scheduler.step_correct( 111 | model_output, 112 | sigma_hat, 113 | sigma_prev, 114 | sample_hat, 115 | step_output.prev_sample, 116 | step_output["derivative"], 117 | ) 118 | sample = step_output.prev_sample 119 | 120 | sample = (sample / 2 + 0.5).clamp(0, 1) 121 | image = sample.cpu().permute(0, 2, 3, 1).numpy() 122 | if output_type == "pil": 123 | image = self.numpy_to_pil(image) 124 | 125 | if not return_dict: 126 | return (image,) 127 | 128 | return ImagePipelineOutput(images=image) 129 | -------------------------------------------------------------------------------- /diffuser/diffusers/pipelines/text_to_video_synthesis/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available 8 | 9 | 10 | @dataclass 11 | class TextToVideoSDPipelineOutput(BaseOutput): 12 | """ 13 | Output class for text to video pipelines. 14 | 15 | Args: 16 | frames (`List[np.ndarray]` or `torch.FloatTensor`) 17 | List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as 18 | a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list 19 | denotes the video length i.e., the number of frames. 20 | """ 21 | 22 | frames: Union[List[np.ndarray], torch.FloatTensor] 23 | 24 | 25 | try: 26 | if not (is_transformers_available() and is_torch_available()): 27 | raise OptionalDependencyNotAvailable() 28 | except OptionalDependencyNotAvailable: 29 | from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 30 | else: 31 | from .pipeline_text_to_video_synth import TextToVideoSDPipeline # noqa: F401 32 | from .pipeline_text_to_video_zero import TextToVideoZeroPipeline 33 | -------------------------------------------------------------------------------- /diffuser/diffusers/schedulers/README.md: -------------------------------------------------------------------------------- 1 | # Schedulers 2 | 3 | For more information on the schedulers, please refer to the [docs](https://huggingface.co/docs/diffusers/api/schedulers/overview). -------------------------------------------------------------------------------- /diffuser/diffusers/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from ..utils import OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available 17 | 18 | 19 | try: 20 | if not is_torch_available(): 21 | raise OptionalDependencyNotAvailable() 22 | except OptionalDependencyNotAvailable: 23 | from ..utils.dummy_pt_objects import * # noqa F403 24 | else: 25 | from .scheduling_ddim import DDIMScheduler 26 | from .scheduling_sde import SDEScheduler 27 | from .scheduling_ddim_inverse import DDIMInverseScheduler 28 | from .scheduling_ddpm import DDPMScheduler 29 | from .scheduling_deis_multistep import DEISMultistepScheduler 30 | from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler 31 | from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler 32 | from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler 33 | from .scheduling_euler_discrete import EulerDiscreteScheduler 34 | from .scheduling_heun_discrete import HeunDiscreteScheduler 35 | from .scheduling_ipndm import IPNDMScheduler 36 | from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler 37 | from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler 38 | from .scheduling_karras_ve import KarrasVeScheduler 39 | from .scheduling_pndm import PNDMScheduler 40 | from .scheduling_repaint import RePaintScheduler 41 | from .scheduling_sde_ve import ScoreSdeVeScheduler 42 | from .scheduling_sde_vp import ScoreSdeVpScheduler 43 | from .scheduling_unclip import UnCLIPScheduler 44 | from .scheduling_unipc_multistep import UniPCMultistepScheduler 45 | from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin 46 | from .scheduling_vq_diffusion import VQDiffusionScheduler 47 | 48 | try: 49 | if not is_flax_available(): 50 | raise OptionalDependencyNotAvailable() 51 | except OptionalDependencyNotAvailable: 52 | from ..utils.dummy_flax_objects import * # noqa F403 53 | else: 54 | from .scheduling_ddim_flax import FlaxDDIMScheduler 55 | from .scheduling_ddpm_flax import FlaxDDPMScheduler 56 | from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler 57 | from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler 58 | from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler 59 | from .scheduling_pndm_flax import FlaxPNDMScheduler 60 | from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler 61 | from .scheduling_utils_flax import ( 62 | FlaxKarrasDiffusionSchedulers, 63 | FlaxSchedulerMixin, 64 | FlaxSchedulerOutput, 65 | broadcast_to_shape_from_left, 66 | ) 67 | 68 | 69 | try: 70 | if not (is_torch_available() and is_scipy_available()): 71 | raise OptionalDependencyNotAvailable() 72 | except OptionalDependencyNotAvailable: 73 | from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 74 | else: 75 | from .scheduling_lms_discrete import LMSDiscreteScheduler 76 | -------------------------------------------------------------------------------- /diffuser/diffusers/schedulers/scheduling_sde_vp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch 16 | 17 | import math 18 | from typing import Union 19 | 20 | import torch 21 | 22 | from ..configuration_utils import ConfigMixin, register_to_config 23 | from ..utils import randn_tensor 24 | from .scheduling_utils import SchedulerMixin 25 | 26 | 27 | class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): 28 | """ 29 | The variance preserving stochastic differential equation (SDE) scheduler. 30 | 31 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` 32 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. 33 | [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and 34 | [`~SchedulerMixin.from_pretrained`] functions. 35 | 36 | For more information, see the original paper: https://arxiv.org/abs/2011.13456 37 | 38 | UNDER CONSTRUCTION 39 | 40 | """ 41 | 42 | order = 1 43 | 44 | @register_to_config 45 | def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): 46 | self.sigmas = None 47 | self.discrete_sigmas = None 48 | self.timesteps = None 49 | 50 | def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None): 51 | self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device) 52 | 53 | def step_pred(self, score, x, t, generator=None): 54 | if self.timesteps is None: 55 | raise ValueError( 56 | "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" 57 | ) 58 | 59 | # TODO(Patrick) better comments + non-PyTorch 60 | # postprocess model score 61 | log_mean_coeff = ( 62 | -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min 63 | ) 64 | std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) 65 | std = std.flatten() 66 | while len(std.shape) < len(score.shape): 67 | std = std.unsqueeze(-1) 68 | score = -score / std 69 | 70 | # compute 71 | dt = -1.0 / len(self.timesteps) 72 | 73 | beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) 74 | beta_t = beta_t.flatten() 75 | while len(beta_t.shape) < len(x.shape): 76 | beta_t = beta_t.unsqueeze(-1) 77 | drift = -0.5 * beta_t * x 78 | 79 | diffusion = torch.sqrt(beta_t) 80 | drift = drift - diffusion**2 * score 81 | x_mean = x + drift * dt 82 | 83 | # add noise 84 | noise = randn_tensor(x.shape, layout=x.layout, generator=generator, device=x.device, dtype=x.dtype) 85 | x = x_mean + diffusion * math.sqrt(-dt) * noise 86 | 87 | return x, x_mean 88 | 89 | def __len__(self): 90 | return self.config.num_train_timesteps 91 | -------------------------------------------------------------------------------- /diffuser/diffusers/test.py: -------------------------------------------------------------------------------- 1 | # make sure you're logged in with \`huggingface-cli login\` 2 | from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler 3 | 4 | pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") 5 | pipe = pipe.to("cuda") 6 | 7 | prompt = "a photo of an astronaut riding a horse on mars" 8 | image = pipe(prompt).images[0] 9 | 10 | image.save("astronaut_rides_horse.png") -------------------------------------------------------------------------------- /diffuser/diffusers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | 18 | from packaging import version 19 | 20 | from .. import __version__ 21 | from .accelerate_utils import apply_forward_hook 22 | from .constants import ( 23 | CONFIG_NAME, 24 | DEPRECATED_REVISION_ARGS, 25 | DIFFUSERS_CACHE, 26 | DIFFUSERS_DYNAMIC_MODULE_NAME, 27 | FLAX_WEIGHTS_NAME, 28 | HF_MODULES_CACHE, 29 | HUGGINGFACE_CO_RESOLVE_ENDPOINT, 30 | ONNX_EXTERNAL_WEIGHTS_NAME, 31 | ONNX_WEIGHTS_NAME, 32 | SAFETENSORS_WEIGHTS_NAME, 33 | TEXT_ENCODER_TARGET_MODULES, 34 | WEIGHTS_NAME, 35 | ) 36 | from .deprecation_utils import deprecate 37 | from .doc_utils import replace_example_docstring 38 | from .dynamic_modules_utils import get_class_from_dynamic_module 39 | from .hub_utils import ( 40 | HF_HUB_OFFLINE, 41 | _add_variant, 42 | _get_model_file, 43 | extract_commit_hash, 44 | http_user_agent, 45 | ) 46 | from .import_utils import ( 47 | ENV_VARS_TRUE_AND_AUTO_VALUES, 48 | ENV_VARS_TRUE_VALUES, 49 | USE_JAX, 50 | USE_TF, 51 | USE_TORCH, 52 | DummyObject, 53 | OptionalDependencyNotAvailable, 54 | is_accelerate_available, 55 | is_accelerate_version, 56 | is_flax_available, 57 | is_inflect_available, 58 | is_k_diffusion_available, 59 | is_k_diffusion_version, 60 | is_librosa_available, 61 | is_note_seq_available, 62 | is_omegaconf_available, 63 | is_onnx_available, 64 | is_safetensors_available, 65 | is_scipy_available, 66 | is_tensorboard_available, 67 | is_tf_available, 68 | is_torch_available, 69 | is_torch_version, 70 | is_transformers_available, 71 | is_transformers_version, 72 | is_unidecode_available, 73 | is_wandb_available, 74 | is_xformers_available, 75 | requires_backends, 76 | ) 77 | from .logging import get_logger 78 | from .outputs import BaseOutput 79 | from .pil_utils import PIL_INTERPOLATION 80 | from .torch_utils import is_compiled_module, randn_tensor 81 | 82 | 83 | if is_torch_available(): 84 | from .testing_utils import ( 85 | floats_tensor, 86 | load_hf_numpy, 87 | load_image, 88 | load_numpy, 89 | load_pt, 90 | nightly, 91 | parse_flag_from_env, 92 | print_tensor_test, 93 | require_torch_2, 94 | require_torch_gpu, 95 | skip_mps, 96 | slow, 97 | torch_all_close, 98 | torch_device, 99 | ) 100 | 101 | from .testing_utils import export_to_video 102 | 103 | 104 | logger = get_logger(__name__) 105 | 106 | 107 | def check_min_version(min_version): 108 | if version.parse(__version__) < version.parse(min_version): 109 | if "dev" in min_version: 110 | error_message = ( 111 | "This example requires a source install from HuggingFace diffusers (see " 112 | "`https://huggingface.co/docs/diffusers/installation#install-from-source`)," 113 | ) 114 | else: 115 | error_message = f"This example requires a minimum version of {min_version}," 116 | error_message += f" but the version found is {__version__}.\n" 117 | raise ImportError(error_message) 118 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/accelerate_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Accelerate utilities: Utilities related to accelerate 16 | """ 17 | 18 | from packaging import version 19 | 20 | from .import_utils import is_accelerate_available 21 | 22 | 23 | if is_accelerate_available(): 24 | import accelerate 25 | 26 | 27 | def apply_forward_hook(method): 28 | """ 29 | Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful 30 | for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the 31 | appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`]. 32 | 33 | This decorator looks inside the internal `_hf_hook` property to find a registered offload hook. 34 | 35 | :param method: The method to decorate. This method should be a method of a PyTorch module. 36 | """ 37 | if not is_accelerate_available(): 38 | return method 39 | accelerate_version = version.parse(accelerate.__version__).base_version 40 | if version.parse(accelerate_version) < version.parse("0.17.0"): 41 | return method 42 | 43 | def wrapper(self, *args, **kwargs): 44 | if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"): 45 | self._hf_hook.pre_forward(self) 46 | return method(self, *args, **kwargs) 47 | 48 | return wrapper 49 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | 16 | from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home 17 | 18 | 19 | default_cache_path = HUGGINGFACE_HUB_CACHE 20 | 21 | 22 | CONFIG_NAME = "config.json" 23 | WEIGHTS_NAME = "diffusion_pytorch_model.bin" 24 | FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" 25 | ONNX_WEIGHTS_NAME = "model.onnx" 26 | SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" 27 | ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" 28 | HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" 29 | DIFFUSERS_CACHE = default_cache_path 30 | DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" 31 | HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) 32 | DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] 33 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"] 34 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/deprecation_utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from typing import Any, Dict, Optional, Union 4 | 5 | from packaging import version 6 | 7 | 8 | def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True): 9 | from .. import __version__ 10 | 11 | deprecated_kwargs = take_from 12 | values = () 13 | if not isinstance(args[0], tuple): 14 | args = (args,) 15 | 16 | for attribute, version_name, message in args: 17 | if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): 18 | raise ValueError( 19 | f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" 20 | f" version {__version__} is >= {version_name}" 21 | ) 22 | 23 | warning = None 24 | if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: 25 | values += (deprecated_kwargs.pop(attribute),) 26 | warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." 27 | elif hasattr(deprecated_kwargs, attribute): 28 | values += (getattr(deprecated_kwargs, attribute),) 29 | warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." 30 | elif deprecated_kwargs is None: 31 | warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." 32 | 33 | if warning is not None: 34 | warning = warning + " " if standard_warn else "" 35 | warnings.warn(warning + message, FutureWarning, stacklevel=2) 36 | 37 | if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: 38 | call_frame = inspect.getouterframes(inspect.currentframe())[1] 39 | filename = call_frame.filename 40 | line_number = call_frame.lineno 41 | function = call_frame.function 42 | key, value = next(iter(deprecated_kwargs.items())) 43 | raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") 44 | 45 | if len(values) == 0: 46 | return 47 | elif len(values) == 1: 48 | return values[0] 49 | return values 50 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/doc_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Doc utilities: Utilities related to documentation 16 | """ 17 | import re 18 | 19 | 20 | def replace_example_docstring(example_docstring): 21 | def docstring_decorator(fn): 22 | func_doc = fn.__doc__ 23 | lines = func_doc.split("\n") 24 | i = 0 25 | while i < len(lines) and re.search(r"^\s*Examples?:\s*$", lines[i]) is None: 26 | i += 1 27 | if i < len(lines): 28 | lines[i] = example_docstring 29 | func_doc = "\n".join(lines) 30 | else: 31 | raise ValueError( 32 | f"The function {fn} should have an empty 'Examples:' in its docstring as placeholder, " 33 | f"current docstring is:\n{func_doc}" 34 | ) 35 | fn.__doc__ = func_doc 36 | return fn 37 | 38 | return docstring_decorator 39 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/dummy_flax_and_transformers_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject): 6 | _backends = ["flax", "transformers"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["flax", "transformers"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["flax", "transformers"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["flax", "transformers"]) 18 | 19 | 20 | class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): 21 | _backends = ["flax", "transformers"] 22 | 23 | def __init__(self, *args, **kwargs): 24 | requires_backends(self, ["flax", "transformers"]) 25 | 26 | @classmethod 27 | def from_config(cls, *args, **kwargs): 28 | requires_backends(cls, ["flax", "transformers"]) 29 | 30 | @classmethod 31 | def from_pretrained(cls, *args, **kwargs): 32 | requires_backends(cls, ["flax", "transformers"]) 33 | 34 | 35 | class FlaxStableDiffusionInpaintPipeline(metaclass=DummyObject): 36 | _backends = ["flax", "transformers"] 37 | 38 | def __init__(self, *args, **kwargs): 39 | requires_backends(self, ["flax", "transformers"]) 40 | 41 | @classmethod 42 | def from_config(cls, *args, **kwargs): 43 | requires_backends(cls, ["flax", "transformers"]) 44 | 45 | @classmethod 46 | def from_pretrained(cls, *args, **kwargs): 47 | requires_backends(cls, ["flax", "transformers"]) 48 | 49 | 50 | class FlaxStableDiffusionPipeline(metaclass=DummyObject): 51 | _backends = ["flax", "transformers"] 52 | 53 | def __init__(self, *args, **kwargs): 54 | requires_backends(self, ["flax", "transformers"]) 55 | 56 | @classmethod 57 | def from_config(cls, *args, **kwargs): 58 | requires_backends(cls, ["flax", "transformers"]) 59 | 60 | @classmethod 61 | def from_pretrained(cls, *args, **kwargs): 62 | requires_backends(cls, ["flax", "transformers"]) 63 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/dummy_flax_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class FlaxControlNetModel(metaclass=DummyObject): 6 | _backends = ["flax"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["flax"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["flax"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["flax"]) 18 | 19 | 20 | class FlaxModelMixin(metaclass=DummyObject): 21 | _backends = ["flax"] 22 | 23 | def __init__(self, *args, **kwargs): 24 | requires_backends(self, ["flax"]) 25 | 26 | @classmethod 27 | def from_config(cls, *args, **kwargs): 28 | requires_backends(cls, ["flax"]) 29 | 30 | @classmethod 31 | def from_pretrained(cls, *args, **kwargs): 32 | requires_backends(cls, ["flax"]) 33 | 34 | 35 | class FlaxUNet2DConditionModel(metaclass=DummyObject): 36 | _backends = ["flax"] 37 | 38 | def __init__(self, *args, **kwargs): 39 | requires_backends(self, ["flax"]) 40 | 41 | @classmethod 42 | def from_config(cls, *args, **kwargs): 43 | requires_backends(cls, ["flax"]) 44 | 45 | @classmethod 46 | def from_pretrained(cls, *args, **kwargs): 47 | requires_backends(cls, ["flax"]) 48 | 49 | 50 | class FlaxAutoencoderKL(metaclass=DummyObject): 51 | _backends = ["flax"] 52 | 53 | def __init__(self, *args, **kwargs): 54 | requires_backends(self, ["flax"]) 55 | 56 | @classmethod 57 | def from_config(cls, *args, **kwargs): 58 | requires_backends(cls, ["flax"]) 59 | 60 | @classmethod 61 | def from_pretrained(cls, *args, **kwargs): 62 | requires_backends(cls, ["flax"]) 63 | 64 | 65 | class FlaxDiffusionPipeline(metaclass=DummyObject): 66 | _backends = ["flax"] 67 | 68 | def __init__(self, *args, **kwargs): 69 | requires_backends(self, ["flax"]) 70 | 71 | @classmethod 72 | def from_config(cls, *args, **kwargs): 73 | requires_backends(cls, ["flax"]) 74 | 75 | @classmethod 76 | def from_pretrained(cls, *args, **kwargs): 77 | requires_backends(cls, ["flax"]) 78 | 79 | 80 | class FlaxDDIMScheduler(metaclass=DummyObject): 81 | _backends = ["flax"] 82 | 83 | def __init__(self, *args, **kwargs): 84 | requires_backends(self, ["flax"]) 85 | 86 | @classmethod 87 | def from_config(cls, *args, **kwargs): 88 | requires_backends(cls, ["flax"]) 89 | 90 | @classmethod 91 | def from_pretrained(cls, *args, **kwargs): 92 | requires_backends(cls, ["flax"]) 93 | 94 | 95 | class FlaxDDPMScheduler(metaclass=DummyObject): 96 | _backends = ["flax"] 97 | 98 | def __init__(self, *args, **kwargs): 99 | requires_backends(self, ["flax"]) 100 | 101 | @classmethod 102 | def from_config(cls, *args, **kwargs): 103 | requires_backends(cls, ["flax"]) 104 | 105 | @classmethod 106 | def from_pretrained(cls, *args, **kwargs): 107 | requires_backends(cls, ["flax"]) 108 | 109 | 110 | class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject): 111 | _backends = ["flax"] 112 | 113 | def __init__(self, *args, **kwargs): 114 | requires_backends(self, ["flax"]) 115 | 116 | @classmethod 117 | def from_config(cls, *args, **kwargs): 118 | requires_backends(cls, ["flax"]) 119 | 120 | @classmethod 121 | def from_pretrained(cls, *args, **kwargs): 122 | requires_backends(cls, ["flax"]) 123 | 124 | 125 | class FlaxKarrasVeScheduler(metaclass=DummyObject): 126 | _backends = ["flax"] 127 | 128 | def __init__(self, *args, **kwargs): 129 | requires_backends(self, ["flax"]) 130 | 131 | @classmethod 132 | def from_config(cls, *args, **kwargs): 133 | requires_backends(cls, ["flax"]) 134 | 135 | @classmethod 136 | def from_pretrained(cls, *args, **kwargs): 137 | requires_backends(cls, ["flax"]) 138 | 139 | 140 | class FlaxLMSDiscreteScheduler(metaclass=DummyObject): 141 | _backends = ["flax"] 142 | 143 | def __init__(self, *args, **kwargs): 144 | requires_backends(self, ["flax"]) 145 | 146 | @classmethod 147 | def from_config(cls, *args, **kwargs): 148 | requires_backends(cls, ["flax"]) 149 | 150 | @classmethod 151 | def from_pretrained(cls, *args, **kwargs): 152 | requires_backends(cls, ["flax"]) 153 | 154 | 155 | class FlaxPNDMScheduler(metaclass=DummyObject): 156 | _backends = ["flax"] 157 | 158 | def __init__(self, *args, **kwargs): 159 | requires_backends(self, ["flax"]) 160 | 161 | @classmethod 162 | def from_config(cls, *args, **kwargs): 163 | requires_backends(cls, ["flax"]) 164 | 165 | @classmethod 166 | def from_pretrained(cls, *args, **kwargs): 167 | requires_backends(cls, ["flax"]) 168 | 169 | 170 | class FlaxSchedulerMixin(metaclass=DummyObject): 171 | _backends = ["flax"] 172 | 173 | def __init__(self, *args, **kwargs): 174 | requires_backends(self, ["flax"]) 175 | 176 | @classmethod 177 | def from_config(cls, *args, **kwargs): 178 | requires_backends(cls, ["flax"]) 179 | 180 | @classmethod 181 | def from_pretrained(cls, *args, **kwargs): 182 | requires_backends(cls, ["flax"]) 183 | 184 | 185 | class FlaxScoreSdeVeScheduler(metaclass=DummyObject): 186 | _backends = ["flax"] 187 | 188 | def __init__(self, *args, **kwargs): 189 | requires_backends(self, ["flax"]) 190 | 191 | @classmethod 192 | def from_config(cls, *args, **kwargs): 193 | requires_backends(cls, ["flax"]) 194 | 195 | @classmethod 196 | def from_pretrained(cls, *args, **kwargs): 197 | requires_backends(cls, ["flax"]) 198 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/dummy_note_seq_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class MidiProcessor(metaclass=DummyObject): 6 | _backends = ["note_seq"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["note_seq"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["note_seq"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["note_seq"]) 18 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/dummy_onnx_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class OnnxRuntimeModel(metaclass=DummyObject): 6 | _backends = ["onnx"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["onnx"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["onnx"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["onnx"]) 18 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/dummy_torch_and_librosa_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class AudioDiffusionPipeline(metaclass=DummyObject): 6 | _backends = ["torch", "librosa"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["torch", "librosa"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["torch", "librosa"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["torch", "librosa"]) 18 | 19 | 20 | class Mel(metaclass=DummyObject): 21 | _backends = ["torch", "librosa"] 22 | 23 | def __init__(self, *args, **kwargs): 24 | requires_backends(self, ["torch", "librosa"]) 25 | 26 | @classmethod 27 | def from_config(cls, *args, **kwargs): 28 | requires_backends(cls, ["torch", "librosa"]) 29 | 30 | @classmethod 31 | def from_pretrained(cls, *args, **kwargs): 32 | requires_backends(cls, ["torch", "librosa"]) 33 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/dummy_torch_and_scipy_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class LMSDiscreteScheduler(metaclass=DummyObject): 6 | _backends = ["torch", "scipy"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["torch", "scipy"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["torch", "scipy"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["torch", "scipy"]) 18 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class StableDiffusionKDiffusionPipeline(metaclass=DummyObject): 6 | _backends = ["torch", "transformers", "k_diffusion"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["torch", "transformers", "k_diffusion"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["torch", "transformers", "k_diffusion"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["torch", "transformers", "k_diffusion"]) 18 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class OnnxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): 6 | _backends = ["torch", "transformers", "onnx"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["torch", "transformers", "onnx"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["torch", "transformers", "onnx"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["torch", "transformers", "onnx"]) 18 | 19 | 20 | class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject): 21 | _backends = ["torch", "transformers", "onnx"] 22 | 23 | def __init__(self, *args, **kwargs): 24 | requires_backends(self, ["torch", "transformers", "onnx"]) 25 | 26 | @classmethod 27 | def from_config(cls, *args, **kwargs): 28 | requires_backends(cls, ["torch", "transformers", "onnx"]) 29 | 30 | @classmethod 31 | def from_pretrained(cls, *args, **kwargs): 32 | requires_backends(cls, ["torch", "transformers", "onnx"]) 33 | 34 | 35 | class OnnxStableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): 36 | _backends = ["torch", "transformers", "onnx"] 37 | 38 | def __init__(self, *args, **kwargs): 39 | requires_backends(self, ["torch", "transformers", "onnx"]) 40 | 41 | @classmethod 42 | def from_config(cls, *args, **kwargs): 43 | requires_backends(cls, ["torch", "transformers", "onnx"]) 44 | 45 | @classmethod 46 | def from_pretrained(cls, *args, **kwargs): 47 | requires_backends(cls, ["torch", "transformers", "onnx"]) 48 | 49 | 50 | class OnnxStableDiffusionPipeline(metaclass=DummyObject): 51 | _backends = ["torch", "transformers", "onnx"] 52 | 53 | def __init__(self, *args, **kwargs): 54 | requires_backends(self, ["torch", "transformers", "onnx"]) 55 | 56 | @classmethod 57 | def from_config(cls, *args, **kwargs): 58 | requires_backends(cls, ["torch", "transformers", "onnx"]) 59 | 60 | @classmethod 61 | def from_pretrained(cls, *args, **kwargs): 62 | requires_backends(cls, ["torch", "transformers", "onnx"]) 63 | 64 | 65 | class OnnxStableDiffusionUpscalePipeline(metaclass=DummyObject): 66 | _backends = ["torch", "transformers", "onnx"] 67 | 68 | def __init__(self, *args, **kwargs): 69 | requires_backends(self, ["torch", "transformers", "onnx"]) 70 | 71 | @classmethod 72 | def from_config(cls, *args, **kwargs): 73 | requires_backends(cls, ["torch", "transformers", "onnx"]) 74 | 75 | @classmethod 76 | def from_pretrained(cls, *args, **kwargs): 77 | requires_backends(cls, ["torch", "transformers", "onnx"]) 78 | 79 | 80 | class StableDiffusionOnnxPipeline(metaclass=DummyObject): 81 | _backends = ["torch", "transformers", "onnx"] 82 | 83 | def __init__(self, *args, **kwargs): 84 | requires_backends(self, ["torch", "transformers", "onnx"]) 85 | 86 | @classmethod 87 | def from_config(cls, *args, **kwargs): 88 | requires_backends(cls, ["torch", "transformers", "onnx"]) 89 | 90 | @classmethod 91 | def from_pretrained(cls, *args, **kwargs): 92 | requires_backends(cls, ["torch", "transformers", "onnx"]) 93 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class SpectrogramDiffusionPipeline(metaclass=DummyObject): 6 | _backends = ["transformers", "torch", "note_seq"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["transformers", "torch", "note_seq"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["transformers", "torch", "note_seq"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["transformers", "torch", "note_seq"]) 18 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/model_card_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 5 | 7 | 8 | # {{ model_name | default("Diffusion Model") }} 9 | 10 | ## Model description 11 | 12 | This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library 13 | on the `{{ dataset_name }}` dataset. 14 | 15 | ## Intended uses & limitations 16 | 17 | #### How to use 18 | 19 | ```python 20 | # TODO: add an example code snippet for running this diffusion pipeline 21 | ``` 22 | 23 | #### Limitations and bias 24 | 25 | [TODO: provide examples of latent issues and potential remediations] 26 | 27 | ## Training data 28 | 29 | [TODO: describe the data used to train the model] 30 | 31 | ### Training hyperparameters 32 | 33 | The following hyperparameters were used during training: 34 | - learning_rate: {{ learning_rate }} 35 | - train_batch_size: {{ train_batch_size }} 36 | - eval_batch_size: {{ eval_batch_size }} 37 | - gradient_accumulation_steps: {{ gradient_accumulation_steps }} 38 | - optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }} 39 | - lr_scheduler: {{ lr_scheduler }} 40 | - lr_warmup_steps: {{ lr_warmup_steps }} 41 | - ema_inv_gamma: {{ ema_inv_gamma }} 42 | - ema_inv_gamma: {{ ema_power }} 43 | - ema_inv_gamma: {{ ema_max_decay }} 44 | - mixed_precision: {{ mixed_precision }} 45 | 46 | ### Training results 47 | 48 | 📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars) 49 | 50 | 51 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/outputs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Generic utilities 16 | """ 17 | 18 | from collections import OrderedDict 19 | from dataclasses import fields 20 | from typing import Any, Tuple 21 | 22 | import numpy as np 23 | 24 | from .import_utils import is_torch_available 25 | 26 | 27 | def is_tensor(x): 28 | """ 29 | Tests if `x` is a `torch.Tensor` or `np.ndarray`. 30 | """ 31 | if is_torch_available(): 32 | import torch 33 | 34 | if isinstance(x, torch.Tensor): 35 | return True 36 | 37 | return isinstance(x, np.ndarray) 38 | 39 | 40 | class BaseOutput(OrderedDict): 41 | """ 42 | Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a 43 | tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular 44 | python dictionary. 45 | 46 | 47 | 48 | You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple 49 | before. 50 | 51 | 52 | """ 53 | 54 | def __post_init__(self): 55 | class_fields = fields(self) 56 | 57 | # Safety and consistency checks 58 | if not len(class_fields): 59 | raise ValueError(f"{self.__class__.__name__} has no fields.") 60 | 61 | first_field = getattr(self, class_fields[0].name) 62 | other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) 63 | 64 | if other_fields_are_none and isinstance(first_field, dict): 65 | for key, value in first_field.items(): 66 | self[key] = value 67 | else: 68 | for field in class_fields: 69 | v = getattr(self, field.name) 70 | if v is not None: 71 | self[field.name] = v 72 | 73 | def __delitem__(self, *args, **kwargs): 74 | raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") 75 | 76 | def setdefault(self, *args, **kwargs): 77 | raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") 78 | 79 | def pop(self, *args, **kwargs): 80 | raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") 81 | 82 | def update(self, *args, **kwargs): 83 | raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") 84 | 85 | def __getitem__(self, k): 86 | if isinstance(k, str): 87 | inner_dict = dict(self.items()) 88 | return inner_dict[k] 89 | else: 90 | return self.to_tuple()[k] 91 | 92 | def __setattr__(self, name, value): 93 | if name in self.keys() and value is not None: 94 | # Don't call self.__setitem__ to avoid recursion errors 95 | super().__setitem__(name, value) 96 | super().__setattr__(name, value) 97 | 98 | def __setitem__(self, key, value): 99 | # Will raise a KeyException if needed 100 | super().__setitem__(key, value) 101 | # Don't call self.__setattr__ to avoid recursion errors 102 | super().__setattr__(key, value) 103 | 104 | def to_tuple(self) -> Tuple[Any]: 105 | """ 106 | Convert self to a tuple containing all the attributes/keys that are not `None`. 107 | """ 108 | return tuple(self[k] for k in self.keys()) 109 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/pil_utils.py: -------------------------------------------------------------------------------- 1 | import PIL.Image 2 | import PIL.ImageOps 3 | from packaging import version 4 | 5 | 6 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): 7 | PIL_INTERPOLATION = { 8 | "linear": PIL.Image.Resampling.BILINEAR, 9 | "bilinear": PIL.Image.Resampling.BILINEAR, 10 | "bicubic": PIL.Image.Resampling.BICUBIC, 11 | "lanczos": PIL.Image.Resampling.LANCZOS, 12 | "nearest": PIL.Image.Resampling.NEAREST, 13 | } 14 | else: 15 | PIL_INTERPOLATION = { 16 | "linear": PIL.Image.LINEAR, 17 | "bilinear": PIL.Image.BILINEAR, 18 | "bicubic": PIL.Image.BICUBIC, 19 | "lanczos": PIL.Image.LANCZOS, 20 | "nearest": PIL.Image.NEAREST, 21 | } 22 | -------------------------------------------------------------------------------- /diffuser/diffusers/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | PyTorch utilities: Utilities related to PyTorch 16 | """ 17 | from typing import List, Optional, Tuple, Union 18 | 19 | from . import logging 20 | from .import_utils import is_torch_available, is_torch_version 21 | 22 | 23 | if is_torch_available(): 24 | import torch 25 | 26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 27 | 28 | 29 | def randn_tensor( 30 | shape: Union[Tuple, List], 31 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 32 | device: Optional["torch.device"] = None, 33 | dtype: Optional["torch.dtype"] = None, 34 | layout: Optional["torch.layout"] = None, 35 | ): 36 | """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When 37 | passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor 38 | will always be created on CPU. 39 | """ 40 | # device on which tensor is created defaults to device 41 | rand_device = device 42 | batch_size = shape[0] 43 | 44 | layout = layout or torch.strided 45 | device = device or torch.device("cpu") 46 | 47 | if generator is not None: 48 | gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type 49 | if gen_device_type != device.type and gen_device_type == "cpu": 50 | rand_device = "cpu" 51 | if device != "mps": 52 | logger.info( 53 | f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." 54 | f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" 55 | f" slighly speed up this function by passing a generator that was created on the {device} device." 56 | ) 57 | elif gen_device_type != device.type and gen_device_type == "cuda": 58 | raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") 59 | 60 | if isinstance(generator, list): 61 | shape = (1,) + shape[1:] 62 | latents = [ 63 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) 64 | for i in range(batch_size) 65 | ] 66 | latents = torch.cat(latents, dim=0).to(device) 67 | else: 68 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) 69 | 70 | return latents 71 | 72 | 73 | def is_compiled_module(module): 74 | """Check whether the module was compiled with torch.compile()""" 75 | if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): 76 | return False 77 | return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) 78 | -------------------------------------------------------------------------------- /diffuser/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | from .util import EasyDict, make_cache_dir_path 9 | -------------------------------------------------------------------------------- /diffuser/dnnlib/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/dnnlib/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/dnnlib/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/dnnlib/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/eval_clip_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import open_clip 4 | from coco_data_loader import text_image_pair 5 | import argparse 6 | from tqdm import tqdm 7 | import clip 8 | import aesthetic_score 9 | 10 | parser = argparse.ArgumentParser(description='Generate images with stable diffusion') 11 | parser.add_argument('--steps', type=int, default=50, help='number of inference steps during sampling') 12 | parser.add_argument('--generate_seed', type=int, default=6) 13 | parser.add_argument('--bs', type=int, default=16) 14 | parser.add_argument('--max_cnt', type=int, default=100, help='number of maximum geneated samples') 15 | parser.add_argument('--csv_path', type=str, default='./generated_images/subset.csv') 16 | parser.add_argument('--dir_path', type=str, default='./generated_images/subset') 17 | parser.add_argument('--scheduler', type=str, default='DDPM') 18 | args = parser.parse_args() 19 | 20 | # define dataset / data_loader 21 | text2img_dataset = text_image_pair(dir_path=args.dir_path, csv_path=args.csv_path) 22 | text2img_loader = torch.utils.data.DataLoader(dataset=text2img_dataset, batch_size=args.bs, shuffle=False) 23 | 24 | print("total length:", len(text2img_dataset)) 25 | model, _, preprocess = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k') 26 | model2, _ = clip.load("ViT-L/14", device='cuda') #RN50x64 27 | model = model.cuda().eval() 28 | model2 = model2.eval() 29 | tokenizer = open_clip.get_tokenizer('ViT-g-14') 30 | 31 | 32 | model_aes = aesthetic_score.MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14 33 | s = torch.load("./clip-refs/sac+logos+ava1-l14-linearMSE.pth") # load the model you trained previously or the model available in this repo 34 | model_aes.load_state_dict(s) 35 | model_aes.to("cuda") 36 | model_aes.eval() 37 | 38 | # text = tokenizer(["a horse", "a dog", "a cat"]) 39 | cnt = 0. 40 | total_clip_score = 0. 41 | total_aesthetic_score = 0. 42 | 43 | 44 | with torch.no_grad(), torch.cuda.amp.autocast(): 45 | for idx, (image, image2, text) in tqdm(enumerate(text2img_loader)): 46 | image = image.cuda().float() 47 | image2 = image2.cuda().float() 48 | text = list(text) 49 | text = tokenizer(text).cuda() 50 | # print('text:') 51 | # print(text.shape) 52 | image_features = model.encode_image(image).float() 53 | text_features = model.encode_text(text).float() 54 | # (bs, 1024) 55 | image_features /= image_features.norm(dim=-1, keepdim=True) 56 | text_features /= text_features.norm(dim=-1, keepdim=True) 57 | 58 | total_clip_score += (image_features * text_features).sum() 59 | 60 | image_features = model2.encode_image(image) 61 | im_emb_arr = aesthetic_score.normalized(image_features.cpu().detach().numpy()) 62 | total_aesthetic_score += model_aes(torch.from_numpy(im_emb_arr).to(image.device).type(torch.cuda.FloatTensor)).sum() 63 | 64 | cnt += len(image) 65 | 66 | 67 | print("Average ClIP score :", total_clip_score.item() / cnt) 68 | print("Average Aesthetic score :", total_aesthetic_score.item() / cnt) -------------------------------------------------------------------------------- /diffuser/ffhq.py: -------------------------------------------------------------------------------- 1 | 2 | from diffusers import DiffusionPipeline 3 | import torch 4 | import argparse 5 | from torchvision.utils import make_grid, save_image 6 | import os 7 | import numpy as np 8 | # setup args 9 | parser = argparse.ArgumentParser(description='Generate images with stable diffusion') 10 | parser.add_argument('--name', type=str, default=None) 11 | args = parser.parse_args() 12 | 13 | model_id = "google/ncsnpp-ffhq-1024" 14 | 15 | # load model and scheduler 16 | sde_ve = DiffusionPipeline.from_pretrained(model_id).to("cuda") 17 | 18 | # run pipeline in inference (sample random noise and denoise) 19 | #gen = torch.Generator(torch.device('cpu')).manual_seed(int(123)) 20 | image_list = [] 21 | seed_list = [i for i in range(1)] 22 | for seed in seed_list: 23 | image = sde_ve(batch_size=1, seed=seed, output_type='tensor') 24 | print(image.shape) 25 | image_list.append(image) 26 | 27 | images = torch.cat(image_list, dim=0) 28 | #images_ = (images + 1) / 2. 29 | images_ = images 30 | print("len:", len(images)) 31 | image_grid = make_grid(images_, nrow=int(np.sqrt(len(images)))) 32 | save_image(image_grid, os.path.join(f'{args.name}.png')) 33 | 34 | # # save image 35 | # if args.name is not None: 36 | # image.save(f"{args.name}.png") 37 | # else: 38 | # image.save("sde_ve_generated_image.png") 39 | 40 | -------------------------------------------------------------------------------- /diffuser/generate.py: -------------------------------------------------------------------------------- 1 | # make sure you're logged in with \`huggingface-cli login\` 2 | from diffusers import StableDiffusionPipeline, DDIMScheduler, DDPMScheduler, SDEScheduler, EulerDiscreteScheduler 3 | import torch 4 | from coco_data_loader import text_image_pair 5 | from PIL import Image 6 | import os 7 | import pandas as pd 8 | import argparse 9 | import torch.nn as nn 10 | from torch_utils import distributed as dist 11 | import numpy as np 12 | import tqdm 13 | 14 | parser = argparse.ArgumentParser(description='Generate images with stable diffusion') 15 | parser.add_argument('--steps', type=int, default=30, help='number of inference steps during sampling') 16 | parser.add_argument('--generate_seed', type=int, default=6) 17 | parser.add_argument('--w', type=float, default=7.5) 18 | parser.add_argument('--s_noise', type=float, default=1.) 19 | parser.add_argument('--bs', type=int, default=16) 20 | parser.add_argument('--max_cnt', type=int, default=5000, help='number of maximum geneated samples') 21 | parser.add_argument('--save_path', type=str, default='./generated_images') 22 | parser.add_argument('--scheduler', type=str, default='DDPM') 23 | parser.add_argument('--name', type=str, default=None) 24 | parser.add_argument('--restart', action='store_true', default=False) 25 | parser.add_argument('--second', action='store_true', default=False, help='second order ODE') 26 | parser.add_argument('--sigma', action='store_true', default=False, help='use sigma') 27 | args = parser.parse_args() 28 | 29 | 30 | dist.init() 31 | 32 | dist.print0('Args:') 33 | for k, v in sorted(vars(args).items()): 34 | dist.print0('\t{}: {}'.format(k, v)) 35 | # define dataset / data_loader 36 | 37 | df = pd.read_csv('./coco/coco/subset.csv') 38 | all_text = list(df['caption']) 39 | all_text = all_text[: args.max_cnt] 40 | 41 | num_batches = ((len(all_text) - 1) // (args.bs * dist.get_world_size()) + 1) * dist.get_world_size() 42 | all_batches = np.array_split(np.array(all_text), num_batches) 43 | rank_batches = all_batches[dist.get_rank():: dist.get_world_size()] 44 | 45 | 46 | index_list = np.arange(len(all_text)) 47 | all_batches_index = np.array_split(index_list, num_batches) 48 | rank_batches_index = all_batches_index[dist.get_rank():: dist.get_world_size()] 49 | 50 | 51 | ##### load stable diffusion models ##### 52 | pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") 53 | dist.print0("default scheduler config:") 54 | dist.print0(pipe.scheduler.config) 55 | pipe = pipe.to("cuda") 56 | 57 | if args.scheduler == 'DDPM': 58 | pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config) 59 | elif args.scheduler == 'DDIM': 60 | # recommend using DDIM with Restart 61 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) 62 | pipe.scheduler.use_sigma = args.sigma 63 | elif args.scheduler == 'SDE': 64 | pipe.scheduler = SDEScheduler.from_config(pipe.scheduler.config) 65 | elif args.scheduler == 'ODE': 66 | pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) 67 | pipe.scheduler.use_karras_sigmas = False 68 | else: 69 | raise NotImplementedError 70 | 71 | generator = torch.Generator(device="cuda").manual_seed(args.generate_seed) 72 | 73 | ##### setup save configuration ####### 74 | if args.name is None: 75 | save_dir = os.path.join(args.save_path, 76 | f'scheduler_{args.scheduler}_steps_{args.steps}_restart_{args.restart}_w_{args.w}_second_{args.second}_seed_{args.generate_seed}_sigma_{args.sigma}') 77 | else: 78 | save_dir = os.path.join(args.save_path, 79 | f'scheduler_{args.scheduler}_steps_{args.steps}_restart_{args.restart}_w_{args.w}_second_{args.second}_seed_{args.generate_seed}_sigma_{args.sigma}_name_{args.name}') 80 | 81 | dist.print0("save images to {}".format(save_dir)) 82 | 83 | if dist.get_rank() == 0 and not os.path.exists(save_dir): 84 | os.mkdir(save_dir) 85 | 86 | ## generate images ## 87 | for cnt, mini_batch in enumerate(tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0))): 88 | torch.distributed.barrier() 89 | text = list(mini_batch) 90 | image = pipe(text, generator=generator, num_inference_steps=args.steps, guidance_scale=args.w, restart=args.restart, second_order=args.second, dist=dist, S_noise=args.s_noise).images 91 | 92 | for text_idx, global_idx in enumerate(rank_batches_index[cnt]): 93 | image[text_idx].save(os.path.join(save_dir, f'{global_idx}.png')) 94 | 95 | # Done. 96 | torch.distributed.barrier() 97 | if dist.get_rank() == 0: 98 | d = {'caption': all_text} 99 | df = pd.DataFrame(data=d) 100 | df.to_csv(os.path.join(save_dir, 'subset.csv')) -------------------------------------------------------------------------------- /diffuser/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.18.0 2 | flax==0.6.9 3 | huggingface_hub==0.13.4 4 | importlib_metadata==4.11.3 5 | jax==0.4.8 6 | k_diffusion==0.0.14 7 | librosa==0.10.0.post2 8 | msgpack_python==0.5.6 9 | numpy==1.22.4 10 | omegaconf==2.3.0 11 | onnxruntime==1.14.1 12 | open_clip_torch==2.16.2 13 | opencv_python==4.7.0.72 14 | packaging==21.3 15 | pandas==1.4.2 16 | Pillow==9.5.0 17 | psutil==5.8.0 18 | pyspng==0.1.1 19 | pytest==7.1.1 20 | pytorch_fid==0.3.0 21 | pytorch_lightning==2.0.1.post0 22 | requests==2.27.1 23 | safetensors==0.3.0 24 | scipy==1.10.1 25 | torch==1.12.1 26 | torchvision==0.13.1 27 | tqdm==4.64.0 28 | transformers==4.28.1 29 | -------------------------------------------------------------------------------- /diffuser/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /diffuser/torch_utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/torch_utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/torch_utils/__pycache__/distributed.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/torch_utils/__pycache__/distributed.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/torch_utils/__pycache__/misc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/torch_utils/__pycache__/misc.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/torch_utils/__pycache__/persistence.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/torch_utils/__pycache__/persistence.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/torch_utils/__pycache__/training_stats.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newbeeer/diffusion_restart_sampling/8cbcb076330381216b2f60578bb6e381ce182683/diffuser/torch_utils/__pycache__/training_stats.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import os 9 | import torch 10 | from . import training_stats 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def init(): 15 | if 'MASTER_ADDR' not in os.environ: 16 | os.environ['MASTER_ADDR'] = 'localhost' 17 | if 'MASTER_PORT' not in os.environ: 18 | os.environ['MASTER_PORT'] = '29500' 19 | if 'RANK' not in os.environ: 20 | os.environ['RANK'] = '0' 21 | if 'LOCAL_RANK' not in os.environ: 22 | os.environ['LOCAL_RANK'] = '0' 23 | if 'WORLD_SIZE' not in os.environ: 24 | os.environ['WORLD_SIZE'] = '1' 25 | 26 | backend = 'gloo' if os.name == 'nt' else 'nccl' 27 | torch.distributed.init_process_group(backend=backend, init_method='env://') 28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 29 | 30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def get_rank(): 36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def get_world_size(): 41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def should_stop(): 46 | return False 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def update_progress(cur, total): 51 | _ = cur, total 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def print0(*args, **kwargs): 56 | if get_rank() == 0: 57 | print(*args, **kwargs) 58 | 59 | #---------------------------------------------------------------------------- 60 | -------------------------------------------------------------------------------- /diffuser/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /diffuser/visualization.py: -------------------------------------------------------------------------------- 1 | # make sure you're logged in with \`huggingface-cli login\` 2 | from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler, DDIMScheduler, DDPMScheduler, SDEScheduler, EulerDiscreteScheduler 3 | import torch 4 | import argparse 5 | import os 6 | from torchvision.utils import make_grid, save_image 7 | import numpy as np 8 | from PIL import Image 9 | 10 | parser = argparse.ArgumentParser(description='Generate images with stable diffusion') 11 | parser.add_argument('--steps', type=int, default=30, help='number of inference steps during sampling') 12 | parser.add_argument('--generate_seed', type=int, default=6) 13 | parser.add_argument('--w', type=float, default=8) 14 | parser.add_argument('--bs', type=int, default=16) 15 | parser.add_argument('--max_cnt', type=int, default=5000, help='number of maximum geneated samples') 16 | parser.add_argument('--save_path', type=str, default='./generated_images') 17 | parser.add_argument('--scheduler', type=str, default='DDIM') 18 | parser.add_argument('--restart', action='store_true', default=False) 19 | parser.add_argument('--second', action='store_true', default=False, help='second order ODE') 20 | parser.add_argument('--sigma', action='store_true', default=False, help='use sigma') 21 | parser.add_argument('--prompt', type=str, default='a photo of an astronaut riding a horse on mars') 22 | 23 | args = parser.parse_args() 24 | 25 | print('Args:') 26 | for k, v in sorted(vars(args).items()): 27 | print('\t{}: {}'.format(k, v)) 28 | 29 | os.makedirs('./vis', exist_ok=True) 30 | 31 | # prompt_list = ["a photo of an astronaut riding a horse on mars", "a raccoon playing table tennis", 32 | # "Intricate origami of a fox in a snowy forest", "A transparent sculpture of a duck made out of glass"] 33 | prompt_list = [args.prompt] 34 | 35 | 36 | for prompt_ in prompt_list: 37 | 38 | pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") 39 | print("default scheduler config:") 40 | print(pipe.scheduler.config) 41 | 42 | pipe = pipe.to("cuda") 43 | 44 | if args.scheduler == 'DDPM': 45 | pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config) 46 | elif args.scheduler == 'DDIM': 47 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) 48 | pipe.scheduler.use_sigma = args.sigma 49 | else: 50 | raise NotImplementedError 51 | 52 | prompt = [prompt_] * 16 53 | 54 | # Restart 55 | generator = torch.Generator(device="cuda").manual_seed(args.generate_seed) 56 | out, _ = pipe(prompt, generator=generator, num_inference_steps=args.steps, guidance_scale=args.w, 57 | restart=args.restart, second_order=args.second, output_type='tensor') 58 | image = out.images 59 | 60 | print(f'image saving to ./vis/{prompt_}_{args.scheduler}_restart_True_steps_{args.steps}_w_{args.w}_seed_{args.generate_seed}.png') 61 | image_grid = make_grid(torch.from_numpy(image).permute(0, 3, 1, 2), nrow=int(np.sqrt(len(image)))) 62 | save_image(image_grid, 63 | f"./vis/{prompt_}_{args.scheduler}_restart_True_steps_{args.steps}_w_{args.w}_seed_{args.generate_seed}.png") 64 | --------------------------------------------------------------------------------