├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── example.py ├── run.py ├── run_safety_checker.py ├── setup.py └── stable_diffusion_jax ├── __init__.py ├── configuration_unet2d.py ├── configuration_vae.py ├── convert_diffusers_to_jax.py ├── modeling_unet2d.py ├── modeling_vae.py ├── pipeline_stable_diffusion.py ├── safety_checker.py └── scheduling_pndm.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # tests and logs 12 | tests/fixtures/cached_*_text.txt 13 | logs/ 14 | lightning_logs/ 15 | lang_code_data/ 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | .dmypy.json 119 | dmypy.json 120 | 121 | # Pyre type checker 122 | .pyre/ 123 | 124 | # vscode 125 | .vs 126 | .vscode 127 | 128 | # Pycharm 129 | .idea 130 | 131 | # TF code 132 | tensorflow_code 133 | 134 | # Models 135 | proc_data 136 | 137 | # data 138 | /data 139 | serialization_dir 140 | 141 | # emacs 142 | *.*~ 143 | debug.env 144 | 145 | # vim 146 | .*.swp 147 | 148 | #ctags 149 | tags 150 | 151 | # pre-commit 152 | .pre-commit* 153 | 154 | # .lock 155 | *.lock -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Suraj Patil 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | check_dirs := stable_diffusion_jax 2 | 3 | quality: 4 | black -l 119 --check --preview $(check_dirs) 5 | isort --check-only $(check_dirs) 6 | flake8 $(check_dirs) 7 | 8 | style: 9 | black --preview -l 119 $(check_dirs) 10 | isort $(check_dirs) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## TODOs: 2 | 3 | - [x] Finish implementing the `UNet2D` model in `modeling_unte2d.py`. Port weights of any existing LDM unet from diffusers and verify equivalence. I've added the skleton of modules that we need to implement in the file. 4 | - [x] Adapt the `PNDMScheduler` from `diffusers` for JAX: Use `jnp` arrays and make it stateless. 5 | - [x] Add the KL module from (here)[https://github.dev/CompVis/stable-diffusion] in `modeling_vae.py` file. For inference we don't really need it, but would be nice to have for completeness. Port the weights of any existing KL VAE and verify equivalence. 6 | - [x] Add an inference loop in `pipeline_stabel_diffusion`. We should able to `jit`/`pmap` the loop to deploy on TPUs. 7 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | from flax.jax_utils import replicate 5 | from flax.training.common_utils import shard 6 | from PIL import Image 7 | from transformers import CLIPTokenizer, FlaxCLIPTextModel, CLIPConfig 8 | 9 | from stable_diffusion_jax import ( 10 | AutoencoderKL, 11 | InferenceState, 12 | PNDMScheduler, 13 | StableDiffusionPipeline, 14 | UNet2D 15 | StableDiffusionSafetyCheckerModel, 16 | ) 17 | from stable_diffusion_jax.convert_diffusers_to_jax import convert_diffusers_to_jax 18 | 19 | 20 | # convert diffusers checkpoint to jax 21 | pt_path = "path_to_diffusers_pt_ckpt" 22 | fx_path = "save_path" 23 | convert_diffusers_to_jax(pt_path, fx_path) 24 | 25 | 26 | # inference with jax 27 | dtype = jnp.bfloat16 28 | clip_model, clip_params = FlaxCLIPTextModel.from_pretrained( 29 | "openai/clip-vit-large-patch14", _do_init=False, dtype=dtype 30 | ) 31 | unet, unet_params = UNet2D.from_pretrained(f"{fx_path}/unet", _do_init=False, dtype=dtype) 32 | vae, vae_params = AutoencoderKL.from_pretrained(f"{fx_path}/vae", _do_init=False, dtype=dtype) 33 | safety_model, safety_model_params = StableDiffusionSafetyCheckerModel.from_pretrained(f"{fx_path}/safety_model", _do_init=False, dtype=dtype) 34 | 35 | config = CLIPConfig.from_pretrained("openai/clip-vit-large-patch14") 36 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") 37 | scheduler = PNDMScheduler() 38 | 39 | # create inference state and replicate it across all TPU devices 40 | inference_state = InferenceState(text_encoder_params=clip_params, unet_params=unet_params, vae_params=vae_params) 41 | inference_state = replicate(inference_state) 42 | 43 | 44 | # create pipeline 45 | pipe = StableDiffusionPipeline(text_encoder=clip_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, vae=vae) 46 | 47 | 48 | 49 | # prepare inputs 50 | num_samples = 8 51 | p = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic" 52 | 53 | input_ids = tokenizer( 54 | [p] * num_samples, padding="max_length", truncation=True, max_length=77, return_tensors="jax" 55 | ).input_ids 56 | uncond_input_ids = tokenizer( 57 | [""] * num_samples, padding="max_length", truncation=True, max_length=77, return_tensors="jax" 58 | ).input_ids 59 | prng_seed = jax.random.PRNGKey(42) 60 | 61 | # shard inputs and rng 62 | input_ids = shard(input_ids) 63 | uncond_input_ids = shard(uncond_input_ids) 64 | prng_seed = jax.random.split(prng_seed, 8) 65 | 66 | # pmap the sample function 67 | num_inference_steps = 50 68 | guidance_scale = 1.0 69 | 70 | sample = jax.pmap(pipe.sample, static_broadcasted_argnums=(4, 5)) 71 | 72 | # sample images 73 | images = sample( 74 | input_ids, 75 | uncond_input_ids, 76 | prng_seed, 77 | inference_state, 78 | num_inference_steps, 79 | guidance_scale, 80 | ) 81 | 82 | 83 | # convert images to PIL images 84 | images = images / 2 + 0.5 85 | images = jnp.clip(images, 0, 1) 86 | images = (images * 255).round().astype("uint8") 87 | images = np.asarray(images).reshape((num_samples, 512, 512, 3)) 88 | 89 | pil_images = [Image.fromarray(image) for image in images] 90 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | from flax.jax_utils import replicate 5 | from flax.training.common_utils import shard 6 | from PIL import Image 7 | from transformers import CLIPTokenizer, FlaxCLIPTextModel, AutoFeatureExtractor 8 | import warnings 9 | 10 | from stable_diffusion_jax import ( 11 | AutoencoderKL, 12 | InferenceState, 13 | PNDMScheduler, 14 | StableDiffusionPipeline, 15 | UNet2D, 16 | StableDiffusionSafetyCheckerModel 17 | ) 18 | from stable_diffusion_jax.convert_diffusers_to_jax import convert_diffusers_to_jax 19 | 20 | 21 | def image_grid(imgs, rows, cols): 22 | assert len(imgs) == rows*cols 23 | 24 | w, h = imgs[0].size 25 | grid = Image.new('RGB', size=(cols*w, rows*h)) 26 | grid_w, grid_h = grid.size 27 | 28 | for i, img in enumerate(imgs): 29 | grid.paste(img, box=(i%cols*w, i//cols*h)) 30 | return grid 31 | 32 | 33 | # convert diffusers checkpoint to jax 34 | #pt_path = "/home/patrick/stable-diffusion-v1-3" 35 | #fx_path = pt_path + "_jax" 36 | #convert_diffusers_to_jax(pt_path, fx_path) 37 | 38 | fx_path = "/home/patrick_huggingface_co/sd-v1-4-flax" 39 | 40 | # inference with jax 41 | dtype = jnp.bfloat16 42 | clip_model, clip_params = FlaxCLIPTextModel.from_pretrained( 43 | "openai/clip-vit-large-patch14", _do_init=False, dtype=dtype 44 | ) 45 | unet, unet_params = UNet2D.from_pretrained(f"{fx_path}/unet", _do_init=False, dtype=dtype) 46 | vae, vae_params = AutoencoderKL.from_pretrained(f"{fx_path}/vae", _do_init=False, dtype=dtype) 47 | safety_checker, safety_params = StableDiffusionSafetyCheckerModel.from_pretrained(f"{fx_path}/safety_checker", _do_init=False) 48 | scheduler = PNDMScheduler.from_config(f"{fx_path}/scheduler", use_auth_token=True) 49 | feature_extractor = AutoFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14") 50 | 51 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") 52 | 53 | # create inference state and replicate it across all TPU devices 54 | inference_state = InferenceState(text_encoder_params=clip_params, unet_params=unet_params, vae_params=vae_params) 55 | inference_state = replicate(inference_state) 56 | 57 | 58 | # create pipeline 59 | pipe = StableDiffusionPipeline(text_encoder=clip_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, vae=vae) 60 | 61 | 62 | # prepare inputs 63 | num_samples = 8 64 | p = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic" 65 | p = "a photograph of an astronaut riding a horse" 66 | 67 | input_ids = tokenizer( 68 | [p] * num_samples, padding="max_length", truncation=True, max_length=77, return_tensors="jax" 69 | ).input_ids 70 | uncond_input_ids = tokenizer( 71 | [""] * num_samples, padding="max_length", truncation=True, max_length=77, return_tensors="jax" 72 | ).input_ids 73 | prng_seed = jax.random.PRNGKey(1) 74 | 75 | # shard inputs and rng 76 | input_ids = shard(input_ids) 77 | uncond_input_ids = shard(uncond_input_ids) 78 | prng_seed = jax.random.split(prng_seed, 8) 79 | 80 | 81 | # pmap the sample function 82 | num_inference_steps = 50 83 | guidance_scale = 7.5 84 | 85 | sample = jax.pmap(pipe.sample, static_broadcasted_argnums=(4, 5)) 86 | 87 | # sample images 88 | images = sample( 89 | input_ids, 90 | uncond_input_ids, 91 | prng_seed, 92 | inference_state, 93 | num_inference_steps, 94 | guidance_scale, 95 | ) 96 | 97 | images = images / 2 + 0.5 98 | images = jnp.clip(images, 0, 1) 99 | images = (images * 255).round().astype("uint8") 100 | images = np.asarray(images).reshape((num_samples, 512, 512, 3)) 101 | 102 | pil_images = [Image.fromarray(image) for image in images] 103 | 104 | # run safety checker 105 | safety_cheker_input = feature_extractor(pil_images, return_tensors="np") 106 | 107 | images, has_nsfw_concept = safety_checker(safety_cheker_input.pixel_values, params=safety_params, images=images) 108 | 109 | pil_images = [Image.fromarray(image) for image in images] 110 | 111 | grid = image_grid(pil_images, rows=1, cols=num_samples) 112 | grid.save(f"/home/patrick_huggingface_co/images/image_{p}.png") 113 | -------------------------------------------------------------------------------- /run_safety_checker.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | from flax.jax_utils import replicate 5 | from flax.training.common_utils import shard 6 | from PIL import Image 7 | from flax.traverse_util import flatten_dict, unflatten_dict 8 | import warnings 9 | import torch 10 | from transformers import CLIPTokenizer, FlaxCLIPTextModel, CLIPConfig, CLIPFeatureExtractor 11 | from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax 12 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionSafetyChecker 13 | 14 | from stable_diffusion_jax import ( 15 | AutoencoderKL, 16 | InferenceState, 17 | PNDMScheduler, 18 | StableDiffusionPipeline, 19 | UNet2D, 20 | StableDiffusionSafetyCheckerModel, 21 | ) 22 | from stable_diffusion_jax.convert_diffusers_to_jax import convert_diffusers_to_jax 23 | 24 | feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14") 25 | 26 | shape = (2, 224, 224, 3) 27 | images = np.random.rand(*shape) 28 | images = (images * 255).round().astype("uint8") 29 | pil_images = [Image.fromarray(image) for image in images] 30 | safety_checker_input_np = feature_extractor(pil_images, return_tensors="np").pixel_values 31 | safety_checker_input_pt = feature_extractor(pil_images, return_tensors="pt").pixel_values 32 | 33 | 34 | model = StableDiffusionSafetyCheckerModel.from_pretrained("/home/patrick/sd-v1-4-flax/safety_checker") 35 | pt_model = StableDiffusionSafetyChecker.from_pretrained("/home/patrick/stable-diffusion-v1-1/safety_checker") 36 | 37 | result = model(safety_checker_input_np) 38 | has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] 39 | for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): 40 | if has_nsfw_concept: 41 | images[idx] = np.zeros(images[idx].shape) # black image 42 | 43 | if any(has_nsfw_concepts): 44 | warnings.warn( 45 | "Potential NSFW content was detected in one or more images. A black image will be returned instead." 46 | " Try again with a different prompt and/or seed." 47 | ) 48 | 49 | 50 | import ipdb; ipdb.set_trace() 51 | 52 | # inference with jax 53 | dtype = jnp.bfloat16 54 | clip_model, clip_params = FlaxCLIPTextModel.from_pretrained( 55 | "openai/clip-vit-large-patch14", _do_init=False, dtype=dtype 56 | ) 57 | unet, unet_params = UNet2D.from_pretrained(f"{fx_path}/unet", _do_init=False, dtype=dtype) 58 | vae, vae_params = AutoencoderKL.from_pretrained(f"{fx_path}/vae", _do_init=False, dtype=dtype) 59 | 60 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") 61 | 62 | scheduler = PNDMScheduler() 63 | 64 | 65 | 66 | # create inference state and replicate it across all TPU devices 67 | inference_state = InferenceState(text_encoder_params=clip_params, unet_params=unet_params, vae_params=vae_params) 68 | inference_state = replicate(inference_state) 69 | 70 | 71 | # create pipeline 72 | pipe = StableDiffusionPipeline(text_encoder=clip_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, vae=vae) 73 | 74 | 75 | 76 | # prepare inputs 77 | num_samples = 8 78 | p = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic" 79 | 80 | input_ids = tokenizer( 81 | [p] * num_samples, padding="max_length", truncation=True, max_length=77, return_tensors="jax" 82 | ).input_ids 83 | uncond_input_ids = tokenizer( 84 | [""] * num_samples, padding="max_length", truncation=True, max_length=77, return_tensors="jax" 85 | ).input_ids 86 | prng_seed = jax.random.PRNGKey(42) 87 | 88 | # shard inputs and rng 89 | input_ids = shard(input_ids) 90 | uncond_input_ids = shard(uncond_input_ids) 91 | prng_seed = jax.random.split(prng_seed, 8) 92 | 93 | 94 | # pmap the sample function 95 | num_inference_steps = 50 96 | guidance_scale = 1.0 97 | 98 | sample = jax.pmap(pipe.sample, static_broadcasted_argnums=(4, 5)) 99 | 100 | # sample images 101 | images = sample( 102 | input_ids, 103 | uncond_input_ids, 104 | prng_seed, 105 | inference_state, 106 | num_inference_steps, 107 | guidance_scale, 108 | ) 109 | 110 | 111 | # convert images to PIL images 112 | images = images / 2 + 0.5 113 | images = jnp.clip(images, 0, 1) 114 | images = (images * 255).round().astype("uint8") 115 | images = np.asarray(images).reshape((num_samples, 512, 512, 3)) 116 | 117 | pil_images = [Image.fromarray(image) for image in images] 118 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | # To use a consistent encoding 3 | from codecs import open 4 | import os 5 | 6 | here = os.path.abspath(os.path.dirname(__file__)) 7 | 8 | with open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 9 | long_description = f.read() 10 | 11 | setuptools.setup( 12 | name='stable-diffusion-jax', 13 | version='0.0.1', 14 | description='JAX implementation of Stable diffusion', 15 | long_description=long_description, 16 | long_description_content_type='text/markdown', 17 | packages=setuptools.find_packages(), 18 | install_requires=['jax>=0.2.6', 'flax', 'transformers'], 19 | ) 20 | -------------------------------------------------------------------------------- /stable_diffusion_jax/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_unet2d import UNet2DConfig 2 | from .configuration_vae import VAEConfig 3 | from .modeling_unet2d import UNet2D 4 | from .modeling_vae import AutoencoderKL 5 | from .pipeline_stable_diffusion import InferenceState, StableDiffusionPipeline 6 | from .scheduling_pndm import PNDMScheduler 7 | from .safety_checker import StableDiffusionSafetyCheckerModel 8 | -------------------------------------------------------------------------------- /stable_diffusion_jax/configuration_unet2d.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | 4 | class UNet2DConfig(PretrainedConfig): 5 | def __init__( 6 | self, 7 | sample_size=32, 8 | in_channels=4, 9 | out_channels=4, 10 | down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), 11 | up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), 12 | block_out_channels=(224, 448, 672, 896), 13 | layers_per_block=2, 14 | attention_head_dim=8, 15 | cross_attention_dim=768, 16 | dropout=0.1, 17 | **kwargs, 18 | ): 19 | super().__init__(**kwargs) 20 | self.sample_size = sample_size 21 | self.in_channels = in_channels 22 | self.out_channels = out_channels 23 | self.down_block_types = down_block_types 24 | self.up_block_types = up_block_types 25 | self.block_out_channels = block_out_channels 26 | self.layers_per_block = layers_per_block 27 | self.attention_head_dim = attention_head_dim 28 | self.cross_attention_dim = cross_attention_dim 29 | self.dropout = dropout 30 | -------------------------------------------------------------------------------- /stable_diffusion_jax/configuration_vae.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | 4 | class VAEConfig(PretrainedConfig): 5 | def __init__( 6 | self, 7 | in_channels=3, 8 | out_channels=3, 9 | down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"), 10 | up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), 11 | block_out_channels=(128, 256, 512, 512), 12 | layers_per_block=2, 13 | act_fn="silu", 14 | latent_channels=4, 15 | sample_size=512, 16 | double_z=True, 17 | **kwargs, 18 | ): 19 | super().__init__(**kwargs) 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.down_block_types = down_block_types 23 | self.up_block_types = up_block_types 24 | self.block_out_channels = block_out_channels 25 | self.layers_per_block = layers_per_block 26 | self.act_fn = act_fn 27 | self.latent_channels = latent_channels 28 | self.sample_size = sample_size 29 | self.double_z = double_z 30 | -------------------------------------------------------------------------------- /stable_diffusion_jax/convert_diffusers_to_jax.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | 4 | import jax.numpy as jnp 5 | from diffusers import AutoencoderKL as Autoencoder 6 | from diffusers import UNet2DConditionModel 7 | from flax.traverse_util import flatten_dict, unflatten_dict 8 | 9 | from . import AutoencoderKL, UNet2D, UNet2DConfig, VAEConfig 10 | 11 | regex = r"\w+[.]\d+" 12 | 13 | 14 | def rename_key(key): 15 | key = key.replace("downsamplers.0", "downsample") 16 | key = key.replace("upsamplers.0", "upsample") 17 | key = key.replace("net.0.proj", "dense1") 18 | key = key.replace("net.2", "dense2") 19 | key = key.replace("to_out.0", "to_out") 20 | key = key.replace("attn1", "self_attn") 21 | key = key.replace("attn2", "cross_attn") 22 | 23 | pats = re.findall(regex, key) 24 | for pat in pats: 25 | key = key.replace(pat, "_".join(pat.split("."))) 26 | return key 27 | 28 | 29 | # Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61 30 | def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): 31 | # convert pytorch tensor to numpy 32 | pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} 33 | 34 | random_flax_state_dict = flatten_dict(flax_model.params_shape_tree) 35 | flax_state_dict = {} 36 | 37 | remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params_shape_tree) and ( 38 | flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) 39 | ) 40 | add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params_shape_tree) and ( 41 | flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) 42 | ) 43 | 44 | # Need to change some parameters name to match Flax names so that we don't have to fork any layer 45 | for pt_key, pt_tensor in pt_state_dict.items(): 46 | pt_tuple_key = tuple(pt_key.split(".")) 47 | 48 | has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix 49 | require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict 50 | 51 | if remove_base_model_prefix and has_base_model_prefix: 52 | pt_tuple_key = pt_tuple_key[1:] 53 | elif add_base_model_prefix and require_base_model_prefix: 54 | pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key 55 | 56 | # Correctly rename weight parameters 57 | if ( 58 | "norm" in pt_key 59 | and (pt_tuple_key[-1] == "bias") 60 | and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict) 61 | and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) 62 | ): 63 | pt_tuple_key = pt_tuple_key[:-1] + ("scale",) 64 | elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: 65 | pt_tuple_key = pt_tuple_key[:-1] + ("scale",) 66 | if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: 67 | pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) 68 | elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict: 69 | # conv layer 70 | pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) 71 | pt_tensor = pt_tensor.transpose(2, 3, 1, 0) 72 | elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict: 73 | # linear layer 74 | pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) 75 | pt_tensor = pt_tensor.T 76 | elif pt_tuple_key[-1] == "gamma": 77 | pt_tuple_key = pt_tuple_key[:-1] + ("weight",) 78 | elif pt_tuple_key[-1] == "beta": 79 | pt_tuple_key = pt_tuple_key[:-1] + ("bias",) 80 | 81 | if pt_tuple_key in random_flax_state_dict: 82 | if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape: 83 | raise ValueError( 84 | f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " 85 | f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}." 86 | ) 87 | 88 | # also add unexpected weight so that warning is thrown 89 | flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor) 90 | 91 | return unflatten_dict(flax_state_dict) 92 | 93 | 94 | def convert_params(pt_model, fx_model): 95 | state_dict = pt_model.state_dict() 96 | keys = list(state_dict.keys()) 97 | for key in keys: 98 | renamed_key = rename_key(key) 99 | state_dict[renamed_key] = state_dict.pop(key) 100 | 101 | fx_params = convert_pytorch_state_dict_to_flax(state_dict, fx_model) 102 | return fx_params 103 | 104 | 105 | def convert_diffusers_to_jax(pt_model_path, save_path): 106 | unet_pt = UNet2DConditionModel.from_pretrained(pt_model_path, subfolder="unet", use_auth_token=True) 107 | 108 | # create UNet flax config and model 109 | config = UNet2DConfig( 110 | sample_size=unet_pt.config.sample_size, 111 | in_channels=unet_pt.config.in_channels, 112 | out_channels=unet_pt.config.out_channels, 113 | down_block_types=unet_pt.config.down_block_types, 114 | up_block_types=unet_pt.config.up_block_types, 115 | block_out_channels=unet_pt.config.block_out_channels, 116 | layers_per_block=unet_pt.config.layers_per_block, 117 | attention_head_dim=unet_pt.config.attention_head_dim, 118 | cross_attention_dim=unet_pt.config.cross_attention_dim, 119 | ) 120 | unet_fx = UNet2D(config, _do_init=False) 121 | 122 | # convert unet pt params to jax 123 | params = convert_params(unet_pt, unet_fx) 124 | # save unet 125 | unet_fx.save_pretrained(f"{save_path}/unet", params=params) 126 | 127 | vae_pt = Autoencoder.from_pretrained(pt_model_path, subfolder="vae", use_auth_token=True) 128 | 129 | # create AutoEncoder flax config and model 130 | config = VAEConfig( 131 | sample_size=vae_pt.config.sample_size, 132 | in_channels=vae_pt.config.in_channels, 133 | out_channels=vae_pt.config.out_channels, 134 | down_block_types=vae_pt.config.down_block_types, 135 | up_block_types=vae_pt.config.up_block_types, 136 | block_out_channels=vae_pt.config.block_out_channels, 137 | layers_per_block=vae_pt.config.layers_per_block, 138 | latent_channels=vae_pt.config.latent_channels, 139 | ) 140 | vae_fx = AutoencoderKL(config, _do_init=False) 141 | 142 | # convert vae pt params to jax 143 | params = convert_params(vae_pt, vae_fx) 144 | # save vae 145 | vae_fx.save_pretrained(f"{save_path}/vae", params=params) 146 | 147 | 148 | if __name__ == "__main__": 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument("--pt_model_path", type=str, required=True) 151 | parser.add_argument("--save_path", type=str, required=True) 152 | args = parser.parse_args() 153 | 154 | convert_diffusers_to_jax(args.pt_model_path, args.save_path) 155 | -------------------------------------------------------------------------------- /stable_diffusion_jax/modeling_unet2d.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import flax.linen as nn 5 | import jax 6 | import jax.numpy as jnp 7 | from flax.core.frozen_dict import FrozenDict 8 | from transformers.modeling_flax_utils import FlaxPreTrainedModel 9 | 10 | from .configuration_unet2d import UNet2DConfig 11 | 12 | 13 | def get_sinusoidal_embeddings(timesteps, embedding_dim): 14 | half_dim = embedding_dim // 2 15 | emb = -math.log(10000) * jnp.arange(half_dim, dtype=jnp.float32) 16 | emb = emb / half_dim 17 | emb = jnp.exp(emb) 18 | 19 | emb = timesteps[:, None] * emb[None, :] 20 | emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1) 21 | return emb 22 | 23 | 24 | ############################################################################################## 25 | # Transformer Blocks 26 | ############################################################################################## 27 | 28 | 29 | class Attention(nn.Module): 30 | query_dim: int 31 | heads: int = 8 32 | dim_head: int = 64 33 | dropout: float = 0.0 34 | dtype: jnp.dtype = jnp.float32 35 | 36 | def setup(self): 37 | inner_dim = self.dim_head * self.heads 38 | self.scale = self.dim_head**-0.5 39 | 40 | self.to_q = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype) 41 | self.to_k = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype) 42 | self.to_v = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype) 43 | 44 | self.to_out = nn.Dense(self.query_dim, dtype=self.dtype) 45 | 46 | def reshape_heads_to_batch_dim(self, tensor): 47 | batch_size, seq_len, dim = tensor.shape 48 | head_size = self.heads 49 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) 50 | tensor = jnp.transpose(tensor, (0, 2, 1, 3)) 51 | tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) 52 | return tensor 53 | 54 | def reshape_batch_dim_to_heads(self, tensor): 55 | batch_size, seq_len, dim = tensor.shape 56 | head_size = self.heads 57 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) 58 | tensor = jnp.transpose(tensor, (0, 2, 1, 3)) 59 | tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) 60 | return tensor 61 | 62 | def __call__(self, hidden_states, context=None, deterministic=True): 63 | context = hidden_states if context is None else context 64 | 65 | q = self.to_q(hidden_states) 66 | k = self.to_k(context) 67 | v = self.to_v(context) 68 | 69 | q = self.reshape_heads_to_batch_dim(q) 70 | k = self.reshape_heads_to_batch_dim(k) 71 | v = self.reshape_heads_to_batch_dim(v) 72 | 73 | # compute attentions 74 | attn_weights = jnp.einsum("b i d, b j d->b i j", q, k) 75 | attn_weights = attn_weights * self.scale 76 | attn_weights = nn.softmax(attn_weights, axis=2) 77 | 78 | ## attend to values 79 | hidden_states = jnp.einsum("b i j, b j d -> b i d", attn_weights, v) 80 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 81 | hidden_states = self.to_out(hidden_states) 82 | return hidden_states 83 | 84 | 85 | class GluFeedForward(nn.Module): 86 | dim: int 87 | dropout: float = 0.0 88 | dtype: jnp.dtype = jnp.float32 89 | 90 | def setup(self): 91 | inner_dim = self.dim * 4 92 | self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype) 93 | self.dense2 = nn.Dense(self.dim, dtype=self.dtype) 94 | 95 | def __call__(self, hidden_states, deterministic=True): 96 | hidden_states = self.dense1(hidden_states) 97 | hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) 98 | hidden_states = hidden_linear * nn.gelu(hidden_gelu) 99 | hidden_states = self.dense2(hidden_states) 100 | return hidden_states 101 | 102 | 103 | class TransformerBlock(nn.Module): 104 | dim: int 105 | n_heads: int 106 | d_head: int 107 | dropout: float = 0.0 108 | dtype: jnp.dtype = jnp.float32 109 | 110 | def setup(self): 111 | # self attention 112 | self.self_attn = Attention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) 113 | # cross attention 114 | self.cross_attn = Attention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) 115 | self.ff = GluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) 116 | self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) 117 | self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) 118 | self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) 119 | 120 | def __call__(self, hidden_states, context, deterministic=True): 121 | # self attention 122 | residual = hidden_states 123 | hidden_states = self.self_attn(self.norm1(hidden_states)) 124 | hidden_states = hidden_states + residual 125 | 126 | # cross attention 127 | residual = hidden_states 128 | hidden_states = self.cross_attn(self.norm2(hidden_states), context) 129 | hidden_states = hidden_states + residual 130 | 131 | # feed forward 132 | residual = hidden_states 133 | hidden_states = self.ff(self.norm3(hidden_states)) 134 | hidden_states = hidden_states + residual 135 | 136 | return hidden_states 137 | 138 | 139 | class SpatialTransformer(nn.Module): 140 | in_channels: int 141 | n_heads: int 142 | d_head: int 143 | depth: int = 1 144 | dropout: float = 0.0 145 | dtype: jnp.dtype = jnp.float32 146 | 147 | def setup(self): 148 | self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) 149 | 150 | inner_dim = self.n_heads * self.d_head 151 | self.proj_in = nn.Conv( 152 | inner_dim, 153 | kernel_size=(1, 1), 154 | strides=(1, 1), 155 | padding="VALID", 156 | dtype=self.dtype, 157 | ) 158 | 159 | self.transformer_blocks = [ 160 | TransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype) 161 | for _ in range(self.depth) 162 | ] 163 | 164 | self.proj_out = nn.Conv( 165 | inner_dim, 166 | kernel_size=(1, 1), 167 | strides=(1, 1), 168 | padding="VALID", 169 | dtype=self.dtype, 170 | ) 171 | 172 | def __call__(self, hidden_states, context, deterministic=True): 173 | batch, height, width, channels = hidden_states.shape 174 | # import ipdb; ipdb.set_trace() 175 | residual = hidden_states 176 | hidden_states = self.norm(hidden_states) 177 | hidden_states = self.proj_in(hidden_states) 178 | 179 | # hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 1)) 180 | hidden_states = hidden_states.reshape(batch, height * width, channels) 181 | 182 | for transformer_block in self.transformer_blocks: 183 | hidden_states = transformer_block(hidden_states, context) 184 | 185 | hidden_states = hidden_states.reshape(batch, height, width, channels) 186 | # hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) 187 | 188 | hidden_states = self.proj_out(hidden_states) 189 | hidden_states = hidden_states + residual 190 | 191 | return hidden_states 192 | 193 | 194 | ############################################################################################## 195 | # UNet Blocks 196 | ############################################################################################## 197 | 198 | 199 | class Timesteps(nn.Module): 200 | dim: int = 32 201 | 202 | @nn.compact 203 | def __call__(self, timesteps): 204 | return get_sinusoidal_embeddings(timesteps, self.dim) 205 | 206 | 207 | class TimestepEmbedding(nn.Module): 208 | time_embed_dim: int = 32 209 | dtype: jnp.dtype = jnp.float32 210 | 211 | @nn.compact 212 | def __call__(self, temb): 213 | temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) 214 | temb = nn.silu(temb) 215 | temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) 216 | return temb 217 | 218 | 219 | class Upsample(nn.Module): 220 | out_channels: int 221 | dtype: jnp.dtype = jnp.float32 222 | 223 | def setup(self): 224 | self.conv = nn.Conv( 225 | self.out_channels, 226 | kernel_size=(3, 3), 227 | strides=(1, 1), 228 | padding=((1, 1), (1, 1)), 229 | dtype=self.dtype, 230 | ) 231 | 232 | def __call__(self, hidden_states): 233 | batch, height, width, channels = hidden_states.shape 234 | hidden_states = jax.image.resize( 235 | hidden_states, 236 | shape=(batch, height * 2, width * 2, channels), 237 | method="nearest", 238 | ) 239 | hidden_states = self.conv(hidden_states) 240 | return hidden_states 241 | 242 | 243 | class Downsample(nn.Module): 244 | out_channels: int 245 | dtype: jnp.dtype = jnp.float32 246 | 247 | def setup(self): 248 | self.conv = nn.Conv( 249 | self.out_channels, 250 | kernel_size=(3, 3), 251 | strides=(2, 2), 252 | padding=((1, 1), (1, 1)), # padding="VALID", 253 | dtype=self.dtype, 254 | ) 255 | 256 | def __call__(self, hidden_states): 257 | # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim 258 | # hidden_states = jnp.pad(hidden_states, pad_width=pad) 259 | hidden_states = self.conv(hidden_states) 260 | return hidden_states 261 | 262 | 263 | class ResnetBlock(nn.Module): 264 | in_channels: int 265 | out_channels: int = None 266 | dropout_prob: float = 0.0 267 | use_nin_shortcut: bool = None 268 | dtype: jnp.dtype = jnp.float32 269 | 270 | def setup(self): 271 | out_channels = self.in_channels if self.out_channels is None else self.out_channels 272 | 273 | self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) 274 | self.conv1 = nn.Conv( 275 | out_channels, 276 | kernel_size=(3, 3), 277 | strides=(1, 1), 278 | padding=((1, 1), (1, 1)), 279 | dtype=self.dtype, 280 | ) 281 | 282 | self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) 283 | 284 | self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) 285 | self.dropout = nn.Dropout(self.dropout_prob) 286 | self.conv2 = nn.Conv( 287 | out_channels, 288 | kernel_size=(3, 3), 289 | strides=(1, 1), 290 | padding=((1, 1), (1, 1)), 291 | dtype=self.dtype, 292 | ) 293 | 294 | use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut 295 | 296 | self.conv_shortcut = None 297 | if use_nin_shortcut: 298 | self.conv_shortcut = nn.Conv( 299 | out_channels, 300 | kernel_size=(1, 1), 301 | strides=(1, 1), 302 | padding="VALID", 303 | dtype=self.dtype, 304 | ) 305 | 306 | def __call__(self, hidden_states, temb, deterministic=True): 307 | residual = hidden_states 308 | hidden_states = self.norm1(hidden_states) 309 | hidden_states = nn.swish(hidden_states) 310 | hidden_states = self.conv1(hidden_states) 311 | 312 | temb = self.time_emb_proj(nn.swish(temb)) 313 | temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) 314 | hidden_states = hidden_states + temb 315 | 316 | hidden_states = self.norm2(hidden_states) 317 | hidden_states = nn.swish(hidden_states) 318 | hidden_states = self.dropout(hidden_states, deterministic) 319 | hidden_states = self.conv2(hidden_states) 320 | 321 | if self.conv_shortcut is not None: 322 | residual = self.conv_shortcut(residual) 323 | 324 | return hidden_states + residual 325 | 326 | 327 | class CrossAttnDownBlock2D(nn.Module): 328 | in_channels: int 329 | out_channels: int 330 | dropout: float = 0.0 331 | num_layers: int = 1 332 | attn_num_head_channels: int = 1 333 | add_downsample: bool = True 334 | dtype: jnp.dtype = jnp.float32 335 | 336 | def setup(self): 337 | resnets = [] 338 | attentions = [] 339 | 340 | for i in range(self.num_layers): 341 | in_channels = self.in_channels if i == 0 else self.out_channels 342 | 343 | res_block = ResnetBlock( 344 | in_channels=in_channels, 345 | out_channels=self.out_channels, 346 | dropout_prob=self.dropout, 347 | dtype=self.dtype, 348 | ) 349 | resnets.append(res_block) 350 | 351 | attn_block = SpatialTransformer( 352 | in_channels=self.out_channels, 353 | n_heads=self.attn_num_head_channels, 354 | d_head=self.out_channels // self.attn_num_head_channels, 355 | depth=1, 356 | dtype=self.dtype, 357 | ) 358 | attentions.append(attn_block) 359 | 360 | self.resnets = resnets 361 | self.attentions = attentions 362 | 363 | if self.add_downsample: 364 | self.downsample = Downsample(self.out_channels, dtype=self.dtype) 365 | 366 | def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): 367 | output_states = () 368 | 369 | for resnet, attn in zip(self.resnets, self.attentions): 370 | hidden_states = resnet(hidden_states, temb) 371 | hidden_states = attn(hidden_states, encoder_hidden_states) 372 | output_states += (hidden_states,) 373 | 374 | if self.add_downsample: 375 | hidden_states = self.downsample(hidden_states) 376 | output_states += (hidden_states,) 377 | 378 | return hidden_states, output_states 379 | 380 | 381 | class DownBlock2D(nn.Module): 382 | in_channels: int 383 | out_channels: int 384 | dropout: float = 0.0 385 | num_layers: int = 1 386 | add_downsample: bool = True 387 | dtype: jnp.dtype = jnp.float32 388 | 389 | def setup(self): 390 | resnets = [] 391 | 392 | for i in range(self.num_layers): 393 | in_channels = self.in_channels if i == 0 else self.out_channels 394 | 395 | res_block = ResnetBlock( 396 | in_channels=in_channels, 397 | out_channels=self.out_channels, 398 | dropout_prob=self.dropout, 399 | dtype=self.dtype, 400 | ) 401 | resnets.append(res_block) 402 | self.resnets = resnets 403 | 404 | if self.add_downsample: 405 | self.downsample = Downsample(self.out_channels, dtype=self.dtype) 406 | 407 | def __call__(self, hidden_states, temb, deterministic=True): 408 | output_states = () 409 | 410 | for resnet in self.resnets: 411 | hidden_states = resnet(hidden_states, temb) 412 | output_states += (hidden_states,) 413 | 414 | if self.add_downsample: 415 | hidden_states = self.downsample(hidden_states) 416 | output_states += (hidden_states,) 417 | 418 | return hidden_states, output_states 419 | 420 | 421 | class CrossAttnUpBlock2D(nn.Module): 422 | in_channels: int 423 | out_channels: int 424 | prev_output_channel: int 425 | dropout: float = 0.0 426 | num_layers: int = 1 427 | attn_num_head_channels: int = 1 428 | add_upsample: bool = True 429 | dtype: jnp.dtype = jnp.float32 430 | 431 | def setup(self): 432 | resnets = [] 433 | attentions = [] 434 | 435 | for i in range(self.num_layers): 436 | res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels 437 | resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels 438 | 439 | res_block = ResnetBlock( 440 | in_channels=resnet_in_channels + res_skip_channels, 441 | out_channels=self.out_channels, 442 | dropout_prob=self.dropout, 443 | dtype=self.dtype, 444 | ) 445 | resnets.append(res_block) 446 | 447 | attn_block = SpatialTransformer( 448 | in_channels=self.out_channels, 449 | n_heads=self.attn_num_head_channels, 450 | d_head=self.out_channels // self.attn_num_head_channels, 451 | depth=1, 452 | dtype=self.dtype, 453 | ) 454 | attentions.append(attn_block) 455 | 456 | self.resnets = resnets 457 | self.attentions = attentions 458 | 459 | if self.add_upsample: 460 | self.upsample = Upsample(self.out_channels, dtype=self.dtype) 461 | 462 | def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): 463 | 464 | for resnet, attn in zip(self.resnets, self.attentions): 465 | # pop res hidden states 466 | res_hidden_states = res_hidden_states_tuple[-1] 467 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 468 | hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) 469 | 470 | hidden_states = resnet(hidden_states, temb) 471 | hidden_states = attn(hidden_states, encoder_hidden_states) 472 | 473 | if self.add_upsample: 474 | hidden_states = self.upsample(hidden_states) 475 | 476 | return hidden_states 477 | 478 | 479 | class UpBlock2D(nn.Module): 480 | in_channels: int 481 | out_channels: int 482 | prev_output_channel: int 483 | dropout: float = 0.0 484 | num_layers: int = 1 485 | add_upsample: bool = True 486 | dtype: jnp.dtype = jnp.float32 487 | 488 | def setup(self): 489 | resnets = [] 490 | 491 | for i in range(self.num_layers): 492 | res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels 493 | resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels 494 | 495 | res_block = ResnetBlock( 496 | in_channels=resnet_in_channels + res_skip_channels, 497 | out_channels=self.out_channels, 498 | dropout_prob=self.dropout, 499 | dtype=self.dtype, 500 | ) 501 | resnets.append(res_block) 502 | 503 | self.resnets = resnets 504 | 505 | if self.add_upsample: 506 | self.upsample = Upsample(self.out_channels, dtype=self.dtype) 507 | 508 | def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): 509 | for resnet in self.resnets: 510 | # pop res hidden states 511 | res_hidden_states = res_hidden_states_tuple[-1] 512 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 513 | hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) 514 | 515 | hidden_states = resnet(hidden_states, temb) 516 | 517 | if self.add_upsample: 518 | hidden_states = self.upsample(hidden_states) 519 | 520 | return hidden_states 521 | 522 | 523 | class UNetMidBlock2DCrossAttn(nn.Module): 524 | in_channels: int 525 | dropout: float = 0.0 526 | num_layers: int = 1 527 | attn_num_head_channels: int = 1 528 | dtype: jnp.dtype = jnp.float32 529 | 530 | def setup(self): 531 | # there is always at least one resnet 532 | resnets = [ 533 | ResnetBlock( 534 | in_channels=self.in_channels, 535 | out_channels=self.in_channels, 536 | dropout_prob=self.dropout, 537 | dtype=self.dtype, 538 | ) 539 | ] 540 | 541 | attentions = [] 542 | 543 | for _ in range(self.num_layers): 544 | attn_block = SpatialTransformer( 545 | in_channels=self.in_channels, 546 | n_heads=self.attn_num_head_channels, 547 | d_head=self.in_channels // self.attn_num_head_channels, 548 | depth=1, 549 | dtype=self.dtype, 550 | ) 551 | attentions.append(attn_block) 552 | 553 | res_block = ResnetBlock( 554 | in_channels=self.in_channels, 555 | out_channels=self.in_channels, 556 | dropout_prob=self.dropout, 557 | dtype=self.dtype, 558 | ) 559 | resnets.append(res_block) 560 | 561 | self.resnets = resnets 562 | self.attentions = attentions 563 | 564 | def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): 565 | hidden_states = self.resnets[0](hidden_states, temb) 566 | for attn, resnet in zip(self.attentions, self.resnets[1:]): 567 | hidden_states = attn(hidden_states, encoder_hidden_states) 568 | hidden_states = resnet(hidden_states, temb) 569 | 570 | return hidden_states 571 | 572 | 573 | class UNet2DModule(nn.Module): 574 | config: UNet2DConfig 575 | dtype: jnp.dtype = jnp.float32 576 | 577 | def setup(self): 578 | config = self.config 579 | 580 | self.sample_size = config.sample_size 581 | block_out_channels = config.block_out_channels 582 | time_embed_dim = block_out_channels[0] * 4 583 | 584 | # input 585 | self.conv_in = nn.Conv( 586 | block_out_channels[0], 587 | kernel_size=(3, 3), 588 | strides=(1, 1), 589 | padding=((1, 1), (1, 1)), 590 | dtype=self.dtype, 591 | ) 592 | 593 | # time 594 | self.time_proj = Timesteps(block_out_channels[0]) 595 | self.time_embedding = TimestepEmbedding(time_embed_dim, dtype=self.dtype) 596 | 597 | # down 598 | down_blocks = [] 599 | output_channel = block_out_channels[0] 600 | for i, down_block_type in enumerate(config.down_block_types): 601 | input_channel = output_channel 602 | output_channel = block_out_channels[i] 603 | is_final_block = i == len(block_out_channels) - 1 604 | 605 | if down_block_type == "CrossAttnDownBlock2D": 606 | down_block = CrossAttnDownBlock2D( 607 | in_channels=input_channel, 608 | out_channels=output_channel, 609 | dropout=config.dropout, 610 | num_layers=config.layers_per_block, 611 | attn_num_head_channels=config.attention_head_dim, 612 | add_downsample=not is_final_block, 613 | dtype=self.dtype, 614 | ) 615 | else: 616 | down_block = DownBlock2D( 617 | in_channels=input_channel, 618 | out_channels=output_channel, 619 | dropout=config.dropout, 620 | num_layers=config.layers_per_block, 621 | add_downsample=not is_final_block, 622 | dtype=self.dtype, 623 | ) 624 | 625 | down_blocks.append(down_block) 626 | self.down_blocks = down_blocks 627 | 628 | # mid 629 | self.mid_block = UNetMidBlock2DCrossAttn( 630 | in_channels=block_out_channels[-1], 631 | dropout=config.dropout, 632 | attn_num_head_channels=config.attention_head_dim, 633 | dtype=self.dtype, 634 | ) 635 | 636 | # up 637 | up_blocks = [] 638 | reversed_block_out_channels = list(reversed(block_out_channels)) 639 | output_channel = reversed_block_out_channels[0] 640 | for i, up_block_type in enumerate(config.up_block_types): 641 | prev_output_channel = output_channel 642 | output_channel = reversed_block_out_channels[i] 643 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 644 | 645 | is_final_block = i == len(block_out_channels) - 1 646 | 647 | if up_block_type == "CrossAttnUpBlock2D": 648 | up_block = CrossAttnUpBlock2D( 649 | in_channels=input_channel, 650 | out_channels=output_channel, 651 | prev_output_channel=prev_output_channel, 652 | num_layers=config.layers_per_block + 1, 653 | attn_num_head_channels=config.attention_head_dim, 654 | add_upsample=not is_final_block, 655 | dropout=config.dropout, 656 | dtype=self.dtype, 657 | ) 658 | else: 659 | up_block = UpBlock2D( 660 | in_channels=input_channel, 661 | out_channels=output_channel, 662 | prev_output_channel=prev_output_channel, 663 | num_layers=config.layers_per_block + 1, 664 | add_upsample=not is_final_block, 665 | dropout=config.dropout, 666 | dtype=self.dtype, 667 | ) 668 | 669 | up_blocks.append(up_block) 670 | prev_output_channel = output_channel 671 | self.up_blocks = up_blocks 672 | 673 | # out 674 | self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5) 675 | self.conv_out = nn.Conv( 676 | config.out_channels, 677 | kernel_size=(3, 3), 678 | strides=(1, 1), 679 | padding=((1, 1), (1, 1)), 680 | dtype=self.dtype, 681 | ) 682 | 683 | def __call__(self, sample, timesteps, encoder_hidden_states, deterministic=True): 684 | 685 | # 1. time 686 | # broadcast to batch dimension 687 | # timesteps = jnp.broadcast_to(timesteps, (sample.shape[0],) + timesteps.shape) 688 | t_emb = self.time_proj(timesteps) 689 | t_emb = self.time_embedding(t_emb) 690 | 691 | # 2. pre-process 692 | sample = self.conv_in(sample) 693 | 694 | # 3. down 695 | down_block_res_samples = (sample,) 696 | for down_block in self.down_blocks: 697 | if isinstance(down_block, CrossAttnDownBlock2D): 698 | sample, res_samples = down_block(sample, t_emb, encoder_hidden_states) 699 | else: 700 | sample, res_samples = down_block(sample, t_emb) 701 | down_block_res_samples += res_samples 702 | 703 | # 4. mid 704 | sample = self.mid_block(sample, t_emb, encoder_hidden_states) 705 | 706 | # 5. up 707 | for up_block in self.up_blocks: 708 | res_samples = down_block_res_samples[-(self.config.layers_per_block + 1) :] 709 | down_block_res_samples = down_block_res_samples[: -(self.config.layers_per_block + 1)] 710 | if isinstance(up_block, CrossAttnUpBlock2D): 711 | sample = up_block( 712 | sample, 713 | temb=t_emb, 714 | encoder_hidden_states=encoder_hidden_states, 715 | res_hidden_states_tuple=res_samples, 716 | ) 717 | else: 718 | sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples) 719 | 720 | # 6. post-process 721 | sample = self.conv_norm_out(sample) 722 | sample = nn.silu(sample) 723 | sample = self.conv_out(sample) 724 | 725 | return sample 726 | 727 | 728 | class UNet2DPretrainedModel(FlaxPreTrainedModel): 729 | config_class = UNet2DConfig 730 | base_model_prefix = "model" 731 | module_class: nn.Module = None 732 | 733 | def __init__( 734 | self, 735 | config: UNet2DConfig, 736 | input_shape: Tuple = (1, 32, 32, 4), 737 | seed: int = 0, 738 | dtype: jnp.dtype = jnp.float32, 739 | _do_init: bool = True, 740 | **kwargs, 741 | ): 742 | module = self.module_class(config=config, dtype=dtype, **kwargs) 743 | super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) 744 | 745 | def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: 746 | # init input tensors 747 | sample_shape = (1, self.config.sample_size, self.config.sample_size, self.config.in_channels) 748 | sample = jnp.zeros(sample_shape, dtype=jnp.float32) 749 | timestpes = jnp.ones((1,), dtype=jnp.int32) 750 | encoder_hidden_states = jnp.zeros((1, 1, self.config.cross_attention_dim), dtype=jnp.float32) 751 | 752 | params_rng, dropout_rng = jax.random.split(rng) 753 | rngs = {"params": params_rng, "dropout": dropout_rng} 754 | 755 | return self.module.init(rngs, sample, timestpes, encoder_hidden_states)["params"] 756 | 757 | def __call__( 758 | self, 759 | sample, 760 | timesteps, 761 | encoder_hidden_states, 762 | params: dict = None, 763 | dropout_rng: jax.random.PRNGKey = None, 764 | train: bool = False, 765 | ): 766 | # Handle any PRNG if needed 767 | rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} 768 | 769 | return self.module.apply( 770 | {"params": params or self.params}, 771 | jnp.array(sample), 772 | jnp.array(timesteps, dtype=jnp.int32), 773 | encoder_hidden_states, 774 | not train, 775 | rngs=rngs, 776 | ) 777 | 778 | 779 | class UNet2D(UNet2DPretrainedModel): 780 | module_class = UNet2DModule 781 | -------------------------------------------------------------------------------- /stable_diffusion_jax/modeling_vae.py: -------------------------------------------------------------------------------- 1 | # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers 2 | 3 | import math 4 | from functools import partial 5 | from typing import Tuple 6 | 7 | import flax.linen as nn 8 | import jax 9 | import jax.numpy as jnp 10 | import numpy as np 11 | from flax.core.frozen_dict import FrozenDict 12 | from transformers.modeling_flax_utils import FlaxPreTrainedModel 13 | 14 | from .configuration_vae import VAEConfig 15 | 16 | 17 | class Upsample(nn.Module): 18 | in_channels: int 19 | dtype: jnp.dtype = jnp.float32 20 | 21 | def setup(self): 22 | self.conv = nn.Conv( 23 | self.in_channels, 24 | kernel_size=(3, 3), 25 | strides=(1, 1), 26 | padding=((1, 1), (1, 1)), 27 | dtype=self.dtype, 28 | ) 29 | 30 | def __call__(self, hidden_states): 31 | batch, height, width, channels = hidden_states.shape 32 | hidden_states = jax.image.resize( 33 | hidden_states, 34 | shape=(batch, height * 2, width * 2, channels), 35 | method="nearest", 36 | ) 37 | hidden_states = self.conv(hidden_states) 38 | return hidden_states 39 | 40 | 41 | class Downsample(nn.Module): 42 | in_channels: int 43 | dtype: jnp.dtype = jnp.float32 44 | 45 | def setup(self): 46 | self.conv = nn.Conv( 47 | self.in_channels, 48 | kernel_size=(3, 3), 49 | strides=(2, 2), 50 | padding="VALID", 51 | dtype=self.dtype, 52 | ) 53 | 54 | def __call__(self, hidden_states): 55 | pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim 56 | hidden_states = jnp.pad(hidden_states, pad_width=pad) 57 | hidden_states = self.conv(hidden_states) 58 | return hidden_states 59 | 60 | 61 | class ResnetBlock(nn.Module): 62 | in_channels: int 63 | out_channels: int = None 64 | dropout_prob: float = 0.0 65 | use_nin_shortcut: bool = None 66 | dtype: jnp.dtype = jnp.float32 67 | 68 | def setup(self): 69 | out_channels = self.in_channels if self.out_channels is None else self.out_channels 70 | 71 | self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6) 72 | self.conv1 = nn.Conv( 73 | out_channels, 74 | kernel_size=(3, 3), 75 | strides=(1, 1), 76 | padding=((1, 1), (1, 1)), 77 | dtype=self.dtype, 78 | ) 79 | 80 | self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6) 81 | self.dropout = nn.Dropout(self.dropout_prob) 82 | self.conv2 = nn.Conv( 83 | out_channels, 84 | kernel_size=(3, 3), 85 | strides=(1, 1), 86 | padding=((1, 1), (1, 1)), 87 | dtype=self.dtype, 88 | ) 89 | 90 | use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut 91 | 92 | self.conv_shortcut = None 93 | if use_nin_shortcut: 94 | self.conv_shortcut = nn.Conv( 95 | out_channels, 96 | kernel_size=(1, 1), 97 | strides=(1, 1), 98 | padding="VALID", 99 | dtype=self.dtype, 100 | ) 101 | 102 | def __call__(self, hidden_states, deterministic=True): 103 | residual = hidden_states 104 | hidden_states = self.norm1(hidden_states) 105 | hidden_states = nn.swish(hidden_states) 106 | hidden_states = self.conv1(hidden_states) 107 | 108 | hidden_states = self.norm2(hidden_states) 109 | hidden_states = nn.swish(hidden_states) 110 | hidden_states = self.dropout(hidden_states, deterministic) 111 | hidden_states = self.conv2(hidden_states) 112 | 113 | # import ipdb; ipdb.set_trace() 114 | if self.conv_shortcut is not None: 115 | residual = self.conv_shortcut(residual) 116 | 117 | return hidden_states + residual 118 | 119 | 120 | class AttnBlock(nn.Module): 121 | channels: int 122 | num_head_channels: int = None 123 | dtype: jnp.dtype = jnp.float32 124 | 125 | def setup(self): 126 | self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1 127 | 128 | dense = partial(nn.Dense, self.channels, dtype=self.dtype) 129 | 130 | self.group_norm = nn.GroupNorm(num_groups=32, epsilon=1e-6) 131 | self.query, self.key, self.value = dense(), dense(), dense() 132 | self.proj_attn = dense() 133 | 134 | def transpose_for_scores(self, projection): 135 | new_projection_shape = projection.shape[:-1] + (self.num_heads, -1) 136 | # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) 137 | new_projection = projection.reshape(new_projection_shape) 138 | # (B, T, H, D) -> (B, H, T, D) 139 | new_projection = jnp.transpose(new_projection, (0, 2, 1, 3)) 140 | return new_projection 141 | 142 | def __call__(self, hidden_states): 143 | residual = hidden_states 144 | batch, height, width, channels = hidden_states.shape 145 | 146 | hidden_states = self.group_norm(hidden_states) 147 | 148 | hidden_states = hidden_states.reshape((batch, height * width, channels)) 149 | 150 | query = self.query(hidden_states) 151 | key = self.key(hidden_states) 152 | value = self.value(hidden_states) 153 | 154 | # transpose 155 | query = self.transpose_for_scores(query) 156 | key = self.transpose_for_scores(key) 157 | value = self.transpose_for_scores(value) 158 | 159 | # compute attentions 160 | scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) 161 | attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale) 162 | attn_weights = nn.softmax(attn_weights, axis=-1) 163 | 164 | # attend to values 165 | hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) 166 | 167 | hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3)) 168 | new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,) 169 | hidden_states = hidden_states.reshape(new_hidden_states_shape) 170 | 171 | hidden_states = self.proj_attn(hidden_states) 172 | hidden_states = hidden_states.reshape((batch, height, width, channels)) 173 | hidden_states = hidden_states + residual 174 | return hidden_states 175 | 176 | 177 | class DownBlock2D(nn.Module): 178 | in_channels: int 179 | out_channels: int 180 | dropout: float = 0.0 181 | num_layers: int = 1 182 | add_downsample: bool = True 183 | dtype: jnp.dtype = jnp.float32 184 | 185 | def setup(self): 186 | resnets = [] 187 | for i in range(self.num_layers): 188 | in_channels = self.in_channels if i == 0 else self.out_channels 189 | 190 | res_block = ResnetBlock( 191 | in_channels=in_channels, 192 | out_channels=self.out_channels, 193 | dropout_prob=self.dropout, 194 | dtype=self.dtype, 195 | ) 196 | resnets.append(res_block) 197 | self.resnets = resnets 198 | 199 | if self.add_downsample: 200 | self.downsample = Downsample(self.out_channels, dtype=self.dtype) 201 | 202 | def __call__(self, hidden_states, deterministic=True): 203 | for resnet in self.resnets: 204 | hidden_states = resnet(hidden_states, deterministic=deterministic) 205 | 206 | if self.add_downsample: 207 | hidden_states = self.downsample(hidden_states) 208 | 209 | return hidden_states 210 | 211 | 212 | class UpBlock2D(nn.Module): 213 | in_channels: int 214 | out_channels: int 215 | dropout: float = 0.0 216 | num_layers: int = 1 217 | add_upsample: bool = True 218 | dtype: jnp.dtype = jnp.float32 219 | 220 | def setup(self): 221 | resnets = [] 222 | for i in range(self.num_layers): 223 | in_channels = self.in_channels if i == 0 else self.out_channels 224 | res_block = ResnetBlock( 225 | in_channels=in_channels, 226 | out_channels=self.out_channels, 227 | dropout_prob=self.dropout, 228 | dtype=self.dtype, 229 | ) 230 | resnets.append(res_block) 231 | 232 | self.resnets = resnets 233 | 234 | if self.add_upsample: 235 | self.upsample = Upsample(self.out_channels, dtype=self.dtype) 236 | 237 | def __call__(self, hidden_states, deterministic=True): 238 | for resnet in self.resnets: 239 | hidden_states = resnet(hidden_states, deterministic=deterministic) 240 | 241 | if self.add_upsample: 242 | hidden_states = self.upsample(hidden_states) 243 | 244 | return hidden_states 245 | 246 | 247 | class UNetMidBlock2D(nn.Module): 248 | in_channels: int 249 | dropout: float = 0.0 250 | num_layers: int = 1 251 | attn_num_head_channels: int = 1 252 | dtype: jnp.dtype = jnp.float32 253 | 254 | def setup(self): 255 | # there is always at least one resnet 256 | resnets = [ 257 | ResnetBlock( 258 | in_channels=self.in_channels, 259 | out_channels=self.in_channels, 260 | dropout_prob=self.dropout, 261 | dtype=self.dtype, 262 | ) 263 | ] 264 | 265 | attentions = [] 266 | 267 | for _ in range(self.num_layers): 268 | attn_block = AttnBlock( 269 | channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype 270 | ) 271 | attentions.append(attn_block) 272 | 273 | res_block = ResnetBlock( 274 | in_channels=self.in_channels, 275 | out_channels=self.in_channels, 276 | dropout_prob=self.dropout, 277 | dtype=self.dtype, 278 | ) 279 | resnets.append(res_block) 280 | 281 | self.resnets = resnets 282 | self.attentions = attentions 283 | 284 | def __call__(self, hidden_states, deterministic=True): 285 | hidden_states = self.resnets[0](hidden_states, deterministic=deterministic) 286 | for attn, resnet in zip(self.attentions, self.resnets[1:]): 287 | hidden_states = attn(hidden_states) 288 | hidden_states = resnet(hidden_states, deterministic=deterministic) 289 | 290 | return hidden_states 291 | 292 | 293 | class Encoder(nn.Module): 294 | config: VAEConfig 295 | dtype: jnp.dtype = jnp.float32 296 | 297 | def setup(self): 298 | block_out_channels = self.config.block_out_channels 299 | # in 300 | self.conv_in = nn.Conv( 301 | block_out_channels[0], 302 | kernel_size=(3, 3), 303 | strides=(1, 1), 304 | padding=((1, 1), (1, 1)), 305 | dtype=self.dtype, 306 | ) 307 | 308 | # downsampling 309 | down_blocks = [] 310 | output_channel = block_out_channels[0] 311 | for i, _ in enumerate(self.config.down_block_types): 312 | input_channel = output_channel 313 | output_channel = block_out_channels[i] 314 | is_final_block = i == len(block_out_channels) - 1 315 | 316 | down_block = DownBlock2D( 317 | in_channels=input_channel, 318 | out_channels=output_channel, 319 | num_layers=self.config.layers_per_block, 320 | add_downsample=not is_final_block, 321 | dtype=self.dtype, 322 | ) 323 | down_blocks.append(down_block) 324 | self.down_blocks = down_blocks 325 | 326 | # middle 327 | self.mid_block = UNetMidBlock2D( 328 | in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype 329 | ) 330 | 331 | # end 332 | conv_out_channels = 2 * self.config.latent_channels if self.config.double_z else self.config.latent_channels 333 | self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) 334 | self.conv_out = nn.Conv( 335 | conv_out_channels, 336 | kernel_size=(3, 3), 337 | strides=(1, 1), 338 | padding=((1, 1), (1, 1)), 339 | dtype=self.dtype, 340 | ) 341 | 342 | def __call__(self, sample, deterministic: bool = True): 343 | # in 344 | sample = self.conv_in(sample) 345 | 346 | # downsampling 347 | for block in self.down_blocks: 348 | sample = block(sample, deterministic=deterministic) 349 | 350 | # middle 351 | sample = self.mid_block(sample, deterministic=deterministic) 352 | 353 | # end 354 | sample = self.conv_norm_out(sample) 355 | sample = nn.swish(sample) 356 | sample = self.conv_out(sample) 357 | 358 | return sample 359 | 360 | 361 | class Decoder(nn.Module): 362 | config: VAEConfig 363 | dtype: jnp.dtype = jnp.float32 364 | 365 | def setup(self): 366 | block_out_channels = self.config.block_out_channels 367 | 368 | # z to block_in 369 | self.conv_in = nn.Conv( 370 | block_out_channels[-1], 371 | kernel_size=(3, 3), 372 | strides=(1, 1), 373 | padding=((1, 1), (1, 1)), 374 | dtype=self.dtype, 375 | ) 376 | 377 | # middle 378 | self.mid_block = UNetMidBlock2D( 379 | in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype 380 | ) 381 | 382 | # upsampling 383 | reversed_block_out_channels = list(reversed(block_out_channels)) 384 | output_channel = reversed_block_out_channels[0] 385 | up_blocks = [] 386 | for i, _ in enumerate(self.config.up_block_types): 387 | prev_output_channel = output_channel 388 | output_channel = reversed_block_out_channels[i] 389 | 390 | is_final_block = i == len(block_out_channels) - 1 391 | 392 | up_block = UpBlock2D( 393 | in_channels=prev_output_channel, 394 | out_channels=output_channel, 395 | num_layers=self.config.layers_per_block + 1, 396 | add_upsample=not is_final_block, 397 | dtype=self.dtype, 398 | ) 399 | up_blocks.append(up_block) 400 | prev_output_channel = output_channel 401 | 402 | self.up_blocks = up_blocks 403 | 404 | # end 405 | self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) 406 | self.conv_out = nn.Conv( 407 | self.config.out_channels, 408 | kernel_size=(3, 3), 409 | strides=(1, 1), 410 | padding=((1, 1), (1, 1)), 411 | dtype=self.dtype, 412 | ) 413 | 414 | def __call__(self, sample, deterministic: bool = True): 415 | # z to block_in 416 | sample = self.conv_in(sample) 417 | 418 | # middle 419 | sample = self.mid_block(sample, deterministic=deterministic) 420 | 421 | # upsampling 422 | for block in self.up_blocks: 423 | sample = block(sample, deterministic=deterministic) 424 | 425 | sample = self.conv_norm_out(sample) 426 | sample = nn.swish(sample) 427 | sample = self.conv_out(sample) 428 | 429 | return sample 430 | 431 | 432 | class DiagonalGaussianDistribution(object): 433 | # TODO: should we pass dtype? 434 | def __init__(self, parameters, deterministic=False): 435 | # Last axis to account for channels-last 436 | self.mean, self.logvar = jnp.split(parameters, 2, axis=-1) 437 | self.logvar = jnp.clip(self.logvar, -30.0, 20.0) 438 | self.deterministic = deterministic 439 | self.std = jnp.exp(0.5 * self.logvar) 440 | self.var = jnp.exp(self.logvar) 441 | if self.deterministic: 442 | self.var = self.std = jnp.zeros_like(self.mean) 443 | 444 | def sample(self, key): 445 | return self.mean + self.std * jax.random.normal(key, self.mean.shape) 446 | 447 | def kl(self, other=None): 448 | if self.deterministic: 449 | return jnp.array([0.0]) 450 | 451 | if other is None: 452 | return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3]) 453 | 454 | return 0.5 * jnp.sum( 455 | jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, 456 | axis=[1, 2, 3], 457 | ) 458 | 459 | def nll(self, sample, axis=[1, 2, 3]): 460 | if self.deterministic: 461 | return jnp.array([0.0]) 462 | 463 | logtwopi = jnp.log(2.0 * jnp.pi) 464 | return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis) 465 | 466 | def mode(self): 467 | return self.mean 468 | 469 | 470 | class AutoencoderKLModule(nn.Module): 471 | config: VAEConfig 472 | dtype: jnp.dtype = jnp.float32 473 | 474 | def setup(self): 475 | self.encoder = Encoder(self.config, dtype=self.dtype) 476 | self.decoder = Decoder(self.config, dtype=self.dtype) 477 | self.quant_conv = nn.Conv( 478 | 2 * self.config.latent_channels, 479 | kernel_size=(1, 1), 480 | strides=(1, 1), 481 | padding="VALID", 482 | dtype=self.dtype, 483 | ) 484 | self.post_quant_conv = nn.Conv( 485 | self.config.latent_channels, 486 | kernel_size=(1, 1), 487 | strides=(1, 1), 488 | padding="VALID", 489 | dtype=self.dtype, 490 | ) 491 | 492 | def encode(self, pixel_values, deterministic: bool = True): 493 | hidden_states = self.encoder(pixel_values, deterministic=deterministic) 494 | moments = self.quant_conv(hidden_states) 495 | posterior = DiagonalGaussianDistribution(moments) 496 | return posterior 497 | 498 | def decode(self, latents, deterministic: bool = True): 499 | hidden_states = self.post_quant_conv(latents) 500 | hidden_states = self.decoder(hidden_states, deterministic=deterministic) 501 | return hidden_states 502 | 503 | def __call__(self, sample, sample_posterior=False, deterministic: bool = True): 504 | posterior = self.encode(sample, deterministic=deterministic) 505 | if sample_posterior: 506 | rng = self.make_rng("gaussian") 507 | hidden_states = posterior.sample(rng) 508 | else: 509 | hidden_states = posterior.mode() 510 | # import ipdb; ipdb.set_trace() 511 | hidden_states = self.decode(hidden_states) 512 | return hidden_states, posterior 513 | 514 | 515 | class AutoencoderKLPreTrainedModel(FlaxPreTrainedModel): 516 | """ 517 | An abstract class to handle weights initialization and a simple interface 518 | for downloading and loading pretrained models. 519 | """ 520 | 521 | config_class = VAEConfig 522 | base_model_prefix = "model" 523 | module_class: nn.Module = None 524 | 525 | def __init__( 526 | self, 527 | config: VAEConfig, 528 | input_shape: Tuple = (1, 256, 256, 3), 529 | seed: int = 0, 530 | dtype: jnp.dtype = jnp.float32, 531 | _do_init: bool = True, 532 | **kwargs, 533 | ): 534 | module = self.module_class(config=config, dtype=dtype, **kwargs) 535 | super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) 536 | 537 | def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: 538 | # init input tensors 539 | sample_shape = (1, self.config.sample_size, self.config.sample_size, self.config.in_channels) 540 | sample = jnp.zeros(sample_shape, dtype=jnp.float32) 541 | params_rng, dropout_rng = jax.random.split(rng) 542 | rngs = {"params": params_rng, "dropout": dropout_rng} 543 | 544 | return self.module.init(rngs, sample)["params"] 545 | 546 | def encode(self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False): 547 | # Handle any PRNG if needed 548 | rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} 549 | 550 | return self.module.apply( 551 | {"params": params or self.params}, jnp.array(pixel_values), not train, rngs=rngs, method=self.module.encode 552 | ) 553 | 554 | def decode(self, hidden_states, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False): 555 | # Handle any PRNG if needed 556 | rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} 557 | 558 | return self.module.apply( 559 | {"params": params or self.params}, 560 | jnp.array(hidden_states), 561 | not train, 562 | rngs=rngs, 563 | method=self.module.decode, 564 | ) 565 | 566 | def decode_code(self, indices, params: dict = None): 567 | return self.module.apply( 568 | {"params": params or self.params}, jnp.array(indices, dtype="i4"), method=self.module.decode_code 569 | ) 570 | 571 | def __call__( 572 | self, 573 | pixel_values, 574 | sample_posterior: bool = False, 575 | params: dict = None, 576 | dropout_rng: jax.random.PRNGKey = None, 577 | train: bool = False, 578 | ): 579 | # Handle any PRNG if needed 580 | rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} 581 | 582 | return self.module.apply( 583 | {"params": params or self.params}, 584 | jnp.array(pixel_values), 585 | sample_posterior, 586 | not train, 587 | rngs=rngs, 588 | ) 589 | 590 | 591 | class AutoencoderKL(AutoencoderKLPreTrainedModel): 592 | module_class = AutoencoderKLModule 593 | -------------------------------------------------------------------------------- /stable_diffusion_jax/pipeline_stable_diffusion.py: -------------------------------------------------------------------------------- 1 | import flax 2 | import flax.linen as nn 3 | import jax 4 | import jax.numpy as jnp 5 | from PIL import Image 6 | from transformers import CLIPTokenizer, FlaxCLIPTextModel 7 | 8 | from stable_diffusion_jax.scheduling_pndm import PNDMScheduler 9 | 10 | 11 | @flax.struct.dataclass 12 | class InferenceState: 13 | text_encoder_params: flax.core.FrozenDict 14 | unet_params: flax.core.FrozenDict 15 | vae_params: flax.core.FrozenDict 16 | 17 | 18 | class StableDiffusionPipeline: 19 | def __init__(self, vae, text_encoder, tokenizer, unet, scheduler): 20 | scheduler = scheduler.set_format("np") 21 | self.vae = vae 22 | self.text_encoder = text_encoder 23 | self.tokenizer = tokenizer 24 | self.unet = unet 25 | self.scheduler = scheduler 26 | 27 | def numpy_to_pil(images): 28 | """ 29 | Convert a numpy image or a batch of images to a PIL image. 30 | """ 31 | if images.ndim == 3: 32 | images = images[None, ...] 33 | images = (images * 255).round().astype("uint8") 34 | pil_images = [Image.fromarray(image) for image in images] 35 | 36 | return pil_images 37 | 38 | def sample( 39 | self, 40 | input_ids: jnp.ndarray, 41 | uncond_input_ids: jnp.ndarray, 42 | prng_seed: jax.random.PRNGKey, 43 | inference_state: InferenceState, 44 | num_inference_steps: int = 50, 45 | guidance_scale: float = 1.0, 46 | debug: bool = False, 47 | ): 48 | 49 | self.scheduler.set_timesteps(num_inference_steps, offset=1) 50 | 51 | text_embeddings = self.text_encoder(input_ids, params=inference_state.text_encoder_params)[0] 52 | uncond_embeddings = self.text_encoder(uncond_input_ids, params=inference_state.text_encoder_params)[0] 53 | context = jnp.concatenate([uncond_embeddings, text_embeddings]) 54 | 55 | latents_shape = ( 56 | input_ids.shape[0], 57 | self.unet.config.sample_size, 58 | self.unet.config.sample_size, 59 | self.unet.config.in_channels, 60 | ) 61 | latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) 62 | 63 | def loop_body(step, latents): 64 | # For classifier free guidance, we need to do two forward passes. 65 | # Here we concatenate the unconditional and text embeddings into a single batch 66 | # to avoid doing two forward passes 67 | latents_input = jnp.concatenate([latents] * 2) 68 | 69 | t = jnp.array(self.scheduler.timesteps)[step] 70 | timestep = jnp.broadcast_to(t, latents_input.shape[0]) 71 | 72 | # predict the noise residual 73 | noise_pred = self.unet( 74 | latents_input, timestep, encoder_hidden_states=context, params=inference_state.unet_params 75 | ) 76 | # perform guidance 77 | noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) 78 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 79 | 80 | # compute the previous noisy sample x_t -> x_t-1 81 | latents = self.scheduler.step(noise_pred, t, latents)["prev_sample"] 82 | return latents 83 | 84 | if debug: 85 | # run with python for loop 86 | for i in range(num_inference_steps): 87 | latents = loop_body(i, latents) 88 | else: 89 | latents = jax.lax.fori_loop(0, num_inference_steps, loop_body, latents) 90 | 91 | # scale and decode the image latents with vae 92 | latents = 1 / 0.18215 * latents 93 | image = self.vae.decode(latents, params=inference_state.vae_params) 94 | 95 | return image 96 | -------------------------------------------------------------------------------- /stable_diffusion_jax/safety_checker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from flax import linen as nn 3 | import jax.numpy as jnp 4 | import jax 5 | from typing import Optional, Tuple 6 | import optax 7 | import warnings 8 | from flax.core.frozen_dict import FrozenDict, freeze, unfreeze 9 | from flax.traverse_util import flatten_dict, unflatten_dict 10 | 11 | from transformers import CLIPConfig, FlaxCLIPVisionModel, FlaxPreTrainedModel 12 | from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule 13 | 14 | 15 | def cosine_distance(ten_1, ten_2, eps=1e-12): 16 | norm_ten_1 = jnp.divide(ten_1.T, jnp.clip(jnp.linalg.norm(ten_1, axis=1), a_min=eps)).T 17 | norm_ten_2 = jnp.divide(ten_2.T, jnp.clip(jnp.linalg.norm(ten_2, axis=1), a_min=eps)).T 18 | 19 | return jnp.matmul(norm_ten_1, norm_ten_2.T) 20 | 21 | 22 | class StableDiffusionSafetyCheckerModule(nn.Module): 23 | config: CLIPConfig 24 | dtype: jnp.dtype = jnp.float32 25 | 26 | def setup(self): 27 | self.vision_model = FlaxCLIPVisionModule(self.config.vision_config) 28 | self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False) 29 | 30 | self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim)) 31 | self.special_care_embeds = self.param("special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim)) 32 | 33 | self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) 34 | self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,)) 35 | 36 | def __call__(self, clip_input, images=None): 37 | pooled_output = self.vision_model(clip_input)[1] # pooled_output 38 | image_embeds = self.visual_projection(pooled_output) 39 | 40 | special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) 41 | cos_dist = cosine_distance(image_embeds, self.concept_embeds) 42 | 43 | if images is None: 44 | return special_cos_dist, cos_dist 45 | 46 | special_cos_dist = np.asarray(special_cos_dist) 47 | cos_dist = np.asarray(cos_dist) 48 | 49 | result = [] 50 | batch_size = image_embeds.shape[0] 51 | for i in range(batch_size): 52 | result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} 53 | 54 | # increase this value to create a stronger `nfsw` filter 55 | # at the cost of increasing the possibility of filtering benign image inputs 56 | adjustment = 0.0 57 | 58 | for concet_idx in range(len(special_cos_dist[0])): 59 | concept_cos = special_cos_dist[i][concet_idx] 60 | concept_threshold = self.special_care_embeds_weights[concet_idx].item() 61 | result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) 62 | if result_img["special_scores"][concet_idx] > 0: 63 | result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]}) 64 | adjustment = 0.01 65 | 66 | for concet_idx in range(len(cos_dist[0])): 67 | concept_cos = cos_dist[i][concet_idx] 68 | concept_threshold = self.concept_embeds_weights[concet_idx].item() 69 | result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) 70 | if result_img["concept_scores"][concet_idx] > 0: 71 | result_img["bad_concepts"].append(concet_idx) 72 | 73 | result.append(result_img) 74 | 75 | has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] 76 | 77 | images_was_copied = False 78 | for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): 79 | if has_nsfw_concept: 80 | if not images_was_copied: 81 | images_was_copied = True 82 | images = images.copy() 83 | 84 | images[idx] = np.zeros(images[idx].shape) # black image 85 | 86 | if any(has_nsfw_concepts): 87 | warnings.warn( 88 | "Potential NSFW content was detected in one or more images. A black image will be returned instead." 89 | " Try again with a different prompt and/or seed." 90 | ) 91 | 92 | return images, has_nsfw_concepts 93 | 94 | 95 | class StableDiffusionSafetyCheckerModel(FlaxPreTrainedModel): 96 | config_class = CLIPConfig 97 | main_input_name = "pixel_values" 98 | module_class = StableDiffusionSafetyCheckerModule 99 | 100 | def __init__( 101 | self, 102 | config: CLIPConfig, 103 | input_shape: Optional[Tuple] = None, 104 | seed: int = 0, 105 | dtype: jnp.dtype = jnp.float32, 106 | _do_init: bool = True, 107 | **kwargs 108 | ): 109 | if input_shape is None: 110 | input_shape = (1, 224, 224, 3) 111 | module = self.module_class(config=config, dtype=dtype, **kwargs) 112 | super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) 113 | 114 | def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: 115 | # init input tensor 116 | pixel_values = jax.random.normal(rng, input_shape) 117 | 118 | 119 | params_rng, dropout_rng = jax.random.split(rng) 120 | rngs = {"params": params_rng, "dropout": dropout_rng} 121 | 122 | random_params = self.module.init(rngs, pixel_values)["params"] 123 | 124 | return random_params 125 | 126 | def __call__( 127 | self, 128 | pixel_values, 129 | params: dict = None, 130 | images=None, 131 | ): 132 | pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) 133 | 134 | return self.module.apply( 135 | {"params": params or self.params}, 136 | jnp.array(pixel_values, dtype=jnp.float32), 137 | images, 138 | rngs={}, 139 | ) 140 | -------------------------------------------------------------------------------- /stable_diffusion_jax/scheduling_pndm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim 16 | 17 | import math 18 | 19 | import jax.numpy as jnp 20 | import numpy as np 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 23 | 24 | 25 | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): 26 | """ 27 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 28 | (1-beta) over time from t = [0,1]. 29 | 30 | :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t 31 | from 0 to 1 and 32 | produces the cumulative product of (1-beta) up to that part of the diffusion process. 33 | :param max_beta: the maximum beta to use; use values lower than 1 to 34 | prevent singularities. 35 | """ 36 | 37 | def alpha_bar(time_step): 38 | return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 39 | 40 | betas = [] 41 | for i in range(num_diffusion_timesteps): 42 | t1 = i / num_diffusion_timesteps 43 | t2 = (i + 1) / num_diffusion_timesteps 44 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 45 | return jnp.array(betas, dtype=jnp.float32) 46 | 47 | 48 | class PNDMScheduler(SchedulerMixin, ConfigMixin): 49 | @register_to_config 50 | def __init__( 51 | self, 52 | num_train_timesteps=1000, 53 | beta_start=0.00085, 54 | beta_end=0.012, 55 | beta_schedule="scaled_linear", 56 | tensor_format="np", 57 | skip_prk_steps=True, 58 | ): 59 | 60 | if beta_schedule == "linear": 61 | self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) 62 | elif beta_schedule == "scaled_linear": 63 | # this schedule is very specific to the latent diffusion model. 64 | self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 65 | elif beta_schedule == "squaredcos_cap_v2": 66 | # Glide cosine schedule 67 | self.betas = betas_for_alpha_bar(num_train_timesteps) 68 | else: 69 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 70 | 71 | self.alphas = 1.0 - self.betas 72 | self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) 73 | 74 | self.one = jnp.array(1.0) 75 | 76 | # For now we only support F-PNDM, i.e. the runge-kutta method 77 | # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf 78 | # mainly at formula (9), (12), (13) and the Algorithm 2. 79 | self.pndm_order = 4 80 | 81 | # running values 82 | self.cur_model_output = 0 83 | self.counter = 0 84 | self.cur_sample = None 85 | self.ets = [] 86 | 87 | # setable values 88 | self.num_inference_steps = None 89 | self._timesteps = jnp.arange(0, num_train_timesteps)[::-1].copy() 90 | self._offset = 0 91 | self.prk_timesteps = None 92 | self.plms_timesteps = None 93 | self.timesteps = None 94 | 95 | self.tensor_format = tensor_format 96 | self.set_format(tensor_format=tensor_format) 97 | 98 | def set_timesteps(self, num_inference_steps, offset=0): 99 | self.num_inference_steps = num_inference_steps 100 | # self._timesteps = list( 101 | # range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) 102 | # ) 103 | self._timesteps = jnp.arange( 104 | 0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps 105 | ) 106 | self._offset = offset 107 | # self._timesteps = [t + self._offset for t in self._timesteps] 108 | self._timesteps = self._timesteps + self._offset 109 | 110 | if self.config.skip_prk_steps: 111 | # for some models like stable diffusion the prk steps can/should be skipped to 112 | # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation 113 | # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 114 | self.prk_timesteps = jnp.array([]) 115 | # self.plms_timesteps = list(reversed(self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:])) 116 | self.plms_timesteps = jnp.concatenate( 117 | (self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]) 118 | )[::-1] 119 | else: 120 | prk_timesteps = self._timesteps[-self.pndm_order :].repeat(2) + jnp.tile( 121 | jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order 122 | ) 123 | self.prk_timesteps = prk_timesteps[:-1].repeat(2)[1:-1][::-1] 124 | self.plms_timesteps = self._timesteps[:-3][::-1] 125 | 126 | timesteps = jnp.concatenate((self.prk_timesteps, self.plms_timesteps)) 127 | self.timesteps = jnp.array(timesteps, dtype=jnp.int32) 128 | 129 | self.ets = [] 130 | self.counter = 0 131 | self.set_format(tensor_format=self.tensor_format) 132 | 133 | def step( 134 | self, 135 | model_output: jnp.ndarray, 136 | timestep: int, 137 | sample: jnp.ndarray, 138 | ): 139 | if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: 140 | return self.step_prk(model_output=model_output, timestep=timestep, sample=sample) 141 | else: 142 | return self.step_plms(model_output=model_output, timestep=timestep, sample=sample) 143 | 144 | def step_prk( 145 | self, 146 | model_output: jnp.ndarray, 147 | timestep: int, 148 | sample: jnp.ndarray, 149 | ): 150 | """ 151 | Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the 152 | solution to the differential equation. 153 | """ 154 | diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 155 | prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1]) 156 | timestep = self.prk_timesteps[self.counter // 4 * 4] 157 | 158 | if self.counter % 4 == 0: 159 | self.cur_model_output += 1 / 6 * model_output 160 | self.ets.append(model_output) 161 | self.cur_sample = sample 162 | elif (self.counter - 1) % 4 == 0: 163 | self.cur_model_output += 1 / 3 * model_output 164 | elif (self.counter - 2) % 4 == 0: 165 | self.cur_model_output += 1 / 3 * model_output 166 | elif (self.counter - 3) % 4 == 0: 167 | model_output = self.cur_model_output + 1 / 6 * model_output 168 | self.cur_model_output = 0 169 | 170 | # cur_sample should not be `None` 171 | cur_sample = self.cur_sample if self.cur_sample is not None else sample 172 | 173 | prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) 174 | self.counter += 1 175 | 176 | return {"prev_sample": prev_sample} 177 | 178 | def step_plms( 179 | self, 180 | model_output: jnp.ndarray, 181 | timestep: int, 182 | sample: jnp.ndarray, 183 | ): 184 | """ 185 | Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple 186 | times to approximate the solution. 187 | """ 188 | if not self.config.skip_prk_steps and len(self.ets) < 3: 189 | raise ValueError( 190 | f"{self.__class__} can only be run AFTER scheduler has been run " 191 | "in 'prk' mode for at least 12 iterations " 192 | "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " 193 | "for more information." 194 | ) 195 | 196 | prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps 197 | prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0) 198 | 199 | if self.counter != 1: 200 | self.ets.append(model_output) 201 | else: 202 | prev_timestep = timestep 203 | timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps 204 | 205 | if len(self.ets) == 1 and self.counter == 0: 206 | model_output = model_output 207 | self.cur_sample = sample 208 | elif len(self.ets) == 1 and self.counter == 1: 209 | model_output = (model_output + self.ets[-1]) / 2 210 | sample = self.cur_sample 211 | self.cur_sample = None 212 | elif len(self.ets) == 2: 213 | model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 214 | elif len(self.ets) == 3: 215 | model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 216 | else: 217 | model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) 218 | 219 | prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) 220 | self.counter += 1 221 | 222 | return {"prev_sample": prev_sample} 223 | 224 | def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): 225 | # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf 226 | # this function computes x_(t−δ) using the formula of (9) 227 | # Note that x_t needs to be added to both sides of the equation 228 | 229 | # Notation ( -> 230 | # alpha_prod_t -> α_t 231 | # alpha_prod_t_prev -> α_(t−δ) 232 | # beta_prod_t -> (1 - α_t) 233 | # beta_prod_t_prev -> (1 - α_(t−δ)) 234 | # sample -> x_t 235 | # model_output -> e_θ(x_t, t) 236 | # prev_sample -> x_(t−δ) 237 | alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset] 238 | alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset] 239 | beta_prod_t = 1 - alpha_prod_t 240 | beta_prod_t_prev = 1 - alpha_prod_t_prev 241 | 242 | # corresponds to (α_(t−δ) - α_t) divided by 243 | # denominator of x_t in formula (9) and plus 1 244 | # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = 245 | # sqrt(α_(t−δ)) / sqrt(α_t)) 246 | sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) 247 | 248 | # corresponds to denominator of e_θ(x_t, t) in formula (9) 249 | model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( 250 | alpha_prod_t * beta_prod_t * alpha_prod_t_prev 251 | ) ** (0.5) 252 | 253 | # full formula (9) 254 | prev_sample = ( 255 | sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff 256 | ) 257 | 258 | return prev_sample 259 | 260 | def __len__(self): 261 | return self.config.num_train_timesteps 262 | --------------------------------------------------------------------------------