├── .DS_Store ├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── data └── benchmark.png ├── diffusionkit ├── __init__.py ├── mlx │ ├── __init__.py │ ├── clip.py │ ├── config.py │ ├── mmdit.py │ ├── model_io.py │ ├── sampler.py │ ├── scripts │ │ ├── __init__.py │ │ └── generate_images.py │ ├── t5.py │ ├── tokenizer.py │ └── vae.py ├── tests │ ├── __init__.py │ ├── mlx │ │ ├── __init__.py │ │ └── test_diffusion_pipeline.py │ └── torch2coreml │ │ ├── __init__.py │ │ ├── test_mmdit.py │ │ └── test_vae.py └── utils.py ├── pyproject.toml ├── requirements.txt └── workflows ├── .DS_Store └── basic_workflow.json /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoddnn/ComfyUI-MLX/215691b282f3d1eddb2e7029c2c399567cd0be9b/.DS_Store -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | jobs: 11 | publish-node: 12 | name: Publish Custom Node to registry 13 | runs-on: ubuntu-latest 14 | # if this is a forked repository. Skipping the workflow. 15 | if: github.event.repository.fork == false 16 | steps: 17 | - name: Check out code 18 | uses: actions/checkout@v4 19 | - name: Publish Custom Node 20 | uses: Comfy-Org/publish-node-action@main 21 | with: 22 | ## Add your own personal access token to your Github Repository secrets and reference it here. 23 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode 3 | *.venv 4 | *.ipynb -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 MY LITTLE PLANET, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI MLX Nodes 2 | 3 | Faster workflows for ComfyUI users on Mac with Apple silicon 4 | 5 | ## Installation 6 | 7 | 1. Install the MLX nodes from the Custom Nodes Manager: 8 | 9 | - In ComfyUI, Manager > Custom Nodes Manager > Tap 'ComfyUI MLX' > Click Install 10 | 11 | OR 12 | 13 | - In ComfyUI, Manager > Install via Git URL > https://github.com/thoddnn/ComfyUI-MLX.git 14 | 15 | ## Performances 16 | 17 | ![ComfyUI-MLX benchmark](./data/benchmark.png) 18 | 19 | Given the following environment: 20 | 21 | - Device: MacBook M2 Max, 96 GB 22 | 23 | - Model: Flux 1.0 dev (not quantized) 24 | 25 | - Size: 512x512 26 | 27 | - Prompt: Photo of a cat 28 | 29 | - Steps: 10 30 | 31 | I get approximatively: 32 | 33 | - 70% faster when the model needs to be loaded 34 | 35 | - 35% faster when the model is loaded 36 | 37 | - 30% lower memory usage 38 | 39 | ## Getting Started 40 | 41 | A basic workflow is provided to help you start experimenting with the nodes [here](./workflows/basic_workflow.json). 42 | 43 | ## Why ComfyUI MLX Nodes? 44 | 45 | I started building these nodes because image generation from Flux models was taking too much time on my MacBook. After discovering DiffusionKit on X, which showcased great performance for image generation on Apple Silicon, I decided to create a quick port of the library into ComfyUI. 46 | 47 | The goal is to collaborate with other contributors to build a full suite of custom nodes optimized for Apple Silicon. 48 | 49 | Additionally, we aim to minimize the reliance on torch to take full advantage of future MLX improvements and further enhance performance. 50 | 51 | This will allow ComfyUI users on Mac with Apple Silicon to experience faster workflows. 52 | 53 | ## Contributing 54 | 55 | Contributions are welcome! I'm open to best practices and suggestions and you’re encouraged to submit a Pull Request to improve the project. 🙏 56 | 57 | ## Future Plans 58 | 59 | - Loading models from local file 60 | - SDXL models support 61 | - ControlNet support 62 | - LoRA support 63 | - LLM and VLM nodes 64 | - CogXVideo models support 65 | - Build more MLX based nodes for common workflows (based on your requests) 66 | 67 | ## License 68 | 69 | ComfyUI MLX Nodes is released under the MIT License. See [LICENSE](LICENSE) for more details. 70 | 71 | ## Acknowledgements 72 | 73 | - [DiffusionKit](https://github.com/argmaxinc/DiffusionKit) 74 | 75 | ## Support 76 | 77 | If you encounter any problems or have any questions, please open an issue in this repository. -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import PIL.Image 3 | import mlx.core as mx 4 | from typing import Optional, Tuple 5 | from PIL import Image 6 | from .diffusionkit.mlx.tokenizer import Tokenizer, T5Tokenizer 7 | from .diffusionkit.mlx.t5 import SD3T5Encoder 8 | from .diffusionkit.mlx import load_t5_encoder, load_t5_tokenizer, load_tokenizer, load_text_encoder 9 | from .diffusionkit.mlx.clip import CLIPTextModel 10 | from .diffusionkit.mlx.model_io import load_flux 11 | from .diffusionkit.mlx import FluxPipeline 12 | import folder_paths 13 | import torch 14 | import os 15 | import gc 16 | 17 | class MLXDecoder: 18 | 19 | @classmethod 20 | def INPUT_TYPES(s): 21 | return {"required": { "latent_image": ("LATENT", ), "mlx_vae": ("mlx_vae", )}} 22 | 23 | RETURN_TYPES = ("IMAGE",) 24 | FUNCTION = "decode" 25 | 26 | def decode(self, latent_image, mlx_vae): 27 | 28 | decoded = mlx_vae(latent_image) 29 | decoded = mx.clip(decoded / 2 + 0.5, 0, 1) 30 | 31 | mx.eval(decoded) 32 | 33 | # Convert MLX tensor to numpy array 34 | decoded_np = np.array(decoded.astype(mx.float16)) 35 | 36 | # Convert numpy array to PyTorch tensor 37 | decoded_torch = torch.from_numpy(decoded_np).float() 38 | 39 | # Ensure the tensor is in the correct format (B, C, H, W) 40 | if decoded_torch.dim() == 3: 41 | decoded_torch = decoded_torch.unsqueeze(0) 42 | 43 | # Ensure the values are in the range [0, 1] 44 | decoded_torch = torch.clamp(decoded_torch, 0, 1) 45 | 46 | return (decoded_torch,) 47 | 48 | 49 | class MLXSampler: 50 | @classmethod 51 | def INPUT_TYPES(s): 52 | return {"required": 53 | {"mlx_model": ("mlx_model",), 54 | "seed": ("INT", {"default": 0, "min": 0, "max": 2**32 - 1}), 55 | "steps": ("INT", {"default": 4, "min": 1, "max": 10000}), 56 | "cfg": ("FLOAT", {"default": 0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), 57 | "mlx_positive_conditioning": ("mlx_conditioning", ), 58 | "latent_image": ("LATENT", ), 59 | "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), 60 | } 61 | } 62 | 63 | RETURN_TYPES = ("LATENT",) 64 | FUNCTION = "generate_image" 65 | 66 | def generate_image(self, mlx_model, seed, steps, cfg, mlx_positive_conditioning, latent_image, denoise): 67 | 68 | conditioning = mlx_positive_conditioning["conditioning"] 69 | pooled_conditioning = mlx_positive_conditioning["pooled_conditioning"] 70 | num_steps = steps 71 | cfg_weight = cfg 72 | 73 | batch, channels, height, width = latent_image["samples"].shape 74 | 75 | latent_size = (height, width) 76 | 77 | latents, iter_time = mlx_model.denoise_latents( 78 | conditioning, 79 | pooled_conditioning, 80 | num_steps=num_steps, 81 | cfg_weight=cfg_weight, 82 | latent_size=latent_size, 83 | seed=seed, 84 | image_path=None, 85 | denoise=denoise, 86 | ) 87 | 88 | mx.eval(latents) 89 | 90 | latents = latents.astype(mlx_model.activation_dtype) 91 | 92 | return (latents,) 93 | 94 | 95 | class MLXLoadFlux: 96 | @classmethod 97 | def INPUT_TYPES(s): 98 | return {"required": { 99 | "model_version": ([ 100 | "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized", 101 | "argmaxinc/mlx-FLUX.1-schnell", 102 | "argmaxinc/mlx-FLUX.1-dev" 103 | ],) 104 | }} 105 | 106 | RETURN_TYPES = ("mlx_model", "mlx_vae", "mlx_conditioning") 107 | FUNCTION = "load_flux_model" 108 | 109 | def check_model_folder(self, filename): 110 | 111 | home_dir = os.path.expanduser("~") 112 | formatted_filename = filename.replace("/", "--") 113 | folder_path = os.path.join(home_dir, ".cache/huggingface/hub/models--" + formatted_filename) 114 | 115 | if os.path.exists(folder_path): 116 | print("Found existing model folder, verifying download...") 117 | else: 118 | print("Model folder not found, downloading from HuggingFace... 🤗") 119 | 120 | def load_flux_model(self, model_version): 121 | 122 | self.check_model_folder(model_version) 123 | 124 | model = FluxPipeline(model_version=model_version, low_memory_mode=False, w16=True, a16=True) 125 | 126 | clip = { 127 | "model_name": model_version, 128 | "clip_l_model": model.clip_l, 129 | "clip_l_tokenizer": model.tokenizer_l, 130 | "t5_model": model.t5_encoder, 131 | "t5_tokenizer": model.t5_tokenizer 132 | } 133 | 134 | print("Model successfully loaded.") 135 | 136 | 137 | return (model, model.decoder, clip) 138 | 139 | 140 | 141 | 142 | class MLXClipTextEncoder: 143 | 144 | @classmethod 145 | def INPUT_TYPES(s): 146 | return {"required": {"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "mlx_conditioning": ("mlx_conditioning", {"forceInput":True})}} 147 | 148 | 149 | RETURN_TYPES = ("mlx_conditioning",) 150 | FUNCTION = "encode" 151 | 152 | 153 | def _tokenize(self, tokenizer, text: str, negative_text: Optional[str] = None): 154 | if negative_text is None: 155 | negative_text = "" 156 | if tokenizer.pad_with_eos: 157 | pad_token = tokenizer.eos_token 158 | else: 159 | pad_token = 0 160 | 161 | text = text.replace('’', '\'') 162 | 163 | # Tokenize the text 164 | tokens = [tokenizer.tokenize(text)] 165 | if tokenizer.pad_to_max_length: 166 | tokens[0].extend([pad_token] * (tokenizer.max_length - len(tokens[0]))) 167 | if negative_text is not None: 168 | tokens += [tokenizer.tokenize(negative_text)] 169 | lengths = [len(t) for t in tokens] 170 | N = max(lengths) 171 | tokens = [t + [pad_token] * (N - len(t)) for t in tokens] 172 | tokens = mx.array(tokens) 173 | 174 | return tokens 175 | 176 | def encode(self, mlx_conditioning, text): 177 | 178 | T5_MAX_LENGTH = { 179 | "argmaxinc/mlx-stable-diffusion-3-medium": 512, 180 | "argmaxinc/mlx-FLUX.1-schnell": 256, 181 | "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 256, 182 | "argmaxinc/mlx-FLUX.1-dev": 512, 183 | } 184 | 185 | model_name = mlx_conditioning["model_name"] 186 | clip_l_encoder:CLIPTextModel = mlx_conditioning["clip_l_model"] 187 | clip_l_tokenizer:Tokenizer = mlx_conditioning["clip_l_tokenizer"] 188 | t5_encoder:SD3T5Encoder = mlx_conditioning["t5_model"] 189 | t5_tokenizer:T5Tokenizer = mlx_conditioning["t5_tokenizer"] 190 | 191 | # CLIP processing 192 | clip_tokens = self._tokenize(tokenizer=clip_l_tokenizer, text=text) 193 | 194 | clip_l_embeddings = clip_l_encoder(clip_tokens[[0], :]) 195 | 196 | clip_last_hidden_state = clip_l_embeddings.last_hidden_state 197 | clip_pooled_output = clip_l_embeddings.pooled_output 198 | 199 | # T5 processing 200 | t5_tokens = self._tokenize(tokenizer=t5_tokenizer, text=text) 201 | 202 | padded_tokens_t5 = mx.zeros((1, T5_MAX_LENGTH[model_name])).astype( 203 | t5_tokens.dtype 204 | ) 205 | 206 | padded_tokens_t5[:, : t5_tokens.shape[1]] = t5_tokens[ 207 | [0], : 208 | ] # Ignore negative text 209 | 210 | t5_embeddings = t5_encoder(padded_tokens_t5) 211 | 212 | # Use T5 embeddings as main conditioning 213 | conditioning = t5_embeddings 214 | 215 | output = { 216 | "conditioning": t5_embeddings, 217 | "pooled_conditioning": clip_pooled_output 218 | } 219 | 220 | 221 | 222 | return (output, ) 223 | 224 | # Node class mappings 225 | NODE_CLASS_MAPPINGS = { 226 | "MLXClipTextEncoder": MLXClipTextEncoder, 227 | "MLXLoadFlux": MLXLoadFlux, 228 | "MLXSampler": MLXSampler, 229 | "MLXDecoder": MLXDecoder 230 | } 231 | 232 | # Node display name mappings 233 | NODE_DISPLAY_NAME_MAPPINGS = { 234 | "MLXClipTextEncoder": "MLX CLIP Text Encoder", 235 | "MLXLoadFlux": "MLX Load Flux Model from HF 🤗", 236 | "MLXSampler": "MLX Sampler", 237 | "MLXDecoder": "MLX Decoder" 238 | } 239 | -------------------------------------------------------------------------------- /data/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoddnn/ComfyUI-MLX/215691b282f3d1eddb2e7029c2c399567cd0be9b/data/benchmark.png -------------------------------------------------------------------------------- /diffusionkit/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 4 | -------------------------------------------------------------------------------- /diffusionkit/mlx/__init__.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion 2 | 3 | # 4 | # For licensing see accompanying LICENSE.md file. 5 | # Copyright (C) 2024 Argmax, Inc. All Rights Reserved. 6 | # 7 | 8 | import gc 9 | import math 10 | import time 11 | from pprint import pprint 12 | from typing import Optional, Tuple 13 | 14 | import mlx.core as mx 15 | import mlx.nn as nn 16 | import numpy as np 17 | from argmaxtools.test_utils import AppleSiliconContextMixin, InferenceContextSpec 18 | from argmaxtools.utils import get_logger 19 | from ...diffusionkit.utils import bytes2gigabytes 20 | from PIL import Image 21 | 22 | from .model_io import ( 23 | _DEFAULT_MODEL, 24 | load_flux, 25 | load_mmdit, 26 | load_t5_encoder, 27 | load_t5_tokenizer, 28 | load_text_encoder, 29 | load_tokenizer, 30 | load_vae_decoder, 31 | load_vae_encoder, 32 | ) 33 | from .sampler import FluxSampler, ModelSamplingDiscreteFlow 34 | 35 | logger = get_logger(__name__) 36 | 37 | MMDIT_CKPT = { 38 | "argmaxinc/mlx-stable-diffusion-3-medium": "argmaxinc/mlx-stable-diffusion-3-medium", 39 | "sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased 40 | "argmaxinc/mlx-FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell", 41 | "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized", 42 | "argmaxinc/mlx-FLUX.1-dev": "argmaxinc/mlx-FLUX.1-dev", 43 | } 44 | 45 | T5_MAX_LENGTH = { 46 | "argmaxinc/mlx-stable-diffusion-3-medium": 512, 47 | "argmaxinc/mlx-FLUX.1-schnell": 256, 48 | "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 256, 49 | "argmaxinc/mlx-FLUX.1-dev": 512, 50 | } 51 | 52 | 53 | class DiffusionKitInferenceContext(AppleSiliconContextMixin, InferenceContextSpec): 54 | def code_spec(self): 55 | return {} 56 | 57 | def model_spec(self): 58 | return {} 59 | 60 | 61 | class DiffusionPipeline: 62 | def __init__( 63 | self, 64 | w16: bool = False, 65 | shift: float = 1.0, 66 | use_t5: bool = True, 67 | model_version: str = "argmaxinc/mlx-stable-diffusion-3-medium", 68 | low_memory_mode: bool = True, 69 | a16: bool = False, 70 | local_ckpt=None, 71 | ): 72 | model_io.LOCAl_SD3_CKPT = local_ckpt 73 | self.float16_dtype = mx.float16 74 | model_io._FLOAT16 = self.float16_dtype 75 | self.dtype = self.float16_dtype if w16 else mx.float32 76 | self.activation_dtype = self.float16_dtype if a16 else mx.float32 77 | self.use_t5 = use_t5 78 | self.mmdit_ckpt = MMDIT_CKPT[model_version] 79 | self.low_memory_mode = low_memory_mode 80 | self.model = _DEFAULT_MODEL 81 | self.model_version = model_version 82 | self.sampler = ModelSamplingDiscreteFlow(shift=shift) 83 | self.latent_format = SD3LatentFormat() 84 | self.use_clip_g = True 85 | self.check_and_load_models() 86 | 87 | def load_mmdit(self, only_modulation_dict=False): 88 | if only_modulation_dict: 89 | return load_mmdit( 90 | float16=True if self.dtype == self.float16_dtype else False, 91 | key=self.mmdit_ckpt, 92 | model_key=self.model_version, 93 | low_memory_mode=self.low_memory_mode, 94 | only_modulation_dict=only_modulation_dict, 95 | ) 96 | self.mmdit = load_mmdit( 97 | float16=True if self.dtype == self.float16_dtype else False, 98 | key=self.mmdit_ckpt, 99 | model_key=self.model_version, 100 | low_memory_mode=self.low_memory_mode, 101 | only_modulation_dict=only_modulation_dict, 102 | ) 103 | 104 | def check_and_load_models(self): 105 | if not hasattr(self, "mmdit"): 106 | self.load_mmdit() 107 | if not hasattr(self, "decoder"): 108 | self.decoder = load_vae_decoder( 109 | float16=True if self.dtype == self.float16_dtype else False, 110 | key=self.mmdit_ckpt, 111 | ) 112 | if not hasattr(self, "encoder"): 113 | self.encoder = load_vae_encoder(float16=False, key=self.mmdit_ckpt) 114 | 115 | if not hasattr(self, "clip_l"): 116 | self.clip_l = load_text_encoder( 117 | self.model, 118 | float16=True if self.dtype == self.float16_dtype else False, 119 | model_key="clip_l", 120 | ) 121 | self.tokenizer_l = load_tokenizer( 122 | self.model, 123 | merges_key="tokenizer_l_merges", 124 | vocab_key="tokenizer_l_vocab", 125 | pad_with_eos=True, 126 | ) 127 | if self.use_clip_g and not hasattr(self, "clip_g"): 128 | self.clip_g = load_text_encoder( 129 | self.model, 130 | float16=True if self.dtype == self.float16_dtype else False, 131 | model_key="clip_g", 132 | ) 133 | self.tokenizer_g = load_tokenizer( 134 | self.model, 135 | merges_key="tokenizer_g_merges", 136 | vocab_key="tokenizer_g_vocab", 137 | pad_with_eos=False, 138 | ) 139 | if self.use_t5 and not hasattr(self, "t5_encoder"): 140 | self.set_up_t5() 141 | 142 | def set_up_t5(self): 143 | if not hasattr(self, "t5_encoder") or self.t5_encoder is None: 144 | self.t5_encoder = load_t5_encoder( 145 | float16=True if self.dtype == self.float16_dtype else False, 146 | low_memory_mode=self.low_memory_mode, 147 | ) 148 | if not hasattr(self, "t5_tokenizer") or self.t5_tokenizer is None: 149 | self.t5_tokenizer = load_t5_tokenizer( 150 | max_context_length=T5_MAX_LENGTH[self.model_version] 151 | ) 152 | self.use_t5 = True 153 | 154 | def unload_t5(self): 155 | if self.t5_encoder is not None: 156 | del self.t5_encoder 157 | self.t5_encoder = None 158 | if self.t5_tokenizer is not None: 159 | del self.t5_tokenizer 160 | self.t5_tokenizer = None 161 | gc.collect() 162 | self.use_t5 = False 163 | 164 | def ensure_models_are_loaded(self): 165 | mx.eval(self.mmdit.parameters()) 166 | mx.eval(self.clip_l.parameters()) 167 | mx.eval(self.decoder.parameters()) 168 | if hasattr(self, "clip_g"): 169 | mx.eval(self.clip_g.parameters()) 170 | if hasattr(self, "t5_encoder") and self.use_t5: 171 | mx.eval(self.t5_encoder.parameters()) 172 | 173 | def _tokenize(self, tokenizer, text: str, negative_text: Optional[str] = None): 174 | if negative_text is None: 175 | negative_text = "" 176 | if tokenizer.pad_with_eos: 177 | pad_token = tokenizer.eos_token 178 | else: 179 | pad_token = 0 180 | 181 | # Tokenize the text 182 | tokens = [tokenizer.tokenize(text)] 183 | if tokenizer.pad_to_max_length: 184 | tokens[0].extend([pad_token] * (tokenizer.max_length - len(tokens[0]))) 185 | if negative_text is not None: 186 | tokens += [tokenizer.tokenize(negative_text)] 187 | lengths = [len(t) for t in tokens] 188 | N = max(lengths) 189 | tokens = [t + [pad_token] * (N - len(t)) for t in tokens] 190 | tokens = mx.array(tokens) 191 | 192 | return tokens 193 | 194 | def encode_text( 195 | self, 196 | text: str, 197 | cfg_weight: float = 7.5, 198 | negative_text: str = "", 199 | ): 200 | tokens_l = self._tokenize( 201 | self.tokenizer_l, 202 | text, 203 | (negative_text if cfg_weight > 1 else None), 204 | ) 205 | tokens_g = self._tokenize( 206 | self.tokenizer_g, 207 | text, 208 | (negative_text if cfg_weight > 1 else None), 209 | ) 210 | 211 | conditioning_l = self.clip_l(tokens_l) 212 | conditioning_g = self.clip_g(tokens_g) 213 | conditioning = mx.concatenate( 214 | [conditioning_l.hidden_states[-2], conditioning_g.hidden_states[-2]], 215 | axis=-1, 216 | ) 217 | pooled_conditioning = mx.concatenate( 218 | [conditioning_l.pooled_output, conditioning_g.pooled_output], 219 | axis=-1, 220 | ) 221 | 222 | conditioning = mx.concatenate( 223 | [ 224 | conditioning, 225 | mx.zeros( 226 | ( 227 | conditioning.shape[0], 228 | conditioning.shape[1], 229 | 4096 - conditioning.shape[2], 230 | ) 231 | ), 232 | ], 233 | axis=-1, 234 | ) 235 | 236 | if self.use_t5: 237 | tokens_t5 = self._tokenize( 238 | self.t5_tokenizer, 239 | text, 240 | (negative_text if cfg_weight > 1 else None), 241 | ) 242 | t5_conditioning = self.t5_encoder(tokens_t5) 243 | mx.eval(t5_conditioning) 244 | else: 245 | t5_conditioning = mx.zeros_like(conditioning) 246 | conditioning = mx.concatenate([conditioning, t5_conditioning], axis=1) 247 | 248 | return conditioning, pooled_conditioning 249 | 250 | def denoise_latents( 251 | self, 252 | conditioning, 253 | pooled_conditioning, 254 | num_steps: int = 2, 255 | cfg_weight: float = 0.0, 256 | latent_size: Tuple[int] = (64, 64), 257 | seed=None, 258 | image_path: Optional[str] = None, 259 | denoise: float = 1.0, 260 | ): 261 | # Set the PRNG state 262 | seed = int(time.time()) if seed is None else seed 263 | logger.info(f"Seed: {seed}") 264 | mx.random.seed(seed) 265 | 266 | x_T = self.get_empty_latent(*latent_size) 267 | if image_path is None: 268 | denoise = 1.0 269 | else: 270 | x_T = self.encode_image_to_latents(image_path, seed=seed) 271 | x_T = self.latent_format.process_in(x_T) 272 | noise = self.get_noise(seed, x_T) 273 | sigmas = self.get_sigmas(self.sampler, num_steps) 274 | sigmas = sigmas[int(num_steps * (1 - denoise)) :] 275 | extra_args = { 276 | "conditioning": conditioning, 277 | "cfg_weight": cfg_weight, 278 | "pooled_conditioning": pooled_conditioning, 279 | } 280 | noise_scaled = self.sampler.noise_scaling( 281 | sigmas[0], noise, x_T, self.max_denoise(sigmas) 282 | ) 283 | latent, iter_time = sample_euler( 284 | CFGDenoiser(self), noise_scaled, sigmas, extra_args=extra_args 285 | ) 286 | 287 | latent = self.latent_format.process_out(latent) 288 | 289 | return latent, iter_time 290 | 291 | def generate_image( 292 | self, 293 | text: str, 294 | num_steps: int = 2, 295 | cfg_weight: float = 0.0, 296 | negative_text: str = "", 297 | latent_size: Tuple[int] = (64, 64), 298 | seed=None, 299 | verbose: bool = True, 300 | image_path: Optional[str] = None, 301 | denoise: float = 1.0, 302 | ): 303 | # Check latent size is divisible by 2 304 | assert ( 305 | latent_size[0] % 2 == 0 306 | ), f"Height must be divisible by 16 ({latent_size[0]*8}/16={latent_size[0]/2})" 307 | assert ( 308 | latent_size[1] % 2 == 0 309 | ), f"Width must be divisible by 16 ({latent_size[1]*8}/16={latent_size[1]/2})" 310 | self.check_and_load_models() 311 | # Start timing 312 | start_time = time.time() 313 | 314 | # Initialize the memory log 315 | log = { 316 | "text_encoding": { 317 | "pre": { 318 | "peak_memory": round( 319 | bytes2gigabytes(mx.metal.get_peak_memory()), 3 320 | ), 321 | "active_memory": round( 322 | bytes2gigabytes(mx.metal.get_active_memory()), 3 323 | ), 324 | }, 325 | "post": {"peak_memory": None, "active_memory": None}, 326 | }, 327 | "denoising": { 328 | "pre": {"peak_memory": None, "active_memory": None}, 329 | "post": {"peak_memory": None, "active_memory": None}, 330 | }, 331 | "decoding": { 332 | "pre": {"peak_memory": None, "active_memory": None}, 333 | "post": {"peak_memory": None, "active_memory": None}, 334 | }, 335 | "peak_memory": 0.0, 336 | } 337 | 338 | # Get the text conditioning 339 | text_encoding_start_time = time.time() 340 | if verbose: 341 | logger.info( 342 | f"Pre text encoding peak memory: {log['text_encoding']['pre']['peak_memory']}GB" 343 | ) 344 | logger.info( 345 | f"Pre text encoding active memory: {log['text_encoding']['pre']['active_memory']}GB" 346 | ) 347 | 348 | # FIXME(arda): Need the same for CLIP models (low memory mode will not succeed a second time otherwise) 349 | if not hasattr(self, "t5"): 350 | self.set_up_t5() 351 | 352 | conditioning, pooled_conditioning = self.encode_text( 353 | text, cfg_weight, negative_text 354 | ) 355 | mx.eval(conditioning) 356 | mx.eval(pooled_conditioning) 357 | log["text_encoding"]["post"]["peak_memory"] = round( 358 | bytes2gigabytes(mx.metal.get_peak_memory()), 3 359 | ) 360 | log["text_encoding"]["post"]["active_memory"] = round( 361 | bytes2gigabytes(mx.metal.get_active_memory()), 3 362 | ) 363 | log["peak_memory"] = max( 364 | log["peak_memory"], log["text_encoding"]["post"]["peak_memory"] 365 | ) 366 | log["text_encoding"]["time"] = round(time.time() - text_encoding_start_time, 3) 367 | if verbose: 368 | logger.info( 369 | f"Post text encoding peak memory: {log['text_encoding']['post']['peak_memory']}GB" 370 | ) 371 | logger.info( 372 | f"Post text encoding active memory: {log['text_encoding']['post']['active_memory']}GB" 373 | ) 374 | logger.info(f"Text encoding time: {log['text_encoding']['time']}s") 375 | 376 | # unload T5 and CLIP models after obtaining conditioning in low memory mode 377 | if self.low_memory_mode: 378 | if hasattr(self, "t5_encoder"): 379 | del self.t5_encoder 380 | if hasattr(self, "clip_g"): 381 | del self.clip_g 382 | del self.clip_l 383 | gc.collect() 384 | 385 | logger.debug(f"Conditioning dtype before casting: {conditioning.dtype}") 386 | logger.debug( 387 | f"Pooled Conditioning dtype before casting: {pooled_conditioning.dtype}" 388 | ) 389 | conditioning = conditioning.astype(self.activation_dtype) 390 | pooled_conditioning = pooled_conditioning.astype(self.activation_dtype) 391 | logger.debug(f"Conditioning dtype after casting: {conditioning.dtype}") 392 | logger.debug( 393 | f"Pooled Conditioning dtype after casting: {pooled_conditioning.dtype}" 394 | ) 395 | 396 | # Reset peak memory info 397 | mx.metal.reset_peak_memory() 398 | 399 | # Generate the latents 400 | denoising_start_time = time.time() 401 | log["denoising"]["pre"]["peak_memory"] = round( 402 | bytes2gigabytes(mx.metal.get_peak_memory()), 3 403 | ) 404 | log["denoising"]["pre"]["active_memory"] = round( 405 | bytes2gigabytes(mx.metal.get_active_memory()), 3 406 | ) 407 | log["peak_memory"] = max( 408 | log["peak_memory"], log["denoising"]["pre"]["peak_memory"] 409 | ) 410 | if verbose: 411 | logger.info( 412 | f"Pre denoise peak memory: {log['denoising']['pre']['peak_memory']}GB" 413 | ) 414 | logger.info( 415 | f"Pre denoise active memory: {log['denoising']['pre']['active_memory']}GB" 416 | ) 417 | 418 | latents, iter_time = self.denoise_latents( 419 | conditioning, 420 | pooled_conditioning, 421 | num_steps=num_steps, 422 | cfg_weight=cfg_weight, 423 | latent_size=latent_size, 424 | seed=seed, 425 | image_path=image_path, 426 | denoise=denoise, 427 | ) 428 | mx.eval(latents) 429 | 430 | log["denoising"]["post"]["peak_memory"] = round( 431 | bytes2gigabytes(mx.metal.get_peak_memory()), 3 432 | ) 433 | log["denoising"]["post"]["active_memory"] = round( 434 | bytes2gigabytes(mx.metal.get_active_memory()), 3 435 | ) 436 | log["peak_memory"] = max( 437 | log["peak_memory"], log["denoising"]["post"]["peak_memory"] 438 | ) 439 | log["denoising"]["time"] = round(time.time() - denoising_start_time, 3) 440 | log["denoising"]["iter_time"] = iter_time 441 | if verbose: 442 | logger.info( 443 | f"Post denoise peak memory: {log['denoising']['post']['peak_memory']}GB" 444 | ) 445 | logger.info( 446 | f"Post denoise active memory: {log['denoising']['post']['active_memory']}GB" 447 | ) 448 | logger.info(f"Denoising time: {log['denoising']['time']}s") 449 | 450 | # unload MMDIT model after obtaining latents in low memory mode 451 | if self.low_memory_mode: 452 | del self.mmdit 453 | gc.collect() 454 | 455 | logger.debug(f"Latents dtype before casting: {latents.dtype}") 456 | latents = latents.astype(self.activation_dtype) 457 | logger.debug(f"Latents dtype after casting: {latents.dtype}") 458 | 459 | # Reset peak memory info 460 | mx.metal.reset_peak_memory() 461 | 462 | # Decode the latents 463 | decoding_start_time = time.time() 464 | log["decoding"]["pre"]["peak_memory"] = round( 465 | bytes2gigabytes(mx.metal.get_peak_memory()), 3 466 | ) 467 | log["decoding"]["pre"]["active_memory"] = round( 468 | bytes2gigabytes(mx.metal.get_active_memory()), 3 469 | ) 470 | log["peak_memory"] = max( 471 | log["peak_memory"], log["decoding"]["pre"]["peak_memory"] 472 | ) 473 | if verbose: 474 | logger.info( 475 | f"Pre decode peak memory: {log['decoding']['pre']['peak_memory']}GB" 476 | ) 477 | logger.info( 478 | f"Pre decode active memory: {log['decoding']['pre']['active_memory']}GB" 479 | ) 480 | latents = latents.astype(self.activation_dtype) 481 | decoded = self.decode_latents_to_image(latents) 482 | mx.eval(decoded) 483 | 484 | log["decoding"]["post"]["peak_memory"] = round( 485 | bytes2gigabytes(mx.metal.get_peak_memory()), 3 486 | ) 487 | log["decoding"]["post"]["active_memory"] = round( 488 | bytes2gigabytes(mx.metal.get_active_memory()), 3 489 | ) 490 | log["peak_memory"] = max( 491 | log["peak_memory"], log["decoding"]["post"]["peak_memory"] 492 | ) 493 | log["decoding"]["time"] = round(time.time() - decoding_start_time, 3) 494 | if verbose: 495 | logger.info( 496 | f"Post decode peak memory: {log['decoding']['post']['peak_memory']}GB" 497 | ) 498 | logger.info( 499 | f"Post decode active memory: {log['decoding']['post']['active_memory']}GB" 500 | ) 501 | 502 | if verbose: 503 | logger.info("============= Summary =============") 504 | logger.info(f"Text encoder: {log['text_encoding']['time']:.1f}s") 505 | logger.info(f"Denoising: {log['denoising']['time']:.1f}s") 506 | logger.info(f"Image decoder: {log['decoding']['time']:.1f}s") 507 | logger.info(f"Peak memory: {log['peak_memory']:.1f}GB") 508 | 509 | logger.info("============= Inference Context =============") 510 | ic = DiffusionKitInferenceContext() 511 | logger.info("Operating System:") 512 | pprint(ic.os_spec()) 513 | logger.info("Device:") 514 | pprint(ic.device_spec()) 515 | 516 | # unload VAE Decoder model after decoding in low memory mode 517 | if self.low_memory_mode: 518 | del self.decoder 519 | gc.collect() 520 | 521 | # Convert the decoded images to uint8 522 | x = mx.concatenate(decoded, axis=0) 523 | x = (x * 255).astype(mx.uint8) 524 | 525 | # End timing 526 | end_time = time.time() 527 | log["total_time"] = round(end_time - start_time, 3) 528 | if verbose: 529 | logger.info(f"Total time: {log['total_time']}s") 530 | 531 | return Image.fromarray(np.array(x)), log 532 | 533 | def read_image(self, image_path: str): 534 | # Read the image 535 | img = Image.open(image_path) 536 | 537 | # Make sure image shape is divisible by 64 538 | W, H = (dim - dim % 64 for dim in (img.width, img.height)) 539 | if W != img.width or H != img.height: 540 | logger.warning( 541 | f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}" 542 | ) 543 | img = img.resize((W, H), Image.LANCZOS) # use desired downsampling filter 544 | 545 | img = mx.array(np.array(img)) 546 | img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1.0 547 | 548 | return mx.expand_dims(img, axis=0) 549 | 550 | def get_noise(self, seed, x_T): 551 | np.random.seed(seed) 552 | noise = np.random.randn(*x_T.transpose(0, 3, 1, 2).shape) 553 | noise = mx.array(noise).transpose(0, 2, 3, 1) 554 | return noise 555 | 556 | def get_sigmas(self, sampler, num_steps: int): 557 | start = sampler.timestep(sampler.sigma_max).item() 558 | end = sampler.timestep(sampler.sigma_min).item() 559 | if isinstance(sampler, FluxSampler): 560 | num_steps += 1 561 | timesteps = mx.linspace(start, end, num_steps) 562 | sigs = [] 563 | for x in range(len(timesteps)): 564 | ts = timesteps[x] 565 | sigs.append(sampler.sigma(ts)) 566 | if not isinstance(sampler, FluxSampler): 567 | sigs += [0.0] 568 | return mx.array(sigs) 569 | 570 | def get_empty_latent(self, *shape): 571 | return mx.ones([1, *shape, 16]) * 0.0609 572 | 573 | def max_denoise(self, sigmas): 574 | max_sigma = float(self.sampler.sigma_max.item()) 575 | sigma = float(sigmas[0].item()) 576 | return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma 577 | 578 | def decode_latents_to_image(self, x_t): 579 | x = self.decoder(x_t) 580 | x = mx.clip(x / 2 + 0.5, 0, 1) 581 | return x 582 | 583 | def encode_image_to_latents(self, image_path: str, seed): 584 | image = self.read_image(image_path) 585 | hidden = self.encoder(image) 586 | mean, logvar = hidden.split(2, axis=-1) 587 | logvar = mx.clip(logvar, -30.0, 20.0) 588 | std = mx.exp(0.5 * logvar) 589 | noise = self.get_noise(seed, mean) 590 | 591 | return mean + std * noise 592 | 593 | 594 | class FluxPipeline(DiffusionPipeline): 595 | def __init__( 596 | self, 597 | w16: bool = False, 598 | shift: float = 1.0, 599 | use_t5: bool = True, 600 | model_version: str = "argmaxinc/mlx-FLUX.1-schnell", 601 | low_memory_mode: bool = True, 602 | a16: bool = False, 603 | local_ckpt=None, 604 | quantize_mmdit: bool = False, 605 | ): 606 | model_io.LOCAl_SD3_CKPT = local_ckpt 607 | self.float16_dtype = mx.bfloat16 608 | model_io._FLOAT16 = self.float16_dtype 609 | self.dtype = self.float16_dtype if w16 else mx.float32 610 | self.activation_dtype = self.float16_dtype if a16 else mx.float32 611 | self.mmdit_ckpt = MMDIT_CKPT[model_version] 612 | self.low_memory_mode = low_memory_mode 613 | self.model = _DEFAULT_MODEL 614 | self.model_version = model_version 615 | self.sampler = FluxSampler(shift=shift) 616 | self.latent_format = FluxLatentFormat() 617 | self.use_t5 = True 618 | self.use_clip_g = False 619 | self.quantize_mmdit = quantize_mmdit 620 | self.check_and_load_models() 621 | 622 | def load_mmdit(self, only_modulation_dict=False): 623 | if only_modulation_dict: 624 | return load_flux( 625 | key=self.mmdit_ckpt, 626 | model_key=self.model_version, 627 | float16=True if self.dtype == self.float16_dtype else False, 628 | low_memory_mode=self.low_memory_mode, 629 | only_modulation_dict=only_modulation_dict, 630 | ) 631 | self.mmdit = load_flux( 632 | key=self.mmdit_ckpt, 633 | model_key=self.model_version, 634 | float16=True if self.dtype == self.float16_dtype else False, 635 | low_memory_mode=self.low_memory_mode, 636 | only_modulation_dict=only_modulation_dict, 637 | ) 638 | 639 | def encode_text( 640 | self, 641 | text: str, 642 | cfg_weight: float = 7.5, 643 | negative_text: str = "", 644 | ): 645 | tokens_l = self._tokenize( 646 | self.tokenizer_l, 647 | text, 648 | (negative_text if cfg_weight > 1 else None), 649 | ) 650 | conditioning_l = self.clip_l(tokens_l[[0], :]) # Ignore negative text 651 | pooled_conditioning = conditioning_l.pooled_output 652 | 653 | tokens_t5 = self._tokenize( 654 | self.t5_tokenizer, 655 | text, 656 | (negative_text if cfg_weight > 1 else None), 657 | ) 658 | padded_tokens_t5 = mx.zeros((1, T5_MAX_LENGTH[self.model_version])).astype( 659 | tokens_t5.dtype 660 | ) 661 | padded_tokens_t5[:, : tokens_t5.shape[1]] = tokens_t5[ 662 | [0], : 663 | ] # Ignore negative text 664 | t5_conditioning = self.t5_encoder(padded_tokens_t5) 665 | mx.eval(t5_conditioning) 666 | conditioning = t5_conditioning 667 | 668 | return conditioning, pooled_conditioning 669 | 670 | 671 | class CFGDenoiser(nn.Module): 672 | """Helper for applying CFG Scaling to diffusion outputs""" 673 | 674 | def __init__(self, model: DiffusionPipeline): 675 | super().__init__() 676 | self.model = model 677 | 678 | def cache_modulation_params(self, pooled_text_embeddings, sigmas): 679 | self.model.mmdit.cache_modulation_params( 680 | pooled_text_embeddings, sigmas.astype(self.model.activation_dtype) 681 | ) 682 | 683 | def clear_cache(self): 684 | self.model.mmdit.load_weights( 685 | self.model.load_mmdit(only_modulation_dict=True), strict=False 686 | ) 687 | 688 | def __call__( 689 | self, 690 | x_t, 691 | timestep, 692 | sigma, 693 | conditioning, 694 | cfg_weight: float = 7.5, 695 | pooled_conditioning=None, 696 | ): 697 | if cfg_weight <= 0: 698 | logger.debug("CFG Weight disabled") 699 | x_t_mmdit = x_t.astype(self.model.activation_dtype) 700 | else: 701 | x_t_mmdit = mx.concatenate([x_t] * 2, axis=0).astype( 702 | self.model.activation_dtype 703 | ) 704 | mmdit_input = { 705 | "latent_image_embeddings": x_t_mmdit, 706 | "token_level_text_embeddings": mx.expand_dims(conditioning, 2), 707 | "timestep": mx.broadcast_to(timestep, [len(x_t_mmdit)]), 708 | } 709 | 710 | mmdit_output = self.model.mmdit(**mmdit_input) 711 | eps_pred = self.model.sampler.calculate_denoised(sigma, mmdit_output, x_t_mmdit) 712 | if cfg_weight <= 0: 713 | return eps_pred 714 | else: 715 | eps_text, eps_neg = eps_pred.split(2) 716 | return eps_neg + cfg_weight * (eps_text - eps_neg) 717 | 718 | 719 | class LatentFormat: 720 | """Base class for latent format conversion""" 721 | 722 | def __init__(self): 723 | self.scale_factor = 1.0 724 | self.shift_factor = 0.0 725 | 726 | def process_in(self, latent): 727 | return (latent - self.shift_factor) * self.scale_factor 728 | 729 | def process_out(self, latent): 730 | return (latent / self.scale_factor) + self.shift_factor 731 | 732 | 733 | class SD3LatentFormat(LatentFormat): 734 | def __init__(self): 735 | super().__init__() 736 | self.scale_factor = 1.5305 737 | self.shift_factor = 0.0609 738 | 739 | 740 | class FluxLatentFormat(LatentFormat): 741 | def __init__(self): 742 | super().__init__() 743 | self.scale_factor = 0.3611 744 | self.shift_factor = 0.1159 745 | 746 | 747 | def append_dims(x, target_dims): 748 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 749 | dims_to_append = target_dims - x.ndim 750 | return x[(...,) + (None,) * dims_to_append] 751 | 752 | 753 | def to_d(x, sigma, denoised): 754 | """Converts a denoiser output to a Karras ODE derivative.""" 755 | return (x - denoised) / append_dims(sigma, x.ndim) 756 | 757 | 758 | def sample_euler(model: CFGDenoiser, x, sigmas, extra_args=None): 759 | """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" 760 | extra_args = {} if extra_args is None else extra_args 761 | 762 | from tqdm import trange 763 | 764 | t = trange(len(sigmas) - 1) 765 | 766 | timesteps = model.model.sampler.timestep(sigmas).astype( 767 | model.model.activation_dtype 768 | ) 769 | model.cache_modulation_params(extra_args.pop("pooled_conditioning"), timesteps) 770 | 771 | iter_time = [] 772 | for i in t: 773 | start_time = t.format_dict["elapsed"] 774 | denoised = model(x, timesteps[i], sigmas[i], **extra_args) 775 | d = to_d(x, sigmas[i], denoised) 776 | dt = sigmas[i + 1] - sigmas[i] 777 | # Euler method 778 | x = x + d * dt 779 | mx.eval(x) 780 | end_time = t.format_dict["elapsed"] 781 | iter_time.append(round((end_time - start_time), 3)) 782 | 783 | model.clear_cache() 784 | 785 | return x, iter_time 786 | -------------------------------------------------------------------------------- /diffusionkit/mlx/clip.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from .config import CLIPTextModelConfig 10 | 11 | _ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu} 12 | 13 | 14 | @dataclass 15 | class CLIPOutput: 16 | # The last_hidden_state indexed at the EOS token and possibly projected if 17 | # the model has a projection layer 18 | pooled_output: Optional[mx.array] = None 19 | 20 | # The full sequence output of the transformer after the final layernorm 21 | last_hidden_state: Optional[mx.array] = None 22 | 23 | # A list of hidden states corresponding to the outputs of the transformer layers 24 | hidden_states: Optional[List[mx.array]] = None 25 | 26 | 27 | class CLIPEncoderLayer(nn.Module): 28 | """The transformer encoder layer from CLIP.""" 29 | 30 | def __init__(self, model_dims: int, num_heads: int, activation: str): 31 | super().__init__() 32 | 33 | self.layer_norm1 = nn.LayerNorm(model_dims) 34 | self.layer_norm2 = nn.LayerNorm(model_dims) 35 | 36 | self.attention = nn.MultiHeadAttention(model_dims, num_heads) 37 | # Add biases to the attention projections to match CLIP 38 | self.attention.query_proj.bias = mx.zeros(model_dims) 39 | self.attention.key_proj.bias = mx.zeros(model_dims) 40 | self.attention.value_proj.bias = mx.zeros(model_dims) 41 | self.attention.out_proj.bias = mx.zeros(model_dims) 42 | 43 | self.linear1 = nn.Linear(model_dims, 4 * model_dims) 44 | self.linear2 = nn.Linear(4 * model_dims, model_dims) 45 | 46 | self.act = _ACTIVATIONS[activation] 47 | 48 | def __call__(self, x, attn_mask=None): 49 | y = self.layer_norm1(x) 50 | y = self.attention(y, y, y, attn_mask) 51 | x = y + x 52 | 53 | y = self.layer_norm2(x) 54 | y = self.linear1(y) 55 | y = self.act(y) 56 | y = self.linear2(y) 57 | x = y + x 58 | 59 | return x 60 | 61 | 62 | class CLIPTextModel(nn.Module): 63 | """Implements the text encoder transformer from CLIP.""" 64 | 65 | def __init__(self, config: CLIPTextModelConfig): 66 | super().__init__() 67 | 68 | self.max_length = config.max_length 69 | 70 | self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims) 71 | self.position_embedding = nn.Embedding(config.max_length, config.model_dims) 72 | self.layers = [ 73 | CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act) 74 | for i in range(config.num_layers) 75 | ] 76 | self.final_layer_norm = nn.LayerNorm(config.model_dims) 77 | 78 | if config.projection_dim is not None: 79 | self.text_projection = nn.Linear( 80 | config.model_dims, config.projection_dim, bias=False 81 | ) 82 | 83 | def _get_mask(self, N, dtype): 84 | indices = mx.arange(N) 85 | mask = indices[:, None] < indices[None] 86 | mask = mask.astype(dtype) * ( 87 | -6e4 if (dtype == mx.bfloat16 or dtype == mx.float16) else -1e9 88 | ) 89 | return mask 90 | 91 | def __call__(self, x): 92 | # Extract some shapes 93 | B, N = x.shape 94 | eos_tokens = x.argmax(-1) 95 | 96 | # Compute the embeddings 97 | x = self.token_embedding(x) 98 | x = x + self.position_embedding.weight[:N] 99 | 100 | # Compute the features from the transformer 101 | mask = self._get_mask(N, x.dtype) 102 | hidden_states = [] 103 | for l in self.layers: 104 | x = l(x, mask) 105 | hidden_states.append(x) 106 | 107 | # Apply the final layernorm and return 108 | x = self.final_layer_norm(x) 109 | last_hidden_state = x 110 | 111 | # Select the EOS token 112 | pooled_output = x[mx.arange(len(x)), eos_tokens] 113 | if "text_projection" in self: 114 | pooled_output = self.text_projection(pooled_output) 115 | 116 | return CLIPOutput( 117 | pooled_output=pooled_output, 118 | last_hidden_state=last_hidden_state, 119 | hidden_states=hidden_states, 120 | ) 121 | -------------------------------------------------------------------------------- /diffusionkit/mlx/config.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion 2 | 3 | # 4 | # For licensing see accompanying LICENSE.md file. 5 | # Copyright (C) 2024 Argmax, Inc. All Rights Reserved. 6 | # 7 | from dataclasses import dataclass 8 | from enum import Enum 9 | from typing import List, Optional, Tuple 10 | 11 | import mlx.core as mx 12 | 13 | 14 | class PositionalEncoding(Enum): 15 | LearnedInputEmbedding = 1 16 | PreSDPARope = 2 17 | 18 | 19 | @dataclass 20 | class MMDiTConfig: 21 | """Multi-modal Diffusion Transformer Configuration""" 22 | 23 | # Transformer spec 24 | num_heads: int = 24 25 | depth_multimodal: int = 24 # e.g. SD3: 24 (2b) or 38 (8b), FLUX.1: 19 26 | depth_unified: int = 0 # e.g. SD3: 0 (2b and 8b), FLUX.1: 38 27 | parallel_mlp_for_unified_blocks: bool = ( 28 | True # e.g. FLUX.1 unified blocks, https://arxiv.org/pdf/2302.05442 29 | ) 30 | mlp_ratio: int = 4 31 | vae_latent_dim: int = 16 # = in_channels = out_channels 32 | layer_norm_eps: float = 1e-6 33 | pos_embed_type: PositionalEncoding = PositionalEncoding.LearnedInputEmbedding 34 | rope_axes_dim: Optional[Tuple[int]] = None 35 | # FLUX uses RMSNorm post-QK projection 36 | use_qk_norm: bool = False 37 | upcast_multimodal_blocks: Optional[List[int]] = None 38 | upcast_unified_blocks: Optional[List[int]] = None 39 | 40 | # 64 * self.depth_multimodal is the SD3 convention, but can be overridden 41 | hidden_size_override: Optional[int] = None 42 | 43 | @property 44 | def hidden_size(self) -> int: 45 | return self.hidden_size_override or (64 * self.depth_multimodal) 46 | 47 | # x: Latent image input spec 48 | max_latent_resolution: int = 192 49 | patch_size: int = 2 50 | # If true, reshapes input to enact (patch_size, patch_size) space-to-depth operation 51 | # If false, uses 2D convolution with kernel_size=patch_size and stride=patch_size 52 | patchify_via_reshape: bool = False 53 | 54 | # y: Text input spec 55 | # e.g. SD3: 768 (CLIP-L/14) + 1280 (CLIP-G/14) = 2048 56 | # FLUX.1: 768 (CLIP-L/14) = 768 57 | pooled_text_embed_dim: int = 2048 58 | # e.g. SD3: 4096 (T5-XXL) = 768 (CLIP-L/14) + 1280 (CLIP-G/14) + 2048 (zero padding) 59 | # FLUX: 4096 (T5-XXL) 60 | token_level_text_embed_dim: int = 4096 61 | 62 | # t: Timestep input spec 63 | frequency_embed_dim: int = 256 64 | max_period: int = 10000 65 | 66 | dtype: mx.Dtype = mx.bfloat16 67 | float16_dtype: mx.Dtype = mx.bfloat16 68 | 69 | low_memory_mode: bool = True 70 | 71 | guidance_embed: bool = False 72 | 73 | 74 | SD3_8b = MMDiTConfig(depth_multimodal=38, num_heads=3, upcast_multimodal_blocks=[35]) 75 | 76 | SD3_2b = MMDiTConfig( 77 | depth_multimodal=24, num_heads=24, float16_dtype=mx.float16, dtype=mx.float16 78 | ) 79 | 80 | FLUX_SCHNELL = MMDiTConfig( 81 | num_heads=24, 82 | depth_multimodal=19, 83 | depth_unified=38, 84 | parallel_mlp_for_unified_blocks=True, 85 | hidden_size_override=3072, 86 | patchify_via_reshape=True, 87 | pos_embed_type=PositionalEncoding.PreSDPARope, 88 | rope_axes_dim=(16, 56, 56), 89 | pooled_text_embed_dim=768, # CLIP-L/14 only 90 | use_qk_norm=True, 91 | float16_dtype=mx.bfloat16, 92 | dtype=mx.bfloat16, 93 | ) 94 | 95 | FLUX_DEV = MMDiTConfig( 96 | num_heads=24, 97 | depth_multimodal=19, 98 | depth_unified=38, 99 | parallel_mlp_for_unified_blocks=True, 100 | hidden_size_override=3072, 101 | patchify_via_reshape=True, 102 | pos_embed_type=PositionalEncoding.PreSDPARope, 103 | rope_axes_dim=(16, 56, 56), 104 | pooled_text_embed_dim=768, # CLIP-L/14 only 105 | use_qk_norm=True, 106 | float16_dtype=mx.bfloat16, 107 | guidance_embed=True, 108 | dtype=mx.bfloat16, 109 | ) 110 | 111 | 112 | @dataclass 113 | class AutoencoderConfig: 114 | in_channels: int = 3 115 | out_channels: int = 3 116 | latent_channels_out: int = 8 117 | latent_channels_in: int = 4 118 | block_out_channels: Tuple[int] = (128, 256, 512, 512) 119 | layers_per_block: int = 2 120 | norm_num_groups: int = 32 121 | scaling_factor: float = 0.18215 122 | 123 | 124 | @dataclass 125 | class VAEDecoderConfig: 126 | in_channels: int = 16 127 | out_channels: int = 3 128 | block_out_channels: Tuple[int] = (128, 256, 512, 512) 129 | layers_per_block: int = 3 130 | resnet_groups: int = 32 131 | 132 | 133 | @dataclass 134 | class VAEEncoderConfig: 135 | in_channels: int = 3 136 | out_channels: int = 32 137 | block_out_channels: Tuple[int] = (128, 256, 512, 512) 138 | layers_per_block: int = 2 139 | resnet_groups: int = 32 140 | 141 | 142 | @dataclass 143 | class CLIPTextModelConfig: 144 | num_layers: int = 23 145 | model_dims: int = 1024 146 | num_heads: int = 16 147 | max_length: int = 77 148 | vocab_size: int = 49408 149 | projection_dim: Optional[int] = None 150 | hidden_act: str = "quick_gelu" 151 | -------------------------------------------------------------------------------- /diffusionkit/mlx/mmdit.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2024 Argmax, Inc. All Rights Reserved. 4 | # 5 | 6 | from functools import partial 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | import numpy as np 11 | from argmaxtools.utils import get_logger 12 | from beartype.typing import Dict, List, Optional, Tuple 13 | from mlx.utils import tree_map 14 | 15 | from .config import MMDiTConfig, PositionalEncoding 16 | 17 | logger = get_logger(__name__) 18 | 19 | SDPA_FLASH_ATTN_THRESHOLD = 1024 20 | 21 | 22 | class MMDiT(nn.Module): 23 | """Multi-modal Diffusion Transformer Architecture 24 | as described in https://arxiv.org/abs/2403.03206 25 | """ 26 | 27 | def __init__(self, config: MMDiTConfig): 28 | super().__init__() 29 | self.config = config 30 | 31 | if config.guidance_embed: 32 | self.guidance_in = MLPEmbedder( 33 | in_dim=config.frequency_embed_dim, hidden_dim=config.hidden_size 34 | ) 35 | else: 36 | self.guidance_in = nn.Identity() 37 | 38 | # Input adapters and embeddings 39 | self.x_embedder = LatentImageAdapter(config) 40 | 41 | if config.pos_embed_type == PositionalEncoding.LearnedInputEmbedding: 42 | self.x_pos_embedder = LatentImagePositionalEmbedding(config) 43 | self.pre_sdpa_rope = nn.Identity() 44 | elif config.pos_embed_type == PositionalEncoding.PreSDPARope: 45 | self.pre_sdpa_rope = RoPE( 46 | theta=10000, 47 | axes_dim=config.rope_axes_dim, 48 | ) 49 | else: 50 | raise ValueError( 51 | f"Unsupported positional encoding type: {config.pos_embed_type}" 52 | ) 53 | 54 | self.y_embedder = PooledTextEmbeddingAdapter(config) 55 | self.t_embedder = TimestepAdapter(config) 56 | self.context_embedder = nn.Linear( 57 | config.token_level_text_embed_dim, 58 | config.hidden_size, 59 | ) 60 | 61 | self.multimodal_transformer_blocks = [ 62 | MultiModalTransformerBlock( 63 | config, 64 | skip_text_post_sdpa=(i == config.depth_multimodal - 1) 65 | and (config.depth_unified < 1), 66 | ) 67 | for i in range(config.depth_multimodal) 68 | ] 69 | 70 | if config.depth_unified > 0: 71 | self.unified_transformer_blocks = [ 72 | UnifiedTransformerBlock(config) for _ in range(config.depth_unified) 73 | ] 74 | 75 | self.final_layer = FinalLayer(config) 76 | 77 | def cache_modulation_params( 78 | self, 79 | pooled_text_embeddings: mx.array, 80 | timesteps: mx.array, 81 | ): 82 | """Compute modulation parameters ahead of time to reduce peak memory load during MMDiT inference 83 | by offloading all adaLN_modulation parameters 84 | """ 85 | y_embed = self.y_embedder(pooled_text_embeddings) 86 | batch_size = pooled_text_embeddings.shape[0] 87 | 88 | offload_size = 0 89 | to_offload = [] 90 | 91 | for timestep in timesteps: 92 | final_timestep = timestep.item() == timesteps[-1].item() 93 | timestep_key = timestep.item() 94 | modulation_inputs = y_embed[:, None, None, :] + self.t_embedder( 95 | mx.repeat(timestep[None], batch_size, axis=0) 96 | ) 97 | 98 | for block in self.multimodal_transformer_blocks: 99 | if not hasattr(block.image_transformer_block, "_modulation_params"): 100 | block.image_transformer_block._modulation_params = dict() 101 | block.text_transformer_block._modulation_params = dict() 102 | 103 | block.image_transformer_block._modulation_params[ 104 | timestep_key 105 | ] = block.image_transformer_block.adaLN_modulation(modulation_inputs) 106 | block.text_transformer_block._modulation_params[ 107 | timestep_key 108 | ] = block.text_transformer_block.adaLN_modulation(modulation_inputs) 109 | mx.eval(block.image_transformer_block._modulation_params[timestep_key]) 110 | mx.eval(block.text_transformer_block._modulation_params[timestep_key]) 111 | 112 | if final_timestep: 113 | offload_size += ( 114 | block.image_transformer_block.adaLN_modulation.layers[ 115 | 1 116 | ].weight.size 117 | * block.image_transformer_block.adaLN_modulation.layers[ 118 | 1 119 | ].weight.dtype.size 120 | ) 121 | offload_size += ( 122 | block.text_transformer_block.adaLN_modulation.layers[ 123 | 1 124 | ].weight.size 125 | * block.text_transformer_block.adaLN_modulation.layers[ 126 | 1 127 | ].weight.dtype.size 128 | ) 129 | to_offload.extend( 130 | [ 131 | block.image_transformer_block.adaLN_modulation.layers[1], 132 | block.text_transformer_block.adaLN_modulation.layers[1], 133 | ] 134 | ) 135 | 136 | if self.config.depth_unified > 0: 137 | for block in self.unified_transformer_blocks: 138 | if not hasattr(block.transformer_block, "_modulation_params"): 139 | block.transformer_block._modulation_params = dict() 140 | block.transformer_block._modulation_params[ 141 | timestep_key 142 | ] = block.transformer_block.adaLN_modulation(modulation_inputs) 143 | mx.eval(block.transformer_block._modulation_params[timestep_key]) 144 | 145 | if final_timestep: 146 | offload_size += ( 147 | block.transformer_block.adaLN_modulation.layers[ 148 | 1 149 | ].weight.size 150 | * block.transformer_block.adaLN_modulation.layers[ 151 | 1 152 | ].weight.dtype.size 153 | ) 154 | to_offload.extend( 155 | [block.transformer_block.adaLN_modulation.layers[1]] 156 | ) 157 | 158 | if not hasattr(self.final_layer, "_modulation_params"): 159 | self.final_layer._modulation_params = dict() 160 | self.final_layer._modulation_params[ 161 | timestep_key 162 | ] = self.final_layer.adaLN_modulation(modulation_inputs) 163 | mx.eval(self.final_layer._modulation_params[timestep_key]) 164 | 165 | if final_timestep: 166 | offload_size += ( 167 | self.final_layer.adaLN_modulation.layers[1].weight.size 168 | * self.final_layer.adaLN_modulation.layers[1].weight.dtype.size 169 | ) 170 | to_offload.extend([self.final_layer.adaLN_modulation.layers[1]]) 171 | 172 | self.to_offload = to_offload 173 | for x in self.to_offload: 174 | x.update(tree_map(lambda _: mx.array([]), x.parameters())) 175 | # x.clear() 176 | 177 | logger.info(f"Cached modulation_params for timesteps={timesteps}") 178 | logger.info( 179 | f"Cached modulation_params will reduce peak memory by {(offload_size) / 1e9:.1f} GB" 180 | ) 181 | 182 | def clear_modulation_params_cache(self): 183 | for name, module in self.named_modules(): 184 | if hasattr(module, "_modulation_params"): 185 | delattr(module, "_modulation_params") 186 | logger.info("Cleared modulation_params cache") 187 | 188 | def __call__( 189 | self, 190 | latent_image_embeddings: mx.array, 191 | token_level_text_embeddings: mx.array, 192 | timestep: mx.array, 193 | ) -> mx.array: 194 | batch, latent_height, latent_width, _ = latent_image_embeddings.shape 195 | token_level_text_embeddings = self.context_embedder(token_level_text_embeddings) 196 | 197 | if hasattr(self, "x_pos_embedder"): 198 | latent_image_embeddings = self.x_embedder( 199 | latent_image_embeddings 200 | ) + self.x_pos_embedder(latent_image_embeddings) 201 | else: 202 | latent_image_embeddings = self.x_embedder(latent_image_embeddings) 203 | 204 | latent_image_embeddings = latent_image_embeddings.reshape( 205 | batch, -1, 1, self.config.hidden_size 206 | ) 207 | 208 | if self.config.pos_embed_type == PositionalEncoding.PreSDPARope: 209 | positional_encodings = self.pre_sdpa_rope( 210 | text_sequence_length=token_level_text_embeddings.shape[1], 211 | latent_image_resolution=( 212 | latent_height // self.config.patch_size, 213 | latent_width // self.config.patch_size, 214 | ), 215 | ) 216 | else: 217 | positional_encodings = None 218 | 219 | if self.config.guidance_embed: 220 | timestep = self.guidance_in(self.t_embedder(timestep)) 221 | 222 | # MultiModalTransformer layers 223 | if self.config.depth_multimodal > 0: 224 | for bidx, block in enumerate(self.multimodal_transformer_blocks): 225 | latent_image_embeddings, token_level_text_embeddings = block( 226 | latent_image_embeddings, 227 | token_level_text_embeddings, 228 | timestep, 229 | positional_encodings=positional_encodings, 230 | ) 231 | 232 | # UnifiedTransformerBlock layers 233 | if self.config.depth_unified > 0: 234 | latent_unified_embeddings = mx.concatenate( 235 | (token_level_text_embeddings, latent_image_embeddings), axis=1 236 | ) 237 | 238 | for bidx, block in enumerate(self.unified_transformer_blocks): 239 | latent_unified_embeddings = block( 240 | latent_unified_embeddings, 241 | timestep, 242 | positional_encodings=positional_encodings, 243 | ) 244 | 245 | latent_image_embeddings = latent_unified_embeddings[ 246 | :, token_level_text_embeddings.shape[1] :, ... 247 | ] 248 | 249 | latent_image_embeddings = self.final_layer( 250 | latent_image_embeddings, 251 | timestep, 252 | ) 253 | 254 | if self.config.patchify_via_reshape: 255 | latent_image_embeddings = self.x_embedder.unpack( 256 | latent_image_embeddings, (latent_height, latent_width) 257 | ) 258 | else: 259 | latent_image_embeddings = unpatchify( 260 | latent_image_embeddings, 261 | patch_size=self.config.patch_size, 262 | target_height=latent_height, 263 | target_width=latent_width, 264 | vae_latent_dim=self.config.vae_latent_dim, 265 | ) 266 | return latent_image_embeddings 267 | 268 | 269 | class LatentImageAdapter(nn.Module): 270 | """Adapts the latent image input by: 271 | - Patchifying to reduce sequence length by `config.patch_size ** 2` 272 | - Projecting to `hidden_size` 273 | """ 274 | 275 | def __init__(self, config: MMDiTConfig): 276 | super().__init__() 277 | self.config = config 278 | in_dim = config.vae_latent_dim 279 | kernel_size = stride = config.patch_size 280 | 281 | if config.patchify_via_reshape: 282 | in_dim *= config.patch_size**2 283 | kernel_size = stride = 1 284 | 285 | self.proj = nn.Conv2d( 286 | in_dim, 287 | config.hidden_size, 288 | kernel_size, 289 | stride, 290 | ) 291 | 292 | def __call__(self, x: mx.array) -> mx.array: 293 | if self.config.patchify_via_reshape: 294 | b, h_latent, w_latent, c = x.shape 295 | p = self.config.patch_size 296 | x = ( 297 | x.reshape(b, h_latent // p, p, w_latent // p, p, c) 298 | .transpose(0, 1, 3, 5, 2, 4) 299 | .reshape(b, h_latent // p, w_latent // p, -1) 300 | ) 301 | 302 | return self.proj(x) 303 | 304 | def unpack(self, x: mx.array, latent_image_resolution: Tuple[int]) -> mx.array: 305 | """Unpacks the latent image embeddings to the original resolution 306 | for `config.patchify_via_reshape` models 307 | """ 308 | assert self.config.patchify_via_reshape 309 | 310 | b = x.shape[0] 311 | p = self.config.patch_size 312 | h = latent_image_resolution[0] // p 313 | w = latent_image_resolution[1] // p 314 | x = ( 315 | x.reshape( 316 | b, h, w, -1, p, p 317 | ) # (b, hw, 1, (c*ph*pw)) -> (b, h, w, c, ph, pw) 318 | .transpose(0, 1, 4, 2, 5, 3) # (b, h, w, c, ph, pw) -> (b, h, ph, w, pw, c) 319 | .reshape(b, h * p, w * p, -1) # (b, h, ph, w, pw, c) -> (b, h*ph, w*pw, c) 320 | ) 321 | return x 322 | 323 | 324 | class LatentImagePositionalEmbedding(nn.Module): 325 | def __init__(self, config: MMDiTConfig): 326 | super().__init__() 327 | self.pos_embed = nn.Embedding( 328 | num_embeddings=config.max_latent_resolution**2, dims=config.hidden_size 329 | ) 330 | self.max_hw = config.max_latent_resolution 331 | self.patch_size = config.patch_size 332 | self.weight_shape = (1, self.max_hw, self.max_hw, config.hidden_size) 333 | 334 | def __call__(self, x: mx.array) -> mx.array: 335 | b, h, w, _ = x.shape 336 | assert h <= self.max_hw and w <= self.max_hw 337 | 338 | h = h // self.patch_size 339 | w = w // self.patch_size 340 | 341 | # Center crop the positional embedding to match the input resolution 342 | y0 = (self.max_hw - h) // 2 343 | y1 = y0 + h 344 | x0 = (self.max_hw - w) // 2 345 | x1 = x0 + w 346 | 347 | w = self.pos_embed.weight.reshape(*self.weight_shape) 348 | w = w[:, y0:y1, x0:x1, :] 349 | return mx.repeat(w, repeats=b, axis=0) 350 | 351 | 352 | class PooledTextEmbeddingAdapter(nn.Module): 353 | def __init__(self, config: MMDiTConfig): 354 | super().__init__() 355 | 356 | d1, d2 = config.pooled_text_embed_dim, config.hidden_size 357 | self.mlp = nn.Sequential( 358 | nn.Linear(d1, d2), 359 | nn.SiLU(), 360 | nn.Linear(d2, d2), 361 | ) 362 | 363 | def __call__(self, y: mx.array) -> mx.array: 364 | return self.mlp(y) 365 | 366 | 367 | class TimestepAdapter(nn.Module): 368 | def __init__(self, config: MMDiTConfig): 369 | super().__init__() 370 | 371 | d1, d2 = config.frequency_embed_dim, config.hidden_size 372 | self.mlp = nn.Sequential( 373 | nn.Linear(d1, d2), 374 | nn.SiLU(), 375 | nn.Linear(d2, d2), 376 | ) 377 | self.config = config 378 | 379 | def timestep_embedding(self, t: mx.array) -> mx.array: 380 | half = self.config.frequency_embed_dim // 2 381 | 382 | frequencies = mx.exp( 383 | -mx.log(mx.array(self.config.max_period)) 384 | * mx.arange(start=0, stop=half, dtype=self.config.dtype) 385 | / half 386 | ).astype(self.config.dtype) 387 | 388 | args = t[:, None].astype(self.config.dtype) * frequencies[None] 389 | return mx.concatenate([mx.cos(args), mx.sin(args)], axis=-1) 390 | 391 | def __call__(self, t: mx.array) -> mx.array: 392 | return self.mlp(self.timestep_embedding(t)[:, None, None, :]) 393 | 394 | 395 | class TransformerBlock(nn.Module): 396 | def __init__( 397 | self, 398 | config: MMDiTConfig, 399 | skip_post_sdpa: bool = False, 400 | parallel_mlp: bool = False, 401 | num_modulation_params: Optional[int] = None, 402 | ): 403 | super().__init__() 404 | self.config = config 405 | self.parallel_mlp = parallel_mlp 406 | self.skip_post_sdpa = skip_post_sdpa 407 | self.per_head_dim = config.hidden_size // config.num_heads 408 | 409 | self.norm1 = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 410 | self.attn = Attention(config.hidden_size, config.num_heads) 411 | if not self.parallel_mlp: 412 | # If parallel, reuse norm1 across attention and mlp 413 | self.norm2 = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 414 | 415 | if skip_post_sdpa: 416 | self.attn.o_proj = nn.Identity() 417 | else: 418 | self.mlp = FFN( 419 | embed_dim=config.hidden_size, 420 | expansion_factor=config.mlp_ratio, 421 | activation_fn=nn.GELU(), 422 | ) 423 | 424 | if num_modulation_params is None: 425 | num_modulation_params = 6 426 | if skip_post_sdpa: 427 | num_modulation_params = 2 428 | 429 | self.num_modulation_params = num_modulation_params 430 | self.adaLN_modulation = nn.Sequential( 431 | nn.SiLU(), 432 | nn.Linear( 433 | config.hidden_size, self.num_modulation_params * config.hidden_size 434 | ), 435 | ) 436 | 437 | if config.use_qk_norm: 438 | self.qk_norm = QKNorm(config.hidden_size // config.num_heads) 439 | 440 | def pre_sdpa( 441 | self, 442 | tensor: mx.array, 443 | timestep: mx.array, 444 | ) -> Dict[str, mx.array]: 445 | if timestep.size > 1: 446 | timestep = timestep[0] 447 | modulation_params = self._modulation_params[timestep.item()] 448 | 449 | modulation_params = mx.split( 450 | modulation_params, self.num_modulation_params, axis=-1 451 | ) 452 | 453 | post_norm1_shift = modulation_params[0] 454 | post_norm1_residual_scale = modulation_params[1] 455 | 456 | # LayerNorm and modulate before SDPA 457 | try: 458 | modulated_pre_attention = affine_transform( 459 | tensor, 460 | shift=post_norm1_shift, 461 | residual_scale=post_norm1_residual_scale, 462 | norm_module=self.norm1, 463 | ) 464 | except Exception as e: 465 | logger.error( 466 | f"Error in pre_sdpa: {e}", 467 | exc_info=True, 468 | ) 469 | raise e 470 | 471 | q = self.attn.q_proj(modulated_pre_attention) 472 | k = self.attn.k_proj(modulated_pre_attention) 473 | v = self.attn.v_proj(modulated_pre_attention) 474 | 475 | batch = tensor.shape[0] 476 | 477 | def rearrange_for_norm(t): 478 | # Target data layout: (batch, head, seq_len, channel) 479 | return t.reshape( 480 | batch, -1, self.config.num_heads, self.per_head_dim 481 | ).transpose(0, 2, 1, 3) 482 | 483 | q = rearrange_for_norm(q) 484 | k = rearrange_for_norm(k) 485 | v = rearrange_for_norm(v) 486 | 487 | if self.config.use_qk_norm: 488 | q, k = self.qk_norm(q, k) 489 | 490 | if self.config.depth_unified == 0: 491 | q = q.transpose(0, 2, 1, 3).reshape(batch, -1, 1, self.config.hidden_size) 492 | k = k.transpose(0, 2, 1, 3).reshape(batch, -1, 1, self.config.hidden_size) 493 | v = v.transpose(0, 2, 1, 3).reshape(batch, -1, 1, self.config.hidden_size) 494 | 495 | results = {"q": q, "k": k, "v": v} 496 | 497 | results["modulated_pre_attention"] = modulated_pre_attention 498 | 499 | assert len(modulation_params) in [2, 3, 6] 500 | results.update( 501 | { 502 | "post_norm1_shift": post_norm1_shift, 503 | "post_norm1_residual_scale": post_norm1_residual_scale, 504 | } 505 | ) 506 | 507 | if len(modulation_params) > 2: 508 | results.update({"post_attn_scale": modulation_params[2]}) 509 | 510 | if len(modulation_params) > 3: 511 | results.update( 512 | { 513 | "post_norm2_shift": modulation_params[3], 514 | "post_norm2_residual_scale": modulation_params[4], 515 | "post_mlp_scale": modulation_params[5], 516 | } 517 | ) 518 | 519 | return results 520 | 521 | def post_sdpa( 522 | self, 523 | residual: mx.array, 524 | sdpa_output: mx.array, 525 | modulated_pre_attention: mx.array, 526 | post_attn_scale: Optional[mx.array] = None, 527 | post_norm2_shift: Optional[mx.array] = None, 528 | post_norm2_residual_scale: Optional[mx.array] = None, 529 | post_mlp_scale: Optional[mx.array] = None, 530 | **kwargs, 531 | ): 532 | attention_out = self.attn.o_proj(sdpa_output) 533 | if self.parallel_mlp: 534 | # Reuse the modulation parameters and self.norm1 across attn and mlp 535 | mlp_out = self.mlp(modulated_pre_attention) 536 | return residual + post_attn_scale * (attention_out + mlp_out) 537 | else: 538 | residual = residual + attention_out * post_attn_scale 539 | # Apply separate modulation parameters and LayerNorm across attn and mlp 540 | mlp_out = self.mlp( 541 | affine_transform( 542 | residual, 543 | shift=post_norm2_shift, 544 | residual_scale=post_norm2_residual_scale, 545 | norm_module=self.norm2, 546 | ) 547 | ) 548 | return residual + post_mlp_scale * mlp_out 549 | 550 | def __call__(self): 551 | raise NotImplementedError("This module is not intended to be used directly") 552 | 553 | 554 | class MultiModalTransformerBlock(nn.Module): 555 | def __init__(self, config: MMDiTConfig, skip_text_post_sdpa: bool = False): 556 | super().__init__() 557 | self.image_transformer_block = TransformerBlock(config) 558 | self.text_transformer_block = TransformerBlock( 559 | config, skip_post_sdpa=skip_text_post_sdpa 560 | ) 561 | 562 | sdpa_impl = mx.fast.scaled_dot_product_attention 563 | self.sdpa = partial(sdpa_impl) 564 | 565 | self.config = config 566 | self.per_head_dim = config.hidden_size // config.num_heads 567 | 568 | def __call__( 569 | self, 570 | latent_image_embeddings: mx.array, # latent image embeddings 571 | token_level_text_embeddings: mx.array, # token-level text embeddings 572 | timestep: mx.array, # pooled text embeddings + timestep embeddings 573 | positional_encodings: mx.array = None, # positional encodings for rope 574 | ): 575 | # Prepare multi-modal SDPA inputs 576 | image_intermediates = self.image_transformer_block.pre_sdpa( 577 | latent_image_embeddings, 578 | timestep=timestep, 579 | ) 580 | 581 | text_intermediates = self.text_transformer_block.pre_sdpa( 582 | token_level_text_embeddings, 583 | timestep=timestep, 584 | ) 585 | 586 | batch = latent_image_embeddings.shape[0] 587 | 588 | def rearrange_for_sdpa(t): 589 | # Target data layout: (batch, head, seq_len, channel) 590 | return t.reshape( 591 | batch, -1, self.config.num_heads, self.per_head_dim 592 | ).transpose(0, 2, 1, 3) 593 | 594 | if self.config.depth_unified > 0: 595 | multimodal_sdpa_inputs = { 596 | "q": mx.concatenate( 597 | [text_intermediates["q"], image_intermediates["q"]], axis=2 598 | ), 599 | "k": mx.concatenate( 600 | [text_intermediates["k"], image_intermediates["k"]], axis=2 601 | ), 602 | "v": mx.concatenate( 603 | [text_intermediates["v"], image_intermediates["v"]], axis=2 604 | ), 605 | "scale": 1.0 / np.sqrt(self.per_head_dim), 606 | } 607 | else: 608 | multimodal_sdpa_inputs = { 609 | "q": rearrange_for_sdpa( 610 | mx.concatenate( 611 | [image_intermediates["q"], text_intermediates["q"]], axis=1 612 | ) 613 | ), 614 | "k": rearrange_for_sdpa( 615 | mx.concatenate( 616 | [image_intermediates["k"], text_intermediates["k"]], axis=1 617 | ) 618 | ), 619 | "v": rearrange_for_sdpa( 620 | mx.concatenate( 621 | [image_intermediates["v"], text_intermediates["v"]], axis=1 622 | ) 623 | ), 624 | "scale": 1.0 / np.sqrt(self.per_head_dim), 625 | } 626 | 627 | if self.config.pos_embed_type == PositionalEncoding.PreSDPARope: 628 | assert positional_encodings is not None 629 | multimodal_sdpa_inputs["q"] = RoPE.apply( 630 | multimodal_sdpa_inputs["q"], positional_encodings 631 | ) 632 | multimodal_sdpa_inputs["k"] = RoPE.apply( 633 | multimodal_sdpa_inputs["k"], positional_encodings 634 | ) 635 | 636 | if self.config.low_memory_mode: 637 | multimodal_sdpa_inputs[ 638 | "memory_efficient_threshold" 639 | ] = SDPA_FLASH_ATTN_THRESHOLD 640 | 641 | # Compute multi-modal SDPA 642 | sdpa_outputs = ( 643 | self.sdpa(**multimodal_sdpa_inputs) 644 | .transpose(0, 2, 1, 3) 645 | .reshape(batch, -1, 1, self.config.hidden_size) 646 | ) 647 | 648 | # Split into image-text sequences for post-SDPA layers 649 | img_seq_len = latent_image_embeddings.shape[1] 650 | txt_seq_len = token_level_text_embeddings.shape[1] 651 | 652 | if self.config.depth_unified > 0: 653 | text_sdpa_output = sdpa_outputs[:, :txt_seq_len, :, :] 654 | image_sdpa_output = sdpa_outputs[:, txt_seq_len:, :, :] 655 | else: 656 | image_sdpa_output = sdpa_outputs[:, :img_seq_len, :, :] 657 | text_sdpa_output = sdpa_outputs[:, -txt_seq_len:, :, :] 658 | 659 | # Post-SDPA layers 660 | latent_image_embeddings = self.image_transformer_block.post_sdpa( 661 | residual=latent_image_embeddings, 662 | sdpa_output=image_sdpa_output, 663 | **image_intermediates, 664 | ) 665 | if self.text_transformer_block.skip_post_sdpa: 666 | # Text token related outputs from the final layer do not impact the model output 667 | token_level_text_embeddings = None 668 | else: 669 | token_level_text_embeddings = self.text_transformer_block.post_sdpa( 670 | residual=token_level_text_embeddings, 671 | sdpa_output=text_sdpa_output, 672 | **text_intermediates, 673 | ) 674 | 675 | return latent_image_embeddings, token_level_text_embeddings 676 | 677 | 678 | class UnifiedTransformerBlock(nn.Module): 679 | def __init__(self, config: MMDiTConfig): 680 | super().__init__() 681 | self.transformer_block = TransformerBlock( 682 | config, 683 | num_modulation_params=3 if config.parallel_mlp_for_unified_blocks else 6, 684 | parallel_mlp=config.parallel_mlp_for_unified_blocks, 685 | ) 686 | 687 | sdpa_impl = mx.fast.scaled_dot_product_attention 688 | self.sdpa = partial(sdpa_impl) 689 | 690 | self.config = config 691 | self.per_head_dim = config.hidden_size // config.num_heads 692 | 693 | def __call__( 694 | self, 695 | latent_unified_embeddings: mx.array, # latent image embeddings 696 | timestep: mx.array, # pooled text embeddings + timestep embeddings 697 | positional_encodings: mx.array = None, # positional encodings for rope 698 | ): 699 | # Prepare multi-modal SDPA inputs 700 | intermediates = self.transformer_block.pre_sdpa( 701 | latent_unified_embeddings, 702 | timestep=timestep, 703 | ) 704 | 705 | batch = latent_unified_embeddings.shape[0] 706 | 707 | def rearrange_for_sdpa(t): 708 | # Target data layout: (batch, head, seq_len, channel) 709 | return t.reshape( 710 | batch, -1, self.config.num_heads, self.per_head_dim 711 | ).transpose(0, 2, 1, 3) 712 | 713 | multimodal_sdpa_inputs = { 714 | "q": intermediates["q"], 715 | "k": intermediates["k"], 716 | "v": intermediates["v"], 717 | "scale": 1.0 / np.sqrt(self.per_head_dim), 718 | } 719 | 720 | if self.config.pos_embed_type == PositionalEncoding.PreSDPARope: 721 | assert positional_encodings is not None 722 | multimodal_sdpa_inputs["q"] = RoPE.apply( 723 | multimodal_sdpa_inputs["q"], positional_encodings 724 | ) 725 | multimodal_sdpa_inputs["k"] = RoPE.apply( 726 | multimodal_sdpa_inputs["k"], positional_encodings 727 | ) 728 | 729 | if self.config.low_memory_mode: 730 | multimodal_sdpa_inputs[ 731 | "memory_efficient_threshold" 732 | ] = SDPA_FLASH_ATTN_THRESHOLD 733 | 734 | # Compute multi-modal SDPA 735 | sdpa_outputs = ( 736 | self.sdpa(**multimodal_sdpa_inputs) 737 | .transpose(0, 2, 1, 3) 738 | .reshape(batch, -1, 1, self.config.hidden_size) 739 | ) 740 | 741 | # o_proj and mlp.fc2 uses the same bias, remove mlp.fc2 bias 742 | self.transformer_block.mlp.fc2.bias = self.transformer_block.mlp.fc2.bias * 0.0 743 | 744 | # Post-SDPA layers 745 | latent_unified_embeddings = self.transformer_block.post_sdpa( 746 | residual=latent_unified_embeddings, 747 | sdpa_output=sdpa_outputs, 748 | **intermediates, 749 | ) 750 | 751 | return latent_unified_embeddings 752 | 753 | 754 | class QKNorm(nn.Module): 755 | def __init__(self, head_dim): 756 | super().__init__() 757 | self.q_norm = nn.RMSNorm(head_dim, eps=1e-6) 758 | self.k_norm = nn.RMSNorm(head_dim, eps=1e-6) 759 | 760 | def __call__(self, q: mx.array, k: mx.array) -> Tuple[mx.array, mx.array]: 761 | # Note: mlx.nn.RMSNorm has high precision accumulation (does not require upcasting) 762 | q = self.q_norm(q) 763 | k = self.k_norm(k) 764 | return q, k 765 | 766 | 767 | class FinalLayer(nn.Module): 768 | def __init__(self, config: MMDiTConfig): 769 | super().__init__() 770 | self.norm_final = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 771 | self.linear = nn.Linear( 772 | config.hidden_size, 773 | (config.patch_size**2) * config.vae_latent_dim, 774 | ) 775 | self.adaLN_modulation = nn.Sequential( 776 | nn.SiLU(), 777 | nn.Linear(config.hidden_size, 2 * config.hidden_size), 778 | ) 779 | 780 | def __call__( 781 | self, 782 | latent_image_embeddings: mx.array, 783 | timestep: mx.array, 784 | ) -> mx.array: 785 | if timestep.size > 1: 786 | timestep = timestep[0] 787 | modulation_params = self._modulation_params[timestep.item()] 788 | 789 | shift, residual_scale = mx.split(modulation_params, 2, axis=-1) 790 | latent_image_embeddings = affine_transform( 791 | latent_image_embeddings, 792 | shift=shift, 793 | residual_scale=residual_scale, 794 | norm_module=self.norm_final, 795 | ) 796 | return self.linear(latent_image_embeddings) 797 | 798 | 799 | class Attention(nn.Module): 800 | def __init__( 801 | self, 802 | embed_dim: int, 803 | n_heads: int, 804 | ): 805 | super().__init__() 806 | 807 | # Configure dimensions for SDPA 808 | self.embed_dim = embed_dim 809 | self.n_heads = n_heads 810 | assert ( 811 | self.embed_dim % self.n_heads == 0 812 | ), "Embedding dimension must be divisible by number of heads" 813 | 814 | self._sdpa_implementation = mx.fast.scaled_dot_product_attention 815 | 816 | # Initialize layers 817 | self.per_head_dim = self.embed_dim // self.n_heads 818 | self.kv_proj_embed_dim = self.per_head_dim * n_heads 819 | 820 | # Note: key bias is redundant due to softmax invariance 821 | self.k_proj = nn.Linear(embed_dim, self.kv_proj_embed_dim) 822 | self.q_proj = nn.Linear(embed_dim, embed_dim) 823 | self.v_proj = nn.Linear(embed_dim, self.kv_proj_embed_dim) 824 | self.o_proj = nn.Linear(embed_dim, embed_dim) 825 | 826 | 827 | class FFN(nn.Module): 828 | def __init__(self, embed_dim, expansion_factor, activation_fn): 829 | super().__init__() 830 | self.fc1 = nn.Linear(embed_dim, embed_dim * expansion_factor) 831 | self.act_fn = activation_fn 832 | self.fc2 = nn.Linear(embed_dim * expansion_factor, embed_dim) 833 | 834 | def __call__(self, x: mx.array) -> mx.array: 835 | return self.fc2(self.act_fn(self.fc1(x))) 836 | 837 | 838 | class LayerNorm(nn.Module): 839 | def __init__(self, num_channels, eps=1e-5): 840 | super().__init__() 841 | self.num_channels = num_channels 842 | self.eps = eps 843 | 844 | def __call__(self, inputs: mx.array) -> mx.array: 845 | input_rank = len(inputs.shape) 846 | if input_rank != 4: 847 | raise ValueError(f"Input tensor must have rank 4, got {input_rank}") 848 | 849 | return mx.fast.layer_norm(inputs, weight=None, bias=None, eps=self.eps) 850 | 851 | 852 | class RoPE(nn.Module): 853 | """Custom RoPE implementation for FLUX""" 854 | 855 | def __init__(self, theta: int, axes_dim: List[int]) -> None: 856 | super().__init__() 857 | self.theta = theta 858 | self.axes_dim = axes_dim 859 | 860 | # Cache for consecutive identical calls 861 | self.rope_embeddings = None 862 | self.last_image_resolution = None 863 | self.last_text_sequence_length = None 864 | 865 | def _get_positions( 866 | self, latent_image_resolution: Tuple[int], text_sequence_length: int 867 | ) -> mx.array: 868 | h, w = latent_image_resolution 869 | image_positions = mx.stack( 870 | [ 871 | mx.zeros((h, w)), 872 | mx.repeat(mx.arange(h)[:, None], w, axis=1), 873 | mx.repeat(mx.arange(w)[None, :], h, axis=0), 874 | ], 875 | axis=-1, 876 | ).flatten( 877 | 0, 1 878 | ) # (h * w, 3) 879 | 880 | text_and_image_positions = mx.concatenate( 881 | [ 882 | mx.zeros((text_sequence_length, 3)), 883 | image_positions, 884 | ], 885 | axis=0, 886 | )[ 887 | None 888 | ] # (text_sequence_length + h * w, 3) 889 | 890 | return text_and_image_positions 891 | 892 | def rope(self, positions: mx.array, dim: int, theta: int = 10_000) -> mx.array: 893 | def _rope_per_dim(positions, dim, theta): 894 | scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim 895 | omega = 1.0 / (theta**scale) 896 | out = ( 897 | positions[..., None] * omega[None, None, :] 898 | ) # mx.einsum("bn,d->bnd", positions, omega) 899 | return mx.stack( 900 | [mx.cos(out), -mx.sin(out), mx.sin(out), mx.cos(out)], axis=-1 901 | ).reshape(*positions.shape, dim // 2, 2, 2) 902 | 903 | return mx.concatenate( 904 | [ 905 | _rope_per_dim( 906 | positions=positions[..., i], dim=self.axes_dim[i], theta=self.theta 907 | ) 908 | for i in range(len(self.axes_dim)) 909 | ], 910 | axis=-3, 911 | ).astype(positions.dtype) 912 | 913 | def __call__( 914 | self, latent_image_resolution: Tuple[int], text_sequence_length: int 915 | ) -> mx.array: 916 | identical_to_last_call = ( 917 | latent_image_resolution == self.last_image_resolution 918 | and text_sequence_length == self.last_text_sequence_length 919 | ) 920 | 921 | if self.rope_embeddings is None or not identical_to_last_call: 922 | self.last_image_resolution = latent_image_resolution 923 | self.last_text_sequence_length = text_sequence_length 924 | positions = self._get_positions( 925 | latent_image_resolution, text_sequence_length 926 | ) 927 | self.rope_embeddings = self.rope(positions, self.theta) 928 | self.rope_embeddings = mx.expand_dims(self.rope_embeddings, axis=1) 929 | else: 930 | logger.debug("Returning cached RoPE embeddings") 931 | 932 | return self.rope_embeddings 933 | 934 | @staticmethod 935 | def apply(q_or_k: mx.array, rope: mx.array) -> mx.array: 936 | in_dtype = q_or_k.dtype 937 | q_or_k = q_or_k.astype(mx.float32).reshape(*q_or_k.shape[:-1], -1, 1, 2) 938 | return ( 939 | (rope[..., 0] * q_or_k[..., 0] + rope[..., 1] * q_or_k[..., 1]) 940 | .astype(in_dtype) 941 | .flatten(-2) 942 | ) 943 | 944 | 945 | class MLPEmbedder(nn.Module): 946 | def __init__(self, in_dim: int, hidden_dim: int): 947 | super().__init__() 948 | self.mlp = nn.Sequential( 949 | nn.Linear(in_dim, hidden_dim), 950 | nn.SiLU(), 951 | nn.Linear(hidden_dim, hidden_dim), 952 | ) 953 | 954 | def __call__(self, x): 955 | return self.mlp(x) 956 | 957 | 958 | def affine_transform( 959 | x: mx.array, 960 | shift: mx.array, 961 | residual_scale: mx.array, 962 | norm_module: nn.Module = None, 963 | ) -> mx.array: 964 | """Affine transformation (Used for Adaptive LayerNorm Modulation)""" 965 | if x.shape[0] == 1 and norm_module is not None: 966 | return mx.fast.layer_norm( 967 | x, 1.0 + residual_scale.squeeze(), shift.squeeze(), norm_module.eps 968 | ) 969 | elif norm_module is not None: 970 | return norm_module(x) * (1.0 + residual_scale) + shift 971 | else: 972 | return x * (1.0 + residual_scale) + shift 973 | 974 | 975 | def unpatchify( 976 | x: mx.array, 977 | patch_size: int, 978 | target_height: int, 979 | target_width: int, 980 | vae_latent_dim: int, 981 | ) -> mx.array: 982 | """Unpatchify to restore VAE latent space compatible data format""" 983 | h, w = target_height // patch_size, target_width // patch_size 984 | x = x.reshape(x.shape[0], h, w, patch_size, patch_size, vae_latent_dim) 985 | x = x.transpose(0, 5, 1, 3, 2, 4) # x = mx.einsum("bhwpqc->bchpwq", x) 986 | return x.reshape(x.shape[0], vae_latent_dim, target_height, target_width).transpose( 987 | 0, 2, 3, 1 988 | ) 989 | -------------------------------------------------------------------------------- /diffusionkit/mlx/model_io.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion 2 | 3 | # 4 | # For licensing see accompanying LICENSE.md file. 5 | # Copyright (C) 2024 Argmax, Inc. All Rights Reserved. 6 | # 7 | 8 | import json 9 | from typing import Optional 10 | 11 | import mlx.core as mx 12 | from huggingface_hub import hf_hub_download 13 | from mlx import nn 14 | from mlx.utils import tree_flatten, tree_unflatten 15 | from transformers import T5Config 16 | 17 | from .clip import CLIPTextModel 18 | from .config import ( 19 | FLUX_SCHNELL, 20 | AutoencoderConfig, 21 | CLIPTextModelConfig, 22 | SD3_2b, 23 | VAEDecoderConfig, 24 | VAEEncoderConfig, 25 | ) 26 | from .mmdit import MMDiT 27 | from .t5 import SD3T5Encoder 28 | from .tokenizer import T5Tokenizer, Tokenizer 29 | from .vae import Autoencoder, VAEDecoder, VAEEncoder 30 | 31 | # import argmaxtools.mlx.utils as axu 32 | 33 | 34 | RANK = 32 35 | _DEFAULT_MMDIT = "argmaxinc/mlx-stable-diffusion-3-medium" 36 | _MMDIT = { 37 | "argmaxinc/mlx-stable-diffusion-3-medium": { 38 | "argmaxinc/mlx-stable-diffusion-3-medium": "sd3_medium.safetensors", 39 | "vae": "sd3_medium.safetensors", 40 | }, 41 | "argmaxinc/mlx-FLUX.1-schnell": { 42 | "argmaxinc/mlx-FLUX.1-schnell": "flux-schnell.safetensors", 43 | "vae": "ae.safetensors", 44 | }, 45 | "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": { 46 | "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "flux-schnell-4bit-quantized.safetensors", 47 | "vae": "ae.safetensors", 48 | }, 49 | "argmaxinc/mlx-FLUX.1-dev": { 50 | "argmaxinc/mlx-FLUX.1-dev": "flux1-dev.safetensors", 51 | "vae": "ae.safetensors", 52 | }, 53 | } 54 | _DEFAULT_MODEL = "argmaxinc/stable-diffusion" 55 | _MODELS = { 56 | "argmaxinc/stable-diffusion": { 57 | "clip_l_config": "clip_l/config.json", 58 | "clip_l": "clip_l/model.fp16.safetensors", 59 | "clip_g_config": "clip_g/config.json", 60 | "clip_g": "clip_g/model.fp16.safetensors", 61 | "tokenizer_l_vocab": "tokenizer_l/vocab.json", 62 | "tokenizer_l_merges": "tokenizer_l/merges.txt", 63 | "tokenizer_g_vocab": "tokenizer_g/vocab.json", 64 | "tokenizer_g_merges": "tokenizer_g/merges.txt", 65 | "t5": "t5/t5xxl.safetensors", 66 | }, 67 | } 68 | 69 | _PREFIX = { 70 | "argmaxinc/mlx-stable-diffusion-3-medium": { 71 | "vae_encoder": "first_stage_model.encoder.", 72 | "vae_decoder": "first_stage_model.decoder.", 73 | }, 74 | "argmaxinc/mlx-FLUX.1-schnell": { 75 | "vae_encoder": "encoder.", 76 | "vae_decoder": "decoder.", 77 | }, 78 | "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": { 79 | "vae_encoder": "encoder.", 80 | "vae_decoder": "decoder.", 81 | }, 82 | "argmaxinc/mlx-FLUX.1-dev": { 83 | "vae_encoder": "encoder.", 84 | "vae_decoder": "decoder.", 85 | }, 86 | } 87 | 88 | _FLOAT16 = mx.bfloat16 89 | 90 | DEPTH = { 91 | "argmaxinc/mlx-stable-diffusion-3-medium": 24, 92 | "sd3-8b-unreleased": 38, 93 | } 94 | MAX_LATENT_RESOLUTION = { 95 | "argmaxinc/mlx-stable-diffusion-3-medium": 96, 96 | "sd3-8b-unreleased": 192, 97 | } 98 | 99 | LOCAl_SD3_CKPT = None 100 | 101 | 102 | def flux_state_dict_adjustments(state_dict, prefix="", hidden_size=3072, mlp_ratio=4): 103 | state_dict = { 104 | k.replace("double_blocks", "multimodal_transformer_blocks"): v 105 | for k, v in state_dict.items() 106 | } 107 | state_dict = { 108 | k.replace("single_blocks", "unified_transformer_blocks"): v 109 | for k, v in state_dict.items() 110 | } 111 | 112 | # Split qkv proj and rename: 113 | # *transformer_block.attn.qkv.{weigth/bias} -> transformer_block.attn.{q/k/v}_proj.{weigth/bias} 114 | # *transformer_block.attn.proj.{weigth/bias} -> transformer_block.attn.o_proj.{weight/bias} 115 | keys_to_pop = [] 116 | state_dict_update = {} 117 | for k in state_dict: 118 | if "attn.qkv" in k: 119 | keys_to_pop.append(k) 120 | for name, weight in zip(["q", "k", "v"], mx.split(state_dict[k], 3)): 121 | state_dict_update[k.replace("attn.qkv", f"attn.{name}_proj")] = ( 122 | weight if "weight" in k else weight 123 | ) 124 | 125 | [state_dict.pop(k) for k in keys_to_pop] 126 | state_dict.update(state_dict_update) 127 | 128 | state_dict = { 129 | k.replace("txt_attn", "text_transformer_block.attn"): v 130 | for k, v in state_dict.items() 131 | } 132 | state_dict = { 133 | k.replace("img_attn", "image_transformer_block.attn"): v 134 | for k, v in state_dict.items() 135 | } 136 | 137 | state_dict = { 138 | k.replace("txt_mlp.0", "text_transformer_block.mlp.fc1"): v 139 | for k, v in state_dict.items() 140 | } 141 | state_dict = { 142 | k.replace("txt_mlp.2", "text_transformer_block.mlp.fc2"): v 143 | for k, v in state_dict.items() 144 | } 145 | state_dict = { 146 | k.replace("img_mlp.0", "image_transformer_block.mlp.fc1"): v 147 | for k, v in state_dict.items() 148 | } 149 | state_dict = { 150 | k.replace("img_mlp.2", "image_transformer_block.mlp.fc2"): v 151 | for k, v in state_dict.items() 152 | } 153 | 154 | state_dict = { 155 | k.replace("img_mod.lin", "image_transformer_block.adaLN_modulation.layers.1"): v 156 | for k, v in state_dict.items() 157 | } 158 | state_dict = { 159 | k.replace("txt_mod.lin", "text_transformer_block.adaLN_modulation.layers.1"): v 160 | for k, v in state_dict.items() 161 | } 162 | 163 | state_dict = {k.replace(".proj", ".o_proj"): v for k, v in state_dict.items()} 164 | 165 | state_dict = { 166 | k.replace(".attn.norm.key_norm.scale", ".qk_norm.k_norm.weight"): v 167 | for k, v in state_dict.items() 168 | } 169 | state_dict = { 170 | k.replace(".attn.norm.query_norm.scale", ".qk_norm.q_norm.weight"): v 171 | for k, v in state_dict.items() 172 | } 173 | 174 | state_dict = { 175 | k.replace(".modulation.lin", ".transformer_block.adaLN_modulation.layers.1"): v 176 | for k, v in state_dict.items() 177 | } 178 | state_dict = { 179 | k.replace(".norm.key_norm.scale", ".transformer_block.qk_norm.k_norm.weight"): v 180 | for k, v in state_dict.items() 181 | } 182 | state_dict = { 183 | k.replace( 184 | ".norm.query_norm.scale", ".transformer_block.qk_norm.q_norm.weight" 185 | ): v 186 | for k, v in state_dict.items() 187 | } 188 | 189 | # Split qkv proj and mlp in unified transformer block and rename: 190 | keys_to_pop = [] 191 | state_dict_update = {} 192 | for k in state_dict: 193 | if ".linear1" in k: 194 | keys_to_pop.append(k) 195 | for name, weight in zip( 196 | ["attn.q", "attn.k", "attn.v", "mlp.fc1"], 197 | mx.split( 198 | state_dict[k], 199 | [ 200 | hidden_size, 201 | 2 * hidden_size, 202 | 3 * hidden_size, 203 | (3 + mlp_ratio) * hidden_size, 204 | ], 205 | ), 206 | ): 207 | if name == "mlp.fc1": 208 | state_dict_update[ 209 | k.replace(".linear1", f".transformer_block.{name}") 210 | ] = (weight if "weight" in k else weight) 211 | else: 212 | state_dict_update[ 213 | k.replace(".linear1", f".transformer_block.{name}_proj") 214 | ] = (weight if "weight" in k else weight) 215 | 216 | [state_dict.pop(k) for k in keys_to_pop] 217 | state_dict.update(state_dict_update) 218 | 219 | # Split o_proj and mlp in unified transformer block and rename: 220 | keys_to_pop = [] 221 | state_dict_update = {} 222 | for k in state_dict: 223 | if ".linear2" in k: 224 | keys_to_pop.append(k) 225 | if "bias" in k: 226 | state_dict_update[ 227 | k.replace(".linear2", ".transformer_block.attn.o_proj") 228 | ] = state_dict[k] 229 | state_dict_update[ 230 | k.replace(".linear2", ".transformer_block.mlp.fc2") 231 | ] = state_dict[k] 232 | else: 233 | for name, weight in zip( 234 | ["attn.o", "mlp.fc2"], 235 | mx.split( 236 | state_dict[k], 237 | [hidden_size, (1 + mlp_ratio) * hidden_size], 238 | axis=1, 239 | ), 240 | ): 241 | if name == "mlp.fc2": 242 | state_dict_update[ 243 | k.replace(".linear2", f".transformer_block.{name}") 244 | ] = (weight if "weight" in k else weight) 245 | else: 246 | state_dict_update[ 247 | k.replace(".linear2", f".transformer_block.{name}_proj") 248 | ] = (weight if "weight" in k else weight) 249 | 250 | [state_dict.pop(k) for k in keys_to_pop] 251 | state_dict.update(state_dict_update) 252 | 253 | state_dict = { 254 | k.replace("img_in.", "x_embedder.proj."): v for k, v in state_dict.items() 255 | } 256 | state_dict = { 257 | k.replace("txt_in.", "context_embedder."): v for k, v in state_dict.items() 258 | } 259 | state_dict = { 260 | k.replace("time_in.", "t_embedder."): v for k, v in state_dict.items() 261 | } 262 | state_dict = { 263 | k.replace("vector_in.", "y_embedder."): v for k, v in state_dict.items() 264 | } 265 | state_dict = { 266 | k.replace(".in_layer.", ".mlp.layers.0."): v for k, v in state_dict.items() 267 | } 268 | state_dict = { 269 | k.replace(".out_layer.", ".mlp.layers.2."): v for k, v in state_dict.items() 270 | } 271 | 272 | state_dict = { 273 | k.replace( 274 | "final_layer.adaLN_modulation.1", "final_layer.adaLN_modulation.layers.1" 275 | ): v 276 | for k, v in state_dict.items() 277 | } 278 | 279 | state_dict["x_embedder.proj.weight"] = mx.expand_dims( 280 | mx.expand_dims(state_dict["x_embedder.proj.weight"], axis=1), axis=1 281 | ) 282 | 283 | return state_dict 284 | 285 | 286 | def mmdit_state_dict_adjustments(state_dict, prefix=""): 287 | # Remove prefix 288 | state_dict = {k.lstrip(prefix): v for k, v in state_dict.items()} 289 | 290 | state_dict = { 291 | k.replace("y_embedder.mlp", "y_embedder.mlp.layers"): v 292 | for k, v in state_dict.items() 293 | } 294 | state_dict = { 295 | k.replace("t_embedder.mlp", "t_embedder.mlp.layers"): v 296 | for k, v in state_dict.items() 297 | } 298 | state_dict = { 299 | k.replace("adaLN_modulation", "adaLN_modulation.layers"): v 300 | for k, v in state_dict.items() 301 | } 302 | state_dict = { 303 | k.replace("al_layer", "final_layer"): v for k, v in state_dict.items() 304 | } 305 | 306 | # Rename joint_blocks -> multimodal_transformer_blocks 307 | state_dict = { 308 | k.replace("joint_blocks", "multimodal_transformer_blocks"): v 309 | for k, v in state_dict.items() 310 | } 311 | 312 | # Remap context_block -> text_block 313 | state_dict = { 314 | k.replace("context_block", "text_transformer_block"): v 315 | for k, v in state_dict.items() 316 | } 317 | 318 | # Remap x_block -> image_block 319 | state_dict = { 320 | k.replace("x_block", "image_transformer_block"): v 321 | for k, v in state_dict.items() 322 | } 323 | 324 | # Split qkv proj and rename: 325 | # *transformer_block.attn.qkv.{weigth/bias} -> transformer_block.attn.{q/k/v}_proj.{weigth/bias} 326 | # *transformer_block.attn.proj.{weigth/bias} -> transformer_block.attn.o_proj.{weight/bias} 327 | keys_to_pop = [] 328 | state_dict_update = {} 329 | for k in state_dict: 330 | if "attn.qkv" in k: 331 | keys_to_pop.append(k) 332 | for name, weight in zip(["q", "k", "v"], mx.split(state_dict[k], 3)): 333 | state_dict_update[k.replace("attn.qkv", f"attn.{name}_proj")] = ( 334 | weight if "weight" in k else weight 335 | ) 336 | 337 | [state_dict.pop(k) for k in keys_to_pop] 338 | state_dict.update(state_dict_update) 339 | 340 | state_dict = { 341 | k.replace("attn.proj", "attn.o_proj"): ( 342 | v if "attn.proj" in k and "weight" in k else v 343 | ) 344 | for k, v in state_dict.items() 345 | } 346 | 347 | # Filter out VAE Decoder related tensors 348 | state_dict = {k: v for k, v in state_dict.items() if "decoder." not in k} 349 | 350 | # Filter out k_proj.bias related tensors 351 | state_dict = {k: v for k, v in state_dict.items() if "k_proj.bias" not in k} 352 | 353 | # Filter out teacher_model related tensors 354 | state_dict = {k: v for k, v in state_dict.items() if "teacher_model." not in k} 355 | 356 | # Remap pos_embed buffer -> nn.Embedding 357 | state_dict = { 358 | k.replace("pos_embed", "x_pos_embedder.pos_embed.weight"): ( 359 | v[0] if "pos_embed" in k else v 360 | ) 361 | for k, v in state_dict.items() 362 | } 363 | 364 | # Transpose x_embedder.proj.weight 365 | state_dict["x_embedder.proj.weight"] = state_dict[ 366 | "x_embedder.proj.weight" 367 | ].transpose(0, 2, 3, 1) 368 | 369 | return state_dict 370 | 371 | 372 | def vae_decoder_state_dict_adjustments(state_dict, prefix="decoder."): 373 | # Keep only the keys that have the prefix 374 | state_dict = {k: v for k, v in state_dict.items() if prefix in k} 375 | state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items()} 376 | 377 | # Filter out MMDIT related tensors 378 | state_dict = {k: v for k, v in state_dict.items() if "diffusion_model." not in k} 379 | 380 | state_dict = {k.replace("up", "up_blocks"): v for k, v in state_dict.items()} 381 | state_dict = {k.replace("mid", "mid_blocks"): v for k, v in state_dict.items()} 382 | 383 | state_dict = { 384 | k.replace("mid_blocks.block_1", "mid_blocks.0"): v 385 | for k, v in state_dict.items() 386 | } 387 | state_dict = { 388 | k.replace("mid_blocks.block_2", "mid_blocks.2"): v 389 | for k, v in state_dict.items() 390 | } 391 | state_dict = { 392 | k.replace("mid_blocks.attn_1", "mid_blocks.1"): v for k, v in state_dict.items() 393 | } 394 | 395 | state_dict = {k.replace(".norm.", ".group_norm."): v for k, v in state_dict.items()} 396 | 397 | state_dict = {k.replace(".q", ".query_proj"): v for k, v in state_dict.items()} 398 | state_dict = {k.replace(".k", ".key_proj"): v for k, v in state_dict.items()} 399 | state_dict = {k.replace(".v", ".value_proj"): v for k, v in state_dict.items()} 400 | state_dict = {k.replace(".proj_out", ".out_proj"): v for k, v in state_dict.items()} 401 | 402 | state_dict = {k.replace(".block.", ".resnets."): v for k, v in state_dict.items()} 403 | state_dict = { 404 | k.replace(".nin_shortcut.", ".conv_shortcut."): v for k, v in state_dict.items() 405 | } 406 | state_dict = { 407 | k.replace(".up_blockssample.conv.", ".upsample."): v 408 | for k, v in state_dict.items() 409 | } 410 | 411 | state_dict = { 412 | k.replace("norm_out", "conv_norm_out"): v for k, v in state_dict.items() 413 | } 414 | 415 | # reshape weights 416 | 417 | state_dict = { 418 | k: v.transpose(0, 2, 3, 1) if "upsample" in k and "weight" in k else v 419 | for k, v in state_dict.items() 420 | } 421 | state_dict = { 422 | k: ( 423 | v.transpose(0, 2, 3, 1) 424 | if "resnets" in k and "conv" in k and "weight" in k 425 | else v 426 | ) 427 | for k, v in state_dict.items() 428 | } 429 | state_dict = { 430 | k: ( 431 | v.transpose(0, 2, 3, 1) 432 | if "mid_blocks" in k and "conv" in k and "weight" in k 433 | else v 434 | ) 435 | for k, v in state_dict.items() 436 | } 437 | state_dict = { 438 | k: v[:, 0, 0, :] if "conv_shortcut.weight" in k else v 439 | for k, v in state_dict.items() 440 | } 441 | state_dict = { 442 | k: v[:, :, 0, 0] if "proj.weight" in k else v for k, v in state_dict.items() 443 | } 444 | state_dict["conv_in.weight"] = state_dict["conv_in.weight"].transpose(0, 2, 3, 1) 445 | state_dict["conv_out.weight"] = state_dict["conv_out.weight"].transpose(0, 2, 3, 1) 446 | 447 | return state_dict 448 | 449 | 450 | def vae_encoder_state_dict_adjustments(state_dict, prefix="encoder."): 451 | # Keep only the keys that have the prefix 452 | state_dict = {k: v for k, v in state_dict.items() if prefix in k} 453 | state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items()} 454 | 455 | # Filter out MMDIT related tensors 456 | state_dict = {k: v for k, v in state_dict.items() if "diffusion_model." not in k} 457 | 458 | state_dict = {k.replace("down.", "down_blocks."): v for k, v in state_dict.items()} 459 | state_dict = { 460 | k.replace(".downsample.conv.", ".downsample."): v for k, v in state_dict.items() 461 | } 462 | state_dict = {k.replace(".block.", ".resnets."): v for k, v in state_dict.items()} 463 | state_dict = { 464 | k.replace(".nin_shortcut.", ".conv_shortcut."): v for k, v in state_dict.items() 465 | } 466 | 467 | state_dict = {k.replace(".q", ".query_proj"): v for k, v in state_dict.items()} 468 | state_dict = {k.replace(".k", ".key_proj"): v for k, v in state_dict.items()} 469 | state_dict = {k.replace(".v", ".value_proj"): v for k, v in state_dict.items()} 470 | state_dict = {k.replace(".proj_out", ".out_proj"): v for k, v in state_dict.items()} 471 | 472 | state_dict = {k.replace("mid", "mid_blocks"): v for k, v in state_dict.items()} 473 | 474 | state_dict = { 475 | k.replace("mid_blocks.block_1", "mid_blocks.0"): v 476 | for k, v in state_dict.items() 477 | } 478 | state_dict = { 479 | k.replace("mid_blocks.block_2", "mid_blocks.2"): v 480 | for k, v in state_dict.items() 481 | } 482 | state_dict = { 483 | k.replace("mid_blocks.attn_1", "mid_blocks.1"): v for k, v in state_dict.items() 484 | } 485 | 486 | state_dict = {k.replace(".norm.", ".group_norm."): v for k, v in state_dict.items()} 487 | state_dict = { 488 | k.replace("norm_out", "conv_norm_out"): v for k, v in state_dict.items() 489 | } 490 | 491 | # reshape weights 492 | 493 | state_dict = { 494 | k: v.transpose(0, 2, 3, 1) if "downsample" in k and "weight" in k else v 495 | for k, v in state_dict.items() 496 | } 497 | state_dict = { 498 | k: ( 499 | v.transpose(0, 2, 3, 1) 500 | if "resnets" in k and "conv" in k and "weight" in k 501 | else v 502 | ) 503 | for k, v in state_dict.items() 504 | } 505 | state_dict = { 506 | k: ( 507 | v.transpose(0, 2, 3, 1) 508 | if "mid_blocks" in k and "conv" in k and "weight" in k 509 | else v 510 | ) 511 | for k, v in state_dict.items() 512 | } 513 | state_dict = { 514 | k: v[:, 0, 0, :] if "conv_shortcut.weight" in k else v 515 | for k, v in state_dict.items() 516 | } 517 | state_dict = { 518 | k: v[:, :, 0, 0] if "proj.weight" in k else v for k, v in state_dict.items() 519 | } 520 | state_dict["conv_in.weight"] = state_dict["conv_in.weight"].transpose(0, 2, 3, 1) 521 | state_dict["conv_out.weight"] = state_dict["conv_out.weight"].transpose(0, 2, 3, 1) 522 | 523 | return state_dict 524 | 525 | 526 | def t5_encoder_state_dict_adjustments(state_dict, prefix=""): 527 | state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items()} 528 | 529 | for i in range(2): 530 | state_dict = { 531 | k.replace(f"layer.{i}.layer_norm", f"ln{i+1}"): v 532 | for k, v in state_dict.items() 533 | } 534 | for i in range(2): 535 | state_dict = {k.replace(f"layer.{i}.", ""): v for k, v in state_dict.items()} 536 | state_dict = {k.replace("block", "layers"): v for k, v in state_dict.items()} 537 | state_dict = { 538 | k.replace("SelfAttention.q", "attention.query_proj"): v 539 | for k, v in state_dict.items() 540 | } 541 | state_dict = { 542 | k.replace("SelfAttention.k", "attention.key_proj"): v 543 | for k, v in state_dict.items() 544 | } 545 | state_dict = { 546 | k.replace("SelfAttention.v", "attention.value_proj"): v 547 | for k, v in state_dict.items() 548 | } 549 | state_dict = { 550 | k.replace("SelfAttention.o", "attention.out_proj"): v 551 | for k, v in state_dict.items() 552 | } 553 | state_dict = { 554 | k.replace("DenseReluDense", "dense"): v for k, v in state_dict.items() 555 | } 556 | 557 | state_dict["encoder.relative_attention_bias.embeddings.weight"] = state_dict[ 558 | "encoder.layers.0.SelfAttention.relative_attention_bias.weight" 559 | ] 560 | del state_dict["encoder.layers.0.SelfAttention.relative_attention_bias.weight"] 561 | 562 | state_dict["wte.weight"] = state_dict["encoder.embed_tokens.weight"] 563 | del state_dict["encoder.embed_tokens.weight"] 564 | del state_dict["shared.weight"] 565 | 566 | state_dict["encoder.ln.weight"] = state_dict["encoder.final_layer_norm.weight"] 567 | del state_dict["encoder.final_layer_norm.weight"] 568 | 569 | return state_dict 570 | 571 | 572 | def map_clip_text_encoder_weights(key, value): 573 | # Remove prefixes 574 | if key.startswith("text_model."): 575 | key = key[11:] 576 | if key.startswith("embeddings."): 577 | key = key[11:] 578 | if key.startswith("encoder."): 579 | key = key[8:] 580 | 581 | # Map attention layers 582 | if "self_attn." in key: 583 | key = key.replace("self_attn.", "attention.") 584 | if "q_proj." in key: 585 | key = key.replace("q_proj.", "query_proj.") 586 | if "k_proj." in key: 587 | key = key.replace("k_proj.", "key_proj.") 588 | if "v_proj." in key: 589 | key = key.replace("v_proj.", "value_proj.") 590 | 591 | # Map ffn layers 592 | if "mlp.fc1" in key: 593 | key = key.replace("mlp.fc1", "linear1") 594 | if "mlp.fc2" in key: 595 | key = key.replace("mlp.fc2", "linear2") 596 | 597 | return [(key, value)] 598 | 599 | 600 | def map_vae_weights(key, value): 601 | # Map up/downsampling 602 | if "downsamplers" in key: 603 | key = key.replace("downsamplers.0.conv", "downsample") 604 | if "upsamplers" in key: 605 | key = key.replace("upsamplers.0.conv", "upsample") 606 | 607 | # Map attention layers 608 | if "to_k" in key: 609 | key = key.replace("to_k", "key_proj") 610 | if "to_out.0" in key: 611 | key = key.replace("to_out.0", "out_proj") 612 | if "to_q" in key: 613 | key = key.replace("to_q", "query_proj") 614 | if "to_v" in key: 615 | key = key.replace("to_v", "value_proj") 616 | 617 | # Map the mid block 618 | if "mid_block.resnets.0" in key: 619 | key = key.replace("mid_block.resnets.0", "mid_blocks.0") 620 | if "mid_block.attentions.0" in key: 621 | key = key.replace("mid_block.attentions.0", "mid_blocks.1") 622 | if "mid_block.resnets.1" in key: 623 | key = key.replace("mid_block.resnets.1", "mid_blocks.2") 624 | 625 | # Map the quant/post_quant layers 626 | if "quant_conv" in key: 627 | key = key.replace("quant_conv", "quant_proj") 628 | value = value.squeeze() 629 | 630 | # Map the conv_shortcut to linear 631 | if "conv_shortcut.weight" in key: 632 | value = value.squeeze() 633 | 634 | if len(value.shape) == 4: 635 | value = value.transpose(0, 2, 3, 1) 636 | value = value.reshape(-1).reshape(value.shape) 637 | 638 | return [(key, value)] 639 | 640 | 641 | """ Code obtained from 642 | https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/model_io.py 643 | """ 644 | 645 | 646 | def _flatten(params): 647 | return [(k, v) for p in params for (k, v) in p] 648 | 649 | 650 | """ Code obtained from 651 | https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/model_io.py 652 | """ 653 | 654 | 655 | def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False): 656 | dtype = _FLOAT16 if float16 else mx.float32 657 | weights = mx.load(weight_file) 658 | weights = _flatten([mapper(k, v.astype(dtype)) for k, v in weights.items()]) 659 | model.update(tree_unflatten(weights)) 660 | 661 | 662 | def _check_key(key: str, part: str): 663 | if key not in _MODELS: 664 | raise ValueError( 665 | f"[{part}] '{key}' model not found, choose one of {{{','.join(_MODELS.keys())}}}" 666 | ) 667 | 668 | 669 | def load_mmdit( 670 | key: str = _DEFAULT_MMDIT, 671 | float16: bool = False, 672 | model_key: str = "mmdit_2b", 673 | low_memory_mode: bool = True, 674 | only_modulation_dict: bool = False, 675 | ): 676 | """Load the MM-DiT model from the checkpoint file.""" 677 | """only_modulation_dict: Only returns the modulation dictionary""" 678 | dtype = _FLOAT16 if float16 else mx.float32 679 | config = SD3_2b 680 | config.low_memory_mode = low_memory_mode 681 | model = MMDiT(config) 682 | 683 | mmdit_weights = _MMDIT[key][model_key] 684 | mmdit_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, mmdit_weights) 685 | hf_hub_download(key, "config.json") 686 | weights = mx.load(mmdit_weights_ckpt) 687 | weights = mmdit_state_dict_adjustments(weights, prefix="model.diffusion_model.") 688 | weights = {k: v.astype(dtype) for k, v in weights.items()} 689 | if only_modulation_dict: 690 | weights = {k: v for k, v in weights.items() if "adaLN" in k} 691 | return tree_flatten(weights) 692 | model.update(tree_unflatten(tree_flatten(weights))) 693 | 694 | return model 695 | 696 | 697 | def load_flux( 698 | key: str = "argmaxinc/mlx-FLUX.1-schnell", 699 | float16: bool = False, 700 | model_key: str = "argmaxinc/mlx-FLUX.1-schnell", 701 | low_memory_mode: bool = True, 702 | only_modulation_dict: bool = False, 703 | ): 704 | """Load the MM-DiT Flux model from the checkpoint file.""" 705 | dtype = _FLOAT16 if float16 else mx.float32 706 | config = FLUX_SCHNELL 707 | config.low_memory_mode = low_memory_mode 708 | model = MMDiT(config) 709 | 710 | flux_weights = _MMDIT[key][model_key] 711 | flux_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, flux_weights) 712 | hf_hub_download(key, "config.json") 713 | weights = mx.load(flux_weights_ckpt) 714 | 715 | if model_key in ["argmaxinc/mlx-FLUX.1-schnell", "argmaxinc/mlx-FLUX.1-dev"]: 716 | weights = flux_state_dict_adjustments( 717 | weights, 718 | prefix="", 719 | hidden_size=config.hidden_size, 720 | mlp_ratio=config.mlp_ratio, 721 | ) 722 | elif ( 723 | model_key == "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized" 724 | ): # 4-bit ckpt already adjusted 725 | nn.quantize(model) 726 | 727 | weights = { 728 | k: v.astype(dtype) if v.dtype != mx.uint32 else v for k, v in weights.items() 729 | } 730 | if only_modulation_dict: 731 | weights = {k: v for k, v in weights.items() if "adaLN" in k} 732 | return tree_flatten(weights) 733 | model.update(tree_unflatten(tree_flatten(weights))) 734 | 735 | return model 736 | 737 | 738 | def load_text_encoder( 739 | key: str = _DEFAULT_MODEL, 740 | float16: bool = False, 741 | model_key: str = "text_encoder", 742 | config_key: Optional[str] = None, 743 | ): 744 | """Load the stable diffusion text encoder from Hugging Face Hub.""" 745 | _check_key(key, "load_text_encoder") 746 | 747 | config_key = config_key or (model_key + "_config") 748 | 749 | # Download the config and create the model 750 | text_encoder_config = _MODELS[key][config_key] 751 | with open(hf_hub_download(key, text_encoder_config)) as f: 752 | config = json.load(f) 753 | 754 | with_projection = "WithProjection" in config["architectures"][0] 755 | 756 | model = CLIPTextModel( 757 | CLIPTextModelConfig( 758 | num_layers=config["num_hidden_layers"], 759 | model_dims=config["hidden_size"], 760 | num_heads=config["num_attention_heads"], 761 | max_length=config["max_position_embeddings"], 762 | vocab_size=config["vocab_size"], 763 | projection_dim=config["projection_dim"] if with_projection else None, 764 | hidden_act=config.get("hidden_act", "quick_gelu"), 765 | ) 766 | ) 767 | 768 | # Download the weights and map them into the model 769 | text_encoder_weights = _MODELS[key][model_key] 770 | weight_file = hf_hub_download(key, text_encoder_weights) 771 | _load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16) 772 | 773 | return model 774 | 775 | 776 | def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False): 777 | """Load the stable diffusion autoencoder from Hugging Face Hub.""" 778 | _check_key(key, "load_autoencoder") 779 | 780 | # Download the config and create the model 781 | vae_config = _MODELS[key]["vae_config"] 782 | with open(hf_hub_download(key, vae_config)) as f: 783 | config = json.load(f) 784 | 785 | config["latent_channels"] = 16 786 | 787 | model = Autoencoder( 788 | AutoencoderConfig( 789 | in_channels=config["in_channels"], 790 | out_channels=config["out_channels"], 791 | latent_channels_out=2 * config["latent_channels"], 792 | latent_channels_in=config["latent_channels"], 793 | block_out_channels=config["block_out_channels"], 794 | layers_per_block=config["layers_per_block"], 795 | norm_num_groups=config["norm_num_groups"], 796 | scaling_factor=config.get("scaling_factor", 0.18215), 797 | ) 798 | ) 799 | 800 | # Download the weights and map them into the model 801 | vae_weights = _MODELS[key]["vae"] 802 | weight_file = hf_hub_download(key, vae_weights) 803 | _load_safetensor_weights(map_vae_weights, model, weight_file, float16) 804 | 805 | return model 806 | 807 | 808 | def load_vae_decoder( 809 | key: str = _DEFAULT_MMDIT, 810 | float16: bool = False, 811 | model_key: str = "vae", 812 | ): 813 | """Load the SD3 VAE Decoder model from the checkpoint file.""" 814 | config = VAEDecoderConfig() 815 | model = VAEDecoder( 816 | in_channels=config.in_channels, 817 | out_channels=config.out_channels, 818 | block_out_channels=config.block_out_channels, 819 | layers_per_block=config.layers_per_block, 820 | resnet_groups=config.resnet_groups, 821 | ) 822 | 823 | dtype = _FLOAT16 if float16 else mx.float32 824 | vae_weights = _MMDIT[key][model_key] 825 | vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, vae_weights) 826 | weights = mx.load(vae_weights_ckpt) 827 | weights = vae_decoder_state_dict_adjustments( 828 | weights, prefix=_PREFIX[key]["vae_decoder"] 829 | ) 830 | weights = {k: v.astype(dtype) for k, v in weights.items()} 831 | model.update(tree_unflatten(tree_flatten(weights))) 832 | 833 | return model 834 | 835 | 836 | def load_vae_encoder( 837 | key: str = _DEFAULT_MMDIT, 838 | float16: bool = False, 839 | model_key: str = "vae", 840 | ): 841 | """Load the SD3 VAE Encoder model from the checkpoint file.""" 842 | config = VAEEncoderConfig() 843 | model = VAEEncoder( 844 | in_channels=config.in_channels, 845 | out_channels=config.out_channels, 846 | block_out_channels=config.block_out_channels, 847 | layers_per_block=config.layers_per_block, 848 | resnet_groups=config.resnet_groups, 849 | ) 850 | 851 | dtype = _FLOAT16 if float16 else mx.float32 852 | vae_weights = _MMDIT[key][model_key] 853 | vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, vae_weights) 854 | weights = mx.load(vae_weights_ckpt) 855 | weights = vae_encoder_state_dict_adjustments( 856 | weights, prefix=_PREFIX[key]["vae_encoder"] 857 | ) 858 | weights = {k: v.astype(dtype) for k, v in weights.items()} 859 | model.update(tree_unflatten(tree_flatten(weights))) 860 | 861 | return model 862 | 863 | 864 | def load_t5_encoder( 865 | key: str = _DEFAULT_MODEL, 866 | float16: bool = False, 867 | model_key: str = "t5", 868 | low_memory_mode: bool = True, 869 | ): 870 | config = T5Config.from_pretrained("google/t5-v1_1-xxl") 871 | model = SD3T5Encoder(config, low_memory_mode=low_memory_mode) 872 | 873 | dtype = _FLOAT16 if float16 else mx.float32 874 | t5_weights = _MODELS[key][model_key] 875 | weights = mx.load(hf_hub_download(key, t5_weights)) 876 | weights = t5_encoder_state_dict_adjustments(weights, prefix="") 877 | weights = {k: v.astype(dtype) for k, v in weights.items()} 878 | model.update(tree_unflatten(tree_flatten(weights))) 879 | 880 | return model 881 | 882 | 883 | def load_tokenizer( 884 | key: str = _DEFAULT_MODEL, 885 | vocab_key: str = "tokenizer_vocab", 886 | merges_key: str = "tokenizer_merges", 887 | pad_with_eos: bool = False, 888 | ): 889 | _check_key(key, "load_tokenizer") 890 | 891 | vocab_file = hf_hub_download(key, _MODELS[key][vocab_key]) 892 | with open(vocab_file, encoding="utf-8") as f: 893 | vocab = json.load(f) 894 | 895 | merges_file = hf_hub_download(key, _MODELS[key][merges_key]) 896 | with open(merges_file, encoding="utf-8") as f: 897 | bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1] 898 | bpe_merges = [tuple(m.split()) for m in bpe_merges] 899 | bpe_ranks = dict(map(reversed, enumerate(bpe_merges))) 900 | 901 | return Tokenizer(bpe_ranks, vocab, pad_with_eos) 902 | 903 | 904 | def load_t5_tokenizer(max_context_length: int = 256): 905 | config = T5Config.from_pretrained("google/t5-v1_1-xxl") 906 | return T5Tokenizer(config, max_context_length) 907 | -------------------------------------------------------------------------------- /diffusionkit/mlx/sampler.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2024 Argmax, Inc. All Rights Reserved. 4 | # 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | 10 | class ModelSamplingDiscreteFlow(nn.Module): 11 | """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models""" 12 | 13 | def __init__(self, shift=1.0): 14 | super().__init__() 15 | self.shift = shift 16 | timesteps = 1000 17 | ts = self.sigma(mx.arange(1, timesteps + 1, 1)) 18 | self.sigmas = ts 19 | 20 | @property 21 | def sigma_min(self): 22 | return self.sigmas[0] 23 | 24 | @property 25 | def sigma_max(self): 26 | return self.sigmas[-1] 27 | 28 | def timestep(self, sigma): 29 | return sigma * 1000 30 | 31 | def sigma(self, timestep: mx.array): 32 | timestep = timestep / 1000.0 33 | if self.shift == 1.0: 34 | return timestep 35 | return self.shift * timestep / (1 + (self.shift - 1) * timestep) 36 | 37 | def calculate_denoised(self, sigma, model_output, model_input): 38 | sigma = sigma.reshape(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) 39 | return model_input - model_output * sigma 40 | 41 | def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): 42 | return sigma * noise + (1.0 - sigma) * latent_image 43 | 44 | 45 | class FluxSampler(nn.Module): 46 | """Helper for sampler scheduling (ie timestep/sigma calculations) for Flux models""" 47 | 48 | def __init__(self, shift=1.0): 49 | super().__init__() 50 | self.shift = shift 51 | timesteps = 1000 52 | ts = self.sigma(mx.arange(0, timesteps + 1, 1)) 53 | self.sigmas = ts 54 | 55 | @property 56 | def sigma_min(self): 57 | return self.sigmas[0] 58 | 59 | @property 60 | def sigma_max(self): 61 | return self.sigmas[-1] 62 | 63 | def timestep(self, sigma): 64 | return sigma * 1000 65 | 66 | def sigma(self, timestep: mx.array): 67 | timestep = timestep / 1000.0 68 | if self.shift == 1.0: 69 | return timestep 70 | return self.shift * timestep / (1 + (self.shift - 1) * timestep) 71 | 72 | def calculate_denoised(self, sigma, model_output, model_input): 73 | sigma = sigma.reshape(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) 74 | return model_input - model_output * sigma 75 | 76 | def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): 77 | return sigma * noise + (1.0 - sigma) * latent_image 78 | -------------------------------------------------------------------------------- /diffusionkit/mlx/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoddnn/ComfyUI-MLX/215691b282f3d1eddb2e7029c2c399567cd0be9b/diffusionkit/mlx/scripts/__init__.py -------------------------------------------------------------------------------- /diffusionkit/mlx/scripts/generate_images.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion 2 | 3 | # For licensing see accompanying LICENSE.md file. 4 | # Copyright (C) 2024 Argmax, Inc. All Rights Reserved. 5 | # 6 | 7 | import argparse 8 | 9 | from argmaxtools.utils import get_logger 10 | from diffusionkit.mlx import MMDIT_CKPT, DiffusionPipeline, FluxPipeline 11 | 12 | logger = get_logger(__name__) 13 | 14 | # Defaults 15 | HEIGHT = { 16 | "argmaxinc/mlx-stable-diffusion-3-medium": 512, 17 | "sd3-8b-unreleased": 1024, 18 | "argmaxinc/mlx-FLUX.1-schnell": 512, 19 | "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512, 20 | "argmaxinc/mlx-FLUX.1-dev": 512, 21 | } 22 | WIDTH = { 23 | "argmaxinc/mlx-stable-diffusion-3-medium": 512, 24 | "sd3-8b-unreleased": 1024, 25 | "argmaxinc/mlx-FLUX.1-schnell": 512, 26 | "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512, 27 | "argmaxinc/mlx-FLUX.1-dev": 512, 28 | } 29 | SHIFT = { 30 | "argmaxinc/mlx-stable-diffusion-3-medium": 3.0, 31 | "sd3-8b-unreleased": 3.0, 32 | "argmaxinc/mlx-FLUX.1-schnell": 1.0, 33 | "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 1.0, 34 | "argmaxinc/mlx-FLUX.1-dev": 1.0, 35 | } 36 | 37 | 38 | def cli(): 39 | parser = argparse.ArgumentParser( 40 | description="Generate images from a text (and an optional image) prompt using Stable Diffusion" 41 | ) 42 | parser.add_argument("--prompt", required=True, help="Text prompt") 43 | parser.add_argument( 44 | "--image-path", type=str, help="Path to the image prompt", default=None 45 | ) 46 | parser.add_argument( 47 | "--model-version", 48 | choices=tuple(MMDIT_CKPT.keys()), 49 | default="argmaxinc/mlx-FLUX.1-schnell", 50 | help="Diffusion model version, e.g. FLUX-1.schnell, stable-diffusion-3-medium", 51 | ) 52 | parser.add_argument( 53 | "--steps", type=int, default=50, help="Number of diffusion steps." 54 | ) 55 | parser.add_argument( 56 | "--cfg", 57 | type=float, 58 | default=5.0, 59 | help="Classifier-free guidance weight", 60 | ) 61 | parser.add_argument("--negative_prompt", default="", help="Negative text prompt") 62 | parser.add_argument( 63 | "--preload-models", 64 | action="store_true", 65 | help="Preload the models in memory. Default version lazy loads the models.", 66 | ) 67 | parser.add_argument( 68 | "--output-path", "-o", default="out.png", help="Path to save the output image." 69 | ) 70 | parser.add_argument( 71 | "--seed", type=int, help="Seed for the random number generator." 72 | ) 73 | parser.add_argument( 74 | "--verbose", "-v", action="store_true", help="Print detailed information." 75 | ) 76 | parser.add_argument( 77 | "--shift", 78 | type=float, 79 | help="Shift for diffusion sampling", 80 | ) 81 | parser.add_argument( 82 | "--t5", 83 | action="store_true", 84 | help="Engages T5 for stronger text embeddings (uses significantly more memory). ", 85 | ) 86 | parser.add_argument("--height", type=int, help="Height of the output image") 87 | parser.add_argument("--width", type=int, help="Width of the output image") 88 | parser.add_argument( 89 | "--no-low-memory-mode", 90 | action="store_false", 91 | dest="low_memory_mode", 92 | help="Disable low memory mode: No models offloading", 93 | ) 94 | parser.add_argument( 95 | "--benchmark-mode", 96 | action="store_true", 97 | help="Run the script in benchmark mode (no memory cleanup).", 98 | ) 99 | parser.add_argument( 100 | "--denoise", 101 | type=float, 102 | default=0.0, 103 | help="Denoising factor when an input image is provided. (between 0.0 and 1.0)", 104 | ) 105 | parser.add_argument( 106 | "--local-ckpt", 107 | default=None, 108 | type=str, 109 | help="Path to the local mmdit checkpoint.", 110 | ) 111 | args = parser.parse_args() 112 | 113 | args.w16 = True 114 | args.a16 = True 115 | 116 | if "FLUX" in args.model_version and args.cfg > 0.0: 117 | logger.warning(f"Disabling CFG for {args.model_version} model.") 118 | args.cfg = 0.0 119 | 120 | if args.benchmark_mode: 121 | if args.low_memory_mode: 122 | logger.warning("Benchmark mode is enabled, disabling low memory mode.") 123 | args.low_memory_mode = False 124 | 125 | if args.denoise < 0.0 or args.denoise > 1.0: 126 | raise ValueError("Denoising factor must be between 0.0 and 1.0") 127 | 128 | shift = args.shift or SHIFT[args.model_version] 129 | pipeline_class = FluxPipeline if "FLUX" in args.model_version else DiffusionPipeline 130 | 131 | # Load the models 132 | sd = pipeline_class( 133 | w16=args.w16, 134 | shift=shift, 135 | use_t5=args.t5, 136 | model_version=args.model_version, 137 | low_memory_mode=args.low_memory_mode, 138 | a16=args.a16, 139 | local_ckpt=args.local_ckpt, 140 | ) 141 | 142 | # Ensure that models are read in memory if needed 143 | if args.preload_models: 144 | sd.ensure_models_are_loaded() 145 | 146 | height = args.height or HEIGHT[args.model_version] 147 | width = args.width or WIDTH[args.model_version] 148 | assert height % 16 == 0, f"Height must be divisible by 16 ({height}/16={height/16})" 149 | assert width % 16 == 0, f"Width must be divisible by 16 ({width}/16={width/16})" 150 | logger.info(f"Output image resolution will be {height}x{width}") 151 | 152 | if args.benchmark_mode: 153 | args.low_memory_mode = False 154 | sd.ensure_models_are_loaded() 155 | logger.info( 156 | "Running in benchmark mode. Warming up the models. (generated latents will be discarded)" 157 | ) 158 | image = sd.generate_image( 159 | args.prompt, 160 | cfg_weight=args.cfg, 161 | num_steps=1, 162 | seed=args.seed, 163 | negative_text=args.negative_prompt, 164 | latent_size=(height // 8, width // 8), 165 | verbose=False, 166 | ) 167 | logger.info("Benchmark mode: Warming up the models done.") 168 | 169 | # Generate the latent vectors using diffusion 170 | image, _ = sd.generate_image( 171 | args.prompt, 172 | cfg_weight=args.cfg, 173 | num_steps=args.steps, 174 | seed=args.seed, 175 | negative_text=args.negative_prompt, 176 | latent_size=(height // 8, width // 8), 177 | image_path=args.image_path, 178 | denoise=args.denoise, 179 | ) 180 | 181 | # Save them to disc 182 | image.save(args.output_path) 183 | logger.info(f"Saved the image to {args.output_path}") 184 | 185 | 186 | if __name__ == "__main__": 187 | cli() 188 | -------------------------------------------------------------------------------- /diffusionkit/mlx/t5.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/ml-explore/mlx-examples/tree/main/t5 2 | 3 | # 4 | # For licensing see accompanying LICENSE.md file. 5 | # Copyright (C) 2024 Argmax, Inc. All Rights Reserved. 6 | # 7 | 8 | from typing import List, Optional, Tuple 9 | 10 | import mlx.core as mx 11 | import mlx.nn as nn 12 | import numpy as np 13 | from argmaxtools.utils import get_logger 14 | from transformers import T5Config 15 | 16 | logger = get_logger(__name__) 17 | 18 | 19 | def _relative_position_bucket( 20 | relative_position, bidirectional=True, num_buckets=32, max_distance=128 21 | ): 22 | """ 23 | Adapted from HF Tensorflow: 24 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py 25 | 26 | Translate relative position to a bucket number for relative attention. The relative position is defined as 27 | memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to 28 | position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for 29 | small absolute relative_position and larger buckets for larger absolute relative_positions. All relative 30 | positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. 31 | This should allow for more graceful generalization to longer sequences than the model has been trained on 32 | 33 | Args: 34 | relative_position: an int32 Tensor 35 | bidirectional: a boolean - whether the attention is bidirectional 36 | num_buckets: an integer 37 | max_distance: an integer 38 | 39 | Returns: 40 | a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) 41 | """ 42 | relative_buckets = 0 43 | if bidirectional: 44 | num_buckets //= 2 45 | relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets 46 | relative_position = mx.abs(relative_position) 47 | else: 48 | relative_position = -mx.minimum( 49 | relative_position, mx.zeros_like(relative_position) 50 | ) 51 | # now relative_position is in the range [0, inf) 52 | 53 | # half of the buckets are for exact increments in positions 54 | max_exact = num_buckets // 2 55 | is_small = relative_position < max_exact 56 | 57 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance 58 | scale = (num_buckets - max_exact) / np.log(max_distance / max_exact) 59 | relative_position_if_large = max_exact + ( 60 | mx.log(relative_position.astype(mx.float32) / max_exact) * scale 61 | ).astype(mx.int16) 62 | relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1) 63 | relative_buckets += mx.where( 64 | is_small, relative_position, relative_position_if_large 65 | ) 66 | return relative_buckets 67 | 68 | 69 | class RelativePositionBias(nn.Module): 70 | def __init__(self, config: T5Config, bidirectional: bool): 71 | self.bidirectional = bidirectional 72 | self.num_buckets = config.relative_attention_num_buckets 73 | self.max_distance = config.relative_attention_max_distance 74 | self.n_heads = config.num_heads 75 | self.embeddings = nn.Embedding( 76 | config.relative_attention_num_buckets, config.num_heads 77 | ) 78 | 79 | def __call__(self, query_length: int, key_length: int, offset: int = 0): 80 | """Compute binned relative position bias""" 81 | context_position = mx.arange(offset, query_length)[:, None] 82 | memory_position = mx.arange(key_length)[None, :] 83 | 84 | # shape (query_length, key_length) 85 | relative_position = memory_position - context_position 86 | relative_position_bucket = _relative_position_bucket( 87 | relative_position, 88 | bidirectional=self.bidirectional, 89 | num_buckets=self.num_buckets, 90 | max_distance=self.max_distance, 91 | ) 92 | 93 | # shape (query_length, key_length, num_heads) 94 | values = self.embeddings(relative_position_bucket) 95 | 96 | # shape (num_heads, query_length, key_length) 97 | return values.transpose(2, 0, 1) 98 | 99 | 100 | class MultiHeadAttention(nn.Module): 101 | def __init__(self, config: T5Config): 102 | super().__init__() 103 | inner_dim = config.d_kv * config.num_heads 104 | self.num_heads = config.num_heads 105 | self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False) 106 | self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False) 107 | self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False) 108 | self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False) 109 | 110 | def __call__( 111 | self, 112 | queries: mx.array, 113 | keys: mx.array, 114 | values: mx.array, 115 | mask: Optional[mx.array], 116 | cache: Optional[Tuple[mx.array, mx.array]] = None, 117 | ) -> [mx.array, Tuple[mx.array, mx.array]]: 118 | queries = self.query_proj(queries) 119 | keys = self.key_proj(keys) 120 | values = self.value_proj(values) 121 | 122 | num_heads = self.num_heads 123 | B, L, _ = queries.shape 124 | _, S, _ = keys.shape 125 | queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 126 | keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) 127 | values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 128 | 129 | if cache is not None: 130 | key_cache, value_cache = cache 131 | keys = mx.concatenate([key_cache, keys], axis=3) 132 | values = mx.concatenate([value_cache, values], axis=2) 133 | 134 | # Dimensions are [batch x num heads x sequence x hidden dim] 135 | scores = queries @ keys 136 | if mask is not None: 137 | scores = scores + mask.astype(scores.dtype) 138 | 139 | scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) 140 | values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) 141 | return self.out_proj(values_hat), (keys, values) 142 | 143 | 144 | class RMSNorm(nn.Module): 145 | def __init__(self, dims: int, eps: float = 1e-5): 146 | super().__init__() 147 | self.weight = mx.ones((dims,)) 148 | self.eps = eps 149 | 150 | def _norm(self, x): 151 | import math 152 | 153 | # return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) 154 | return x * mx.rsqrt( 155 | (x * (1.0 / math.sqrt(self.weight.shape[0]))) 156 | .square() 157 | .sum(-1, keepdims=True) 158 | + self.eps 159 | ) 160 | 161 | def __call__(self, x): 162 | t = x.dtype 163 | output = self._norm(x).astype(t) 164 | return self.weight * output 165 | 166 | 167 | class DenseActivation(nn.Module): 168 | def __init__(self, config: T5Config): 169 | super().__init__() 170 | mlp_dims = config.d_ff or config.d_model * 4 171 | self.gated = config.feed_forward_proj.startswith("gated") 172 | if self.gated: 173 | self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False) 174 | self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False) 175 | else: 176 | self.wi = nn.Linear(config.d_model, mlp_dims, bias=False) 177 | self.wo = nn.Linear(mlp_dims, config.d_model, bias=False) 178 | activation = config.feed_forward_proj.removeprefix("gated-") 179 | if activation == "relu": 180 | self.act = nn.relu 181 | elif activation == "gelu": 182 | self.act = nn.gelu 183 | elif activation == "silu": 184 | self.act = nn.silu 185 | else: 186 | raise ValueError(f"Unknown activation: {activation}") 187 | 188 | def __call__(self, x): 189 | if self.gated: 190 | hidden_act = self.act(self.wi_0(x)) 191 | hidden_linear = self.wi_1(x) 192 | x = hidden_act * hidden_linear 193 | else: 194 | x = self.act(self.wi(x)) 195 | return self.wo(x) 196 | 197 | 198 | class TransformerEncoderLayer(nn.Module): 199 | def __init__(self, config: T5Config): 200 | super().__init__() 201 | self.attention = MultiHeadAttention(config) 202 | self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) 203 | self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) 204 | self.dense = DenseActivation(config) 205 | 206 | def __call__(self, x, mask): 207 | y = self.ln1(x) 208 | y = y.astype(self.ln1.weight.dtype) 209 | y, _ = self.attention(y, y, y, mask=mask) 210 | y = y.astype(mx.float32) 211 | x = x + y 212 | 213 | y = self.ln2(x) 214 | y = self.dense(y) 215 | return x + y 216 | 217 | 218 | class TransformerEncoder(nn.Module): 219 | def __init__(self, config: T5Config, low_memory_mode=True): 220 | super().__init__() 221 | self.low_memory_mode = low_memory_mode 222 | self.layers = [ 223 | TransformerEncoderLayer(config) for i in range(config.num_layers) 224 | ] 225 | self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) 226 | self.relative_attention_bias = RelativePositionBias(config, bidirectional=True) 227 | 228 | def __call__(self, x: mx.array): 229 | t = x.dtype 230 | pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1]) 231 | 232 | if self.low_memory_mode: 233 | mx.metal.set_memory_limit(4 * 1024**3) 234 | 235 | for layer in self.layers: 236 | x = layer(x, mask=pos_bias) 237 | 238 | if self.low_memory_mode: 239 | self.layers.clear() 240 | 241 | mx.eval(x) 242 | mx.metal.set_memory_limit(mx.metal.device_info()["memory_size"]) 243 | return self.ln(x).astype(t) 244 | 245 | 246 | class TransformerDecoderLayer(nn.Module): 247 | def __init__(self, config: T5Config): 248 | super().__init__() 249 | self.self_attention = MultiHeadAttention(config) 250 | self.cross_attention = MultiHeadAttention(config) 251 | self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) 252 | self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) 253 | self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) 254 | self.dense = DenseActivation(config) 255 | 256 | def __call__( 257 | self, 258 | x: mx.array, 259 | memory: mx.array, 260 | mask: mx.array, 261 | memory_mask: mx.array, 262 | cache: Optional[List[Tuple[mx.array, mx.array]]] = None, 263 | ): 264 | y = self.ln1(x) 265 | y, cache = self.self_attention(y, y, y, mask, cache) 266 | x = x + y 267 | 268 | y = self.ln2(x) 269 | y, _ = self.cross_attention(y, memory, memory, memory_mask) 270 | x = x + y 271 | 272 | y = self.ln3(x) 273 | y = self.dense(y) 274 | x = x + y 275 | 276 | return x, cache 277 | 278 | 279 | class TransformerDecoder(nn.Module): 280 | def __init__(self, config: T5Config): 281 | super().__init__() 282 | n_layers = getattr(config, "num_decoder_layers", config.num_layers) 283 | self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)] 284 | self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) 285 | self.relative_attention_bias = RelativePositionBias(config, bidirectional=False) 286 | 287 | def __call__(self, x, memory, mask, memory_mask, cache=None): 288 | if cache is not None: 289 | offset = cache[0][0].shape[3] 290 | else: 291 | offset = 0 292 | cache = [None] * len(self.layers) 293 | 294 | T = offset + x.shape[1] 295 | pos_bias = self.relative_attention_bias(T, T, offset=offset) 296 | if mask is not None: 297 | mask += pos_bias 298 | else: 299 | mask = pos_bias 300 | 301 | for e, layer in enumerate(self.layers): 302 | x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e]) 303 | x = self.ln(x) 304 | 305 | return x, cache 306 | 307 | 308 | class OutputHead(nn.Module): 309 | def __init__(self, config: T5Config): 310 | self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False) 311 | 312 | def __call__(self, inputs): 313 | return self.linear(inputs) 314 | 315 | 316 | class SD3T5Encoder(nn.Module): 317 | def __init__(self, config: T5Config, low_memory_mode=True): 318 | self.wte = nn.Embedding(config.vocab_size, config.d_model) 319 | self.encoder = TransformerEncoder(config, low_memory_mode=low_memory_mode) 320 | self.model_dim = config.d_model 321 | 322 | def __call__(self, inputs: mx.array): 323 | out = self.wte(inputs) 324 | out = self.encoder(out) 325 | return out 326 | -------------------------------------------------------------------------------- /diffusionkit/mlx/tokenizer.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion 2 | 3 | from typing import List 4 | 5 | import mlx.core as mx 6 | import numpy as np 7 | import regex 8 | from argmaxtools.utils import get_logger 9 | from transformers import AutoTokenizer, T5Config 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | class Tokenizer: 15 | """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ .""" 16 | 17 | def __init__(self, bpe_ranks, vocab, pad_with_eos=False): 18 | self.bpe_ranks = bpe_ranks 19 | self.vocab = vocab 20 | self.pat = regex.compile( 21 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 22 | regex.IGNORECASE, 23 | ) 24 | 25 | self.pad_to_max_length = True 26 | self.max_length = 77 27 | 28 | self._cache = {self.bos: self.bos, self.eos: self.eos} 29 | self.pad_with_eos = pad_with_eos 30 | 31 | @property 32 | def bos(self): 33 | return "<|startoftext|>" 34 | 35 | @property 36 | def bos_token(self): 37 | return self.vocab[self.bos] 38 | 39 | @property 40 | def eos(self): 41 | return "<|endoftext|>" 42 | 43 | @property 44 | def eos_token(self): 45 | return self.vocab[self.eos] 46 | 47 | def bpe(self, text): 48 | if text in self._cache: 49 | return self._cache[text] 50 | 51 | unigrams = list(text[:-1]) + [text[-1] + ""] 52 | unique_bigrams = set(zip(unigrams, unigrams[1:])) 53 | 54 | if not unique_bigrams: 55 | return unigrams 56 | 57 | # In every iteration try to merge the two most likely bigrams. If none 58 | # was merged we are done. 59 | # 60 | # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py 61 | while unique_bigrams: 62 | bigram = min( 63 | unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")) 64 | ) 65 | if bigram not in self.bpe_ranks: 66 | break 67 | 68 | new_unigrams = [] 69 | skip = False 70 | for a, b in zip(unigrams, unigrams[1:]): 71 | if skip: 72 | skip = False 73 | continue 74 | 75 | if (a, b) == bigram: 76 | new_unigrams.append(a + b) 77 | skip = True 78 | 79 | else: 80 | new_unigrams.append(a) 81 | 82 | if not skip: 83 | new_unigrams.append(b) 84 | 85 | unigrams = new_unigrams 86 | unique_bigrams = set(zip(unigrams, unigrams[1:])) 87 | 88 | self._cache[text] = unigrams 89 | 90 | return unigrams 91 | 92 | def tokenize(self, text, prepend_bos=True, append_eos=True): 93 | if isinstance(text, list): 94 | return [self.tokenize(t, prepend_bos, append_eos) for t in text] 95 | 96 | # Lower case cleanup and split according to self.pat. Hugging Face does 97 | # a much more thorough job here but this should suffice for 95% of 98 | # cases. 99 | clean_text = regex.sub(r"\s+", " ", text.lower()) 100 | tokens = regex.findall(self.pat, clean_text) 101 | 102 | # Split the tokens according to the byte-pair merge file 103 | bpe_tokens = [ti for t in tokens for ti in self.bpe(t)] 104 | 105 | # Map to token ids and return 106 | tokens = [self.vocab[t] for t in bpe_tokens] 107 | 108 | # Truncate 109 | max_length = self.max_length - int(prepend_bos) - int(append_eos) 110 | if len(tokens) > max_length: 111 | tokens = tokens[:max_length] 112 | logger.warning( 113 | f"Length of tokens exceeds {self.max_length}. Truncating to {self.max_length}." 114 | ) 115 | if prepend_bos: 116 | tokens = [self.bos_token] + tokens 117 | if append_eos: 118 | tokens.append(self.eos_token) 119 | 120 | return tokens 121 | 122 | 123 | class T5Tokenizer: 124 | def __init__(self, config: T5Config, max_context_length: int): 125 | self.max_length = max_context_length 126 | self._decoder_start_id = config.decoder_start_token_id 127 | self._tokenizer = AutoTokenizer.from_pretrained( 128 | "google/t5-v1_1-xxl", 129 | legacy=False, 130 | model_max_length=self.max_length, 131 | ) 132 | 133 | self.pad_to_max_length = True 134 | self.pad_with_eos = False 135 | 136 | @property 137 | def eos_id(self) -> int: 138 | return self._tokenizer.eos_token_id 139 | 140 | @property 141 | def decoder_start_id(self) -> int: 142 | return self._decoder_start_id 143 | 144 | def encode(self, s: str) -> mx.array: 145 | return mx.array( 146 | self._tokenizer( 147 | s, 148 | return_tensors="np", 149 | return_attention_mask=False, 150 | max_length=self.max_length, 151 | truncation=True, 152 | )["input_ids"] 153 | ) 154 | 155 | def decode(self, t: List[int], with_sep: bool = True) -> str: 156 | tokens = self._tokenizer.convert_ids_to_tokens(t) 157 | return "".join(t.replace("▁", " " if with_sep else "") for t in tokens) 158 | 159 | def tokenize(self, s: str) -> np.array: 160 | return [t.item() for t in self.encode(s)[0]] 161 | -------------------------------------------------------------------------------- /diffusionkit/mlx/vae.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion 2 | 3 | # 4 | # For licensing see accompanying LICENSE.md file. 5 | # Copyright (C) 2024 Argmax, Inc. All Rights Reserved. 6 | # 7 | 8 | import math 9 | from typing import List, Optional 10 | 11 | import mlx.core as mx 12 | import mlx.nn as nn 13 | from argmaxtools.utils import get_logger 14 | 15 | from .config import AutoencoderConfig 16 | 17 | logger = get_logger(__name__) 18 | 19 | 20 | def upsample_nearest(x, scale: int = 2): 21 | B, H, W, C = x.shape 22 | x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C)) 23 | x = x.reshape(B, H * scale, W * scale, C) 24 | 25 | return x 26 | 27 | 28 | class Attention(nn.Module): 29 | """A single head unmasked attention for use with the VAE.""" 30 | 31 | def __init__(self, dims: int, norm_groups: int = 32): 32 | super().__init__() 33 | 34 | self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True) 35 | self.query_proj = nn.Linear(dims, dims) 36 | self.key_proj = nn.Linear(dims, dims) 37 | self.value_proj = nn.Linear(dims, dims) 38 | self.out_proj = nn.Linear(dims, dims) 39 | 40 | def __call__(self, x): 41 | B, H, W, C = x.shape 42 | 43 | y = self.group_norm(x) 44 | 45 | queries = self.query_proj(y).reshape(B, H * W, C) 46 | keys = self.key_proj(y).reshape(B, H * W, C) 47 | values = self.value_proj(y).reshape(B, H * W, C) 48 | 49 | scale = 1 / math.sqrt(queries.shape[-1]) 50 | scores = (queries * scale) @ keys.transpose(0, 2, 1) 51 | attn = mx.softmax(scores, axis=-1) 52 | y = (attn @ values).reshape(B, H, W, C) 53 | 54 | y = self.out_proj(y) 55 | x = x + y 56 | 57 | return x 58 | 59 | 60 | class ResnetBlock2D(nn.Module): 61 | def __init__( 62 | self, 63 | in_channels: int, 64 | out_channels: Optional[int] = None, 65 | groups: int = 32, 66 | temb_channels: Optional[int] = None, 67 | ): 68 | super().__init__() 69 | 70 | out_channels = out_channels or in_channels 71 | 72 | self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True) 73 | self.conv1 = nn.Conv2d( 74 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 75 | ) 76 | if temb_channels is not None: 77 | self.time_emb_proj = nn.Linear(temb_channels, out_channels) 78 | self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True) 79 | self.conv2 = nn.Conv2d( 80 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 81 | ) 82 | 83 | if in_channels != out_channels: 84 | self.conv_shortcut = nn.Linear(in_channels, out_channels) 85 | 86 | def __call__(self, x, temb=None): 87 | if temb is not None: 88 | temb = self.time_emb_proj(nn.silu(temb)) 89 | 90 | y = self.norm1(x) 91 | y = nn.silu(y) 92 | y = self.conv1(y) 93 | if temb is not None: 94 | y = y + temb[:, None, None, :] 95 | y = self.norm2(y) 96 | y = nn.silu(y) 97 | y = self.conv2(y) 98 | 99 | x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x)) 100 | 101 | return x 102 | 103 | 104 | class EncoderDecoderBlock2D(nn.Module): 105 | def __init__( 106 | self, 107 | in_channels: int, 108 | out_channels: int, 109 | num_layers: int = 1, 110 | resnet_groups: int = 32, 111 | add_downsample=True, 112 | add_upsample=True, 113 | ): 114 | super().__init__() 115 | 116 | # Add the resnet blocks 117 | self.resnets = [ 118 | ResnetBlock2D( 119 | in_channels=in_channels if i == 0 else out_channels, 120 | out_channels=out_channels, 121 | groups=resnet_groups, 122 | ) 123 | for i in range(num_layers) 124 | ] 125 | 126 | # Add an optional downsampling layer 127 | if add_downsample: 128 | self.downsample = nn.Conv2d( 129 | out_channels, out_channels, kernel_size=3, stride=2, padding=0 130 | ) 131 | 132 | # or upsampling layer 133 | if add_upsample: 134 | self.upsample = nn.Conv2d( 135 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 136 | ) 137 | 138 | def __call__(self, x): 139 | for resnet in self.resnets: 140 | x = resnet(x) 141 | 142 | if "downsample" in self: 143 | x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) 144 | x = self.downsample(x) 145 | 146 | if "upsample" in self: 147 | x = self.upsample(upsample_nearest(x)) 148 | 149 | return x 150 | 151 | 152 | class Encoder(nn.Module): 153 | """Implements the encoder side of the Autoencoder.""" 154 | 155 | def __init__( 156 | self, 157 | in_channels: int, 158 | out_channels: int, 159 | block_out_channels: List[int] = [64], 160 | layers_per_block: int = 2, 161 | resnet_groups: int = 32, 162 | ): 163 | super().__init__() 164 | 165 | self.conv_in = nn.Conv2d( 166 | in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1 167 | ) 168 | 169 | channels = [block_out_channels[0]] + list(block_out_channels) 170 | self.down_blocks = [ 171 | EncoderDecoderBlock2D( 172 | in_channels, 173 | out_channels, 174 | num_layers=layers_per_block, 175 | resnet_groups=resnet_groups, 176 | add_downsample=i < len(block_out_channels) - 1, 177 | add_upsample=False, 178 | ) 179 | for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])) 180 | ] 181 | 182 | self.mid_blocks = [ 183 | ResnetBlock2D( 184 | in_channels=block_out_channels[-1], 185 | out_channels=block_out_channels[-1], 186 | groups=resnet_groups, 187 | ), 188 | Attention(block_out_channels[-1], resnet_groups), 189 | ResnetBlock2D( 190 | in_channels=block_out_channels[-1], 191 | out_channels=block_out_channels[-1], 192 | groups=resnet_groups, 193 | ), 194 | ] 195 | 196 | self.conv_norm_out = nn.GroupNorm( 197 | resnet_groups, block_out_channels[-1], pytorch_compatible=True 198 | ) 199 | self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1) 200 | 201 | def __call__(self, x): 202 | x = self.conv_in(x) 203 | 204 | for l in self.down_blocks: 205 | x = l(x) 206 | 207 | x = self.mid_blocks[0](x) 208 | x = self.mid_blocks[1](x) 209 | x = self.mid_blocks[2](x) 210 | 211 | x = self.conv_norm_out(x) 212 | x = nn.silu(x) 213 | x = self.conv_out(x) 214 | 215 | return x 216 | 217 | 218 | class Decoder(nn.Module): 219 | """Implements the decoder side of the Autoencoder.""" 220 | 221 | def __init__( 222 | self, 223 | in_channels: int, 224 | out_channels: int, 225 | block_out_channels: List[int] = [64], 226 | layers_per_block: int = 2, 227 | resnet_groups: int = 32, 228 | ): 229 | super().__init__() 230 | 231 | self.conv_in = nn.Conv2d( 232 | in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1 233 | ) 234 | 235 | self.mid_blocks = [ 236 | ResnetBlock2D( 237 | in_channels=block_out_channels[-1], 238 | out_channels=block_out_channels[-1], 239 | groups=resnet_groups, 240 | ), 241 | Attention(block_out_channels[-1], resnet_groups), 242 | ResnetBlock2D( 243 | in_channels=block_out_channels[-1], 244 | out_channels=block_out_channels[-1], 245 | groups=resnet_groups, 246 | ), 247 | ] 248 | 249 | channels = list(reversed(block_out_channels)) 250 | channels = [channels[0]] + channels 251 | self.up_blocks = [ 252 | EncoderDecoderBlock2D( 253 | in_channels, 254 | out_channels, 255 | num_layers=layers_per_block, 256 | resnet_groups=resnet_groups, 257 | add_downsample=False, 258 | add_upsample=i < len(block_out_channels) - 1, 259 | ) 260 | for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])) 261 | ] 262 | 263 | self.conv_norm_out = nn.GroupNorm( 264 | resnet_groups, block_out_channels[0], pytorch_compatible=True 265 | ) 266 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) 267 | 268 | def __call__(self, x): 269 | x = self.conv_in(x) 270 | 271 | x = self.mid_blocks[0](x) 272 | x = self.mid_blocks[1](x) 273 | x = self.mid_blocks[2](x) 274 | 275 | for l in self.up_blocks: 276 | x = l(x) 277 | 278 | x = self.conv_norm_out(x) 279 | x = nn.silu(x) 280 | x = self.conv_out(x) 281 | 282 | return x 283 | 284 | 285 | class Autoencoder(nn.Module): 286 | """The autoencoder that allows us to perform diffusion in the latent space.""" 287 | 288 | def __init__(self, config: AutoencoderConfig): 289 | super().__init__() 290 | 291 | self.latent_channels = config.latent_channels_in 292 | self.scaling_factor = config.scaling_factor 293 | self.encoder = Encoder( 294 | config.in_channels, 295 | config.latent_channels_out, 296 | config.block_out_channels, 297 | config.layers_per_block, 298 | resnet_groups=config.norm_num_groups, 299 | ) 300 | self.decoder = Decoder( 301 | config.latent_channels_in, 302 | config.out_channels, 303 | config.block_out_channels, 304 | config.layers_per_block + 1, 305 | resnet_groups=config.norm_num_groups, 306 | ) 307 | 308 | self.quant_proj = nn.Linear( 309 | config.latent_channels_out, config.latent_channels_out 310 | ) 311 | self.post_quant_proj = nn.Linear( 312 | config.latent_channels_in, config.latent_channels_in 313 | ) 314 | 315 | def decode(self, z): 316 | z = z / self.scaling_factor 317 | return self.decoder(self.post_quant_proj(z)) 318 | 319 | def encode(self, x): 320 | x = self.encoder(x) 321 | x = self.quant_proj(x) 322 | mean, logvar = x.split(2, axis=-1) 323 | mean = mean * self.scaling_factor 324 | logvar = logvar + 2 * math.log(self.scaling_factor) 325 | 326 | return mean, logvar 327 | 328 | def __call__(self, x, key=None): 329 | mean, logvar = self.encode(x) 330 | z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean 331 | x_hat = self.decode(z) 332 | 333 | return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar) 334 | 335 | 336 | class VAEDecoder(nn.Module): 337 | """Implements the decoder side of the Autoencoder for SD3""" 338 | 339 | def __init__( 340 | self, 341 | in_channels: int = 16, 342 | out_channels: int = 3, 343 | block_out_channels: List[int] = [128, 256, 512, 512], 344 | layers_per_block: int = 3, 345 | resnet_groups: int = 32, 346 | ): 347 | super().__init__() 348 | 349 | self.conv_in = nn.Conv2d( 350 | in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1 351 | ) 352 | 353 | self.mid_blocks = [ 354 | ResnetBlock2D( 355 | in_channels=block_out_channels[-1], 356 | out_channels=block_out_channels[-1], 357 | groups=resnet_groups, 358 | ), 359 | Attention(block_out_channels[-1], resnet_groups), 360 | ResnetBlock2D( 361 | in_channels=block_out_channels[-1], 362 | out_channels=block_out_channels[-1], 363 | groups=resnet_groups, 364 | ), 365 | ] 366 | 367 | channels = list(reversed(block_out_channels)) 368 | channels = [channels[0]] + channels 369 | self.up_blocks = [] 370 | for i, (in_c, out_c) in enumerate(zip(channels, channels[1:])): 371 | up = EncoderDecoderBlock2D( 372 | in_c, 373 | out_c, 374 | num_layers=layers_per_block, 375 | resnet_groups=resnet_groups, 376 | add_downsample=False, 377 | add_upsample=i < len(block_out_channels) - 1, 378 | ) 379 | self.up_blocks.insert(0, up) 380 | 381 | self.conv_norm_out = nn.GroupNorm( 382 | resnet_groups, block_out_channels[0], pytorch_compatible=True 383 | ) 384 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) 385 | 386 | def __call__(self, x): 387 | x = self.conv_in(x) 388 | 389 | x = self.mid_blocks[0](x) 390 | x = self.mid_blocks[1](x) 391 | x = self.mid_blocks[2](x) 392 | 393 | for l in reversed(self.up_blocks): 394 | x = l(x) 395 | mx.eval(x) 396 | 397 | x = self.conv_norm_out(x) 398 | x = nn.silu(x) 399 | x = self.conv_out(x) 400 | 401 | return x 402 | 403 | 404 | class VAEEncoder(nn.Module): 405 | """Implements the encoder side of the Autoencoder.""" 406 | 407 | def __init__( 408 | self, 409 | in_channels: int = 3, 410 | out_channels: int = 32, 411 | block_out_channels: List[int] = [128, 256, 512, 512], 412 | layers_per_block: int = 2, 413 | resnet_groups: int = 32, 414 | ): 415 | super().__init__() 416 | 417 | self.conv_in = nn.Conv2d( 418 | in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1 419 | ) 420 | 421 | channels = [block_out_channels[0]] + list(block_out_channels) 422 | self.down_blocks = [ 423 | EncoderDecoderBlock2D( 424 | in_channels, 425 | out_channels, 426 | num_layers=layers_per_block, 427 | resnet_groups=resnet_groups, 428 | add_downsample=i < len(block_out_channels) - 1, 429 | add_upsample=False, 430 | ) 431 | for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])) 432 | ] 433 | 434 | self.mid_blocks = [ 435 | ResnetBlock2D( 436 | in_channels=block_out_channels[-1], 437 | out_channels=block_out_channels[-1], 438 | groups=resnet_groups, 439 | ), 440 | Attention(block_out_channels[-1], resnet_groups), 441 | ResnetBlock2D( 442 | in_channels=block_out_channels[-1], 443 | out_channels=block_out_channels[-1], 444 | groups=resnet_groups, 445 | ), 446 | ] 447 | 448 | self.conv_norm_out = nn.GroupNorm( 449 | resnet_groups, block_out_channels[-1], pytorch_compatible=True 450 | ) 451 | self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1) 452 | 453 | def __call__(self, x): 454 | x = self.conv_in(x) 455 | 456 | for l in self.down_blocks: 457 | x = l(x) 458 | 459 | x = self.mid_blocks[0](x) 460 | x = self.mid_blocks[1](x) 461 | x = self.mid_blocks[2](x) 462 | 463 | x = self.conv_norm_out(x) 464 | x = nn.silu(x) 465 | x = self.conv_out(x) 466 | 467 | return x 468 | -------------------------------------------------------------------------------- /diffusionkit/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoddnn/ComfyUI-MLX/215691b282f3d1eddb2e7029c2c399567cd0be9b/diffusionkit/tests/__init__.py -------------------------------------------------------------------------------- /diffusionkit/tests/mlx/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoddnn/ComfyUI-MLX/215691b282f3d1eddb2e7029c2c399567cd0be9b/diffusionkit/tests/mlx/__init__.py -------------------------------------------------------------------------------- /diffusionkit/tests/mlx/test_diffusion_pipeline.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2024 Argmax, Inc. All Rights Reserved. 4 | # 5 | 6 | import json 7 | import os 8 | import unittest 9 | 10 | from argmaxtools.utils import get_logger 11 | from diffusionkit.mlx import MMDIT_CKPT, DiffusionPipeline 12 | from diffusionkit.utils import image_psnr 13 | from huggingface_hub import hf_hub_download 14 | from PIL import Image 15 | 16 | logger = get_logger(__name__) 17 | 18 | W16 = True 19 | A16 = True 20 | TEST_PSNR_THRESHOLD = 20 21 | TEST_MIN_SPEEDUP = 0.95 22 | SD3_TEST_IMAGES_REPO = "argmaxinc/sd-test-images" 23 | TEST_CACHE_DIR = ".cache" 24 | CACHE_SUBFOLDER = None 25 | 26 | LOW_MEMORY_MODE = True 27 | SAVE_IMAGES = True 28 | MODEL_VERSION = "argmaxinc/mlx-stable-diffusion-3-medium" 29 | USE_T5 = False 30 | SKIP_CORRECTNESS = False 31 | 32 | 33 | class TestSD3Pipeline(unittest.TestCase): 34 | @classmethod 35 | def setUpClass(cls): 36 | cls.sd_test_images_metadata = hf_hub_download( 37 | SD3_TEST_IMAGES_REPO, "metadata.json", repo_type="dataset" 38 | ) 39 | 40 | @classmethod 41 | def tearDownClass(cls): 42 | del cls.sd_test_images_metadata 43 | cls.sd_test_images_metadata = None 44 | 45 | super().tearDownClass() 46 | 47 | def test_sd3_pipeline_correctness(self): 48 | with open(self.sd_test_images_metadata, "r") as f: 49 | metadata = json.load(f) 50 | 51 | # Group metadata by model size 52 | model_examples = {"argmaxinc/mlx-stable-diffusion-3-medium": []} 53 | for data in metadata: 54 | model_examples[data["model_version"]].append(data) 55 | 56 | for model_version, examples in model_examples.items(): 57 | sd3 = DiffusionPipeline( 58 | model_version=model_version, 59 | w16=W16, 60 | low_memory_mode=LOW_MEMORY_MODE, 61 | a16=A16, 62 | ) 63 | if not LOW_MEMORY_MODE: 64 | sd3.ensure_models_are_loaded() 65 | for example in examples: 66 | image_path = example["image"] 67 | sd3.use_t5 = example["use_t5"] 68 | 69 | f = hf_hub_download( 70 | SD3_TEST_IMAGES_REPO, image_path, repo_type="dataset" 71 | ) 72 | image = Image.open(f) 73 | 74 | generated_image, _ = sd3.generate_image( 75 | text=example["prompt"], 76 | num_steps=example["steps"], 77 | cfg_weight=example["cfg"], 78 | negative_text=example["neg_prompt"], 79 | latent_size=(example["height"] // 8, example["width"] // 8), 80 | seed=example["seed"], 81 | ) 82 | 83 | if SAVE_IMAGES: 84 | img_cache_dir = os.path.join(TEST_CACHE_DIR, "img") 85 | out_path = os.path.join(img_cache_dir, image_path) 86 | if not os.path.exists(img_cache_dir): 87 | os.makedirs(img_cache_dir, exist_ok=True) 88 | generated_image.save(out_path) 89 | logger.info(f"Saved the image to {out_path}") 90 | 91 | psnr = image_psnr(image, generated_image) 92 | logger.info(f"Image: {image_path} | PSNR: {psnr} dB") 93 | self.assertGreaterEqual(psnr, TEST_PSNR_THRESHOLD) 94 | if LOW_MEMORY_MODE: 95 | del sd3 96 | sd3 = DiffusionPipeline( 97 | model_version=model_version, 98 | w16=W16, 99 | low_memory_mode=LOW_MEMORY_MODE, 100 | a16=A16, 101 | ) 102 | del sd3 103 | 104 | def test_memory_usage(self): 105 | with open(self.sd_test_images_metadata, "r") as f: 106 | metadata = json.load(f) 107 | 108 | # Group metadata by model size 109 | model_examples = {"argmaxinc/mlx-stable-diffusion-3-medium": []} 110 | for data in metadata: 111 | model_examples[data["model_version"]].append(data) 112 | 113 | sd3 = DiffusionPipeline( 114 | model_version=MODEL_VERSION, 115 | w16=W16, 116 | low_memory_mode=LOW_MEMORY_MODE, 117 | a16=A16, 118 | ) 119 | if not LOW_MEMORY_MODE: 120 | sd3.ensure_models_are_loaded() 121 | 122 | log = None 123 | for example in model_examples[MODEL_VERSION]: 124 | sd3.use_t5 = USE_T5 125 | logger.info( 126 | f"Testing memory usage... USE_T5 = {USE_T5} | MODEL_VERSION = {MODEL_VERSION}" 127 | ) 128 | _, log = sd3.generate_image( 129 | text=example["prompt"], 130 | num_steps=3, 131 | cfg_weight=example["cfg"], 132 | negative_text=example["neg_prompt"], 133 | latent_size=(example["height"] // 8, example["width"] // 8), 134 | seed=example["seed"], 135 | ) 136 | break 137 | 138 | out_folder = os.path.join(TEST_CACHE_DIR, CACHE_SUBFOLDER) 139 | out_path = os.path.join(out_folder, f"{MODEL_VERSION}_log.json") 140 | if not os.path.exists(out_folder): 141 | os.makedirs(out_folder, exist_ok=True) 142 | with open(out_path, "w") as f: 143 | json.dump(log, f, indent=2) 144 | logger.info(f"Saved the memory log to {out_path}") 145 | self.assertIsNotNone(log) 146 | 147 | 148 | def main(args): 149 | global LOW_MEMORY_MODE, SAVE_IMAGES, SKIP_CORRECTNESS, MODEL_VERSION, W16, A16, CACHE_SUBFOLDER, USE_T5 150 | 151 | LOW_MEMORY_MODE = args.low_memory_mode 152 | SAVE_IMAGES = args.save_images 153 | SKIP_CORRECTNESS = args.skip_correctness 154 | MODEL_VERSION = args.model_version 155 | W16 = args.w16 156 | A16 = args.a16 157 | CACHE_SUBFOLDER = args.subfolder 158 | USE_T5 = args.use_t5 159 | 160 | suite = unittest.TestSuite() 161 | if not SKIP_CORRECTNESS: 162 | suite.addTest(TestSD3Pipeline("test_sd3_pipeline_correctness")) 163 | 164 | suite.addTest(TestSD3Pipeline("test_memory_usage")) 165 | runner = unittest.TextTestRunner() 166 | runner.run(suite) 167 | 168 | 169 | if __name__ == "__main__": 170 | import argparse 171 | 172 | parser = argparse.ArgumentParser() 173 | parser.add_argument( 174 | "--no-low-memory-mode", 175 | action="store_false", 176 | dest="low_memory_mode", 177 | help="Disable low memory mode: models remains loaded in memory after forward pass.", 178 | ) 179 | parser.add_argument( 180 | "--save-images", 181 | action="store_true", 182 | help="Saves generated images to .cache/img/ folder.", 183 | ) 184 | parser.add_argument( 185 | "--skip-correctness", action="store_true", help="Skip the correctness test." 186 | ) 187 | parser.add_argument( 188 | "--model-size", 189 | type=str, 190 | default="argmaxinc/mlx-stable-diffusion-3-medium", 191 | choices=tuple(MMDIT_CKPT.keys()), 192 | help="model version to test", 193 | ) 194 | parser.add_argument( 195 | "--w16", action="store_true", help="Loads the models in float16." 196 | ) 197 | parser.add_argument( 198 | "--a16", action="store_true", help="Use float16 for the model activations." 199 | ) 200 | parser.add_argument( 201 | "--subfolder", 202 | default="default", 203 | type=str, 204 | help="If specified, this string will be appended to the cache directory name.", 205 | ) 206 | parser.add_argument( 207 | "--use-t5", action="store_true", help="Use T5 model for text generation." 208 | ) 209 | args = parser.parse_args() 210 | 211 | main(args) 212 | -------------------------------------------------------------------------------- /diffusionkit/tests/torch2coreml/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_mmdit import convert_mmdit_to_mlpackage 2 | from .test_vae import convert_vae_to_mlpackage 3 | -------------------------------------------------------------------------------- /diffusionkit/tests/torch2coreml/test_mmdit.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2024 Argmax, Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import unittest 8 | from typing import Dict 9 | 10 | import coremltools as ct 11 | import torch 12 | from argmaxtools import test_utils as argmaxtools_test_utils 13 | from argmaxtools.utils import get_fastest_device, get_logger 14 | from diffusionkit.torch import mmdit 15 | from diffusionkit.torch.model_io import _load_mmdit_weights 16 | from huggingface_hub import hf_hub_download 17 | 18 | torch.set_grad_enabled(False) 19 | logger = get_logger(__name__) 20 | 21 | TEST_SD3_CKPT_PATH = os.getenv("TEST_SD3_CKPT_PATH", None) or None 22 | TEST_CKPT_FILE_NAME = os.getenv("TEST_CKPT_FILE_NAME", None) or None 23 | TEST_SD3_HF_REPO = os.getenv("TEST_SD3_HF_REPO", None) or None 24 | TEST_CACHE_DIR = os.getenv("TEST_CACHE_DIR", None) or "/tmp" 25 | TEST_DEV = os.getenv("TEST_DEV", None) or get_fastest_device() 26 | TEST_TORCH_DTYPE = torch.float32 27 | TEST_PSNR_THR = 35 28 | TEST_LATENT_SIZE = 64 # 64 latent -> 512 image, 128 latent -> 1024 image 29 | TEST_LATENT_HEIGHT = TEST_LATENT_SIZE 30 | TEST_LATENT_WIDTH = TEST_LATENT_SIZE 31 | 32 | TEST_MODELS = { 33 | "2b": mmdit.SD3_2b, 34 | "8b": mmdit.SD3_8b, 35 | } 36 | 37 | 38 | def setup_test_config( 39 | min_speedup_vs_cpu=3.0, 40 | compute_precision=ct.precision.FLOAT32, 41 | compute_unit=ct.ComputeUnit.CPU_AND_GPU, 42 | compression_min_speedup=0.2, 43 | default_nbits=None, 44 | skip_speed_tests=True, 45 | compile_coreml=False, 46 | ): 47 | argmaxtools_test_utils.TEST_MIN_SPEEDUP_VS_CPU = min_speedup_vs_cpu 48 | argmaxtools_test_utils.TEST_COREML_PRECISION = compute_precision 49 | argmaxtools_test_utils.TEST_COMPUTE_UNIT = compute_unit 50 | argmaxtools_test_utils.TEST_COMPRESSION_MIN_SPEEDUP = compression_min_speedup 51 | argmaxtools_test_utils.TEST_DEFAULT_NBITS = default_nbits 52 | argmaxtools_test_utils.TEST_SKIP_SPEED_TESTS = skip_speed_tests 53 | argmaxtools_test_utils.TEST_COMPILE_COREML = compile_coreml 54 | 55 | 56 | class TestSD3MMDiT(argmaxtools_test_utils.CoreMLTestsMixin, unittest.TestCase): 57 | """Unit tests for stable_duffusion_3.mmdit.MMDiT module""" 58 | 59 | model_version = "2b" 60 | 61 | @classmethod 62 | def setUpClass(cls): 63 | global TEST_SD3_CKPT_PATH 64 | cls.model_name = "MultiModalDiffusionTransformer" 65 | cls.test_output_names = ["denoiser_output"] 66 | cls.test_cache_dir = TEST_CACHE_DIR 67 | 68 | # Base test model 69 | logger.info("Initializing SD3 model") 70 | cls.test_torch_model = ( 71 | mmdit.MMDiT(TEST_MODELS[cls.model_version]) 72 | .to(TEST_DEV) 73 | .to(TEST_TORCH_DTYPE) 74 | .eval() 75 | ) 76 | logger.info("Initialized.") 77 | TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download( 78 | TEST_SD3_HF_REPO, "sd3_medium.safetensors" 79 | ) 80 | if TEST_SD3_CKPT_PATH is not None: 81 | logger.info(f"Loading SD3 model checkpoint from {TEST_SD3_CKPT_PATH}") 82 | _load_mmdit_weights(cls.test_torch_model, TEST_SD3_CKPT_PATH) 83 | logger.info("Loaded.") 84 | else: 85 | logger.info( 86 | "No TEST_SD3_CKPT_PATH (--sd3-ckpt-path) provided, exporting random weights" 87 | ) 88 | 89 | # Sample inputs 90 | # TODO(atiorh): CLI configurable model version 91 | cls.test_torch_inputs = get_test_inputs(TEST_MODELS[cls.model_version]) 92 | 93 | super().setUpClass() 94 | 95 | @classmethod 96 | def tearDownClass(cls): 97 | cls.test_torch_model = None 98 | cls.test_torch_inputs = None 99 | super().tearDownClass() 100 | 101 | 102 | def get_test_inputs(cfg: mmdit.MMDiTConfig) -> Dict[str, torch.Tensor]: 103 | """Generate random inputs for the SD3 MMDiT model""" 104 | batch_size = 2 # classifier-free guidance 105 | assert TEST_LATENT_HEIGHT <= cfg.max_latent_resolution 106 | assert TEST_LATENT_WIDTH <= cfg.max_latent_resolution 107 | 108 | latent_image_embeddings_dims = ( 109 | batch_size, 110 | cfg.vae_latent_dim, 111 | TEST_LATENT_HEIGHT, 112 | TEST_LATENT_WIDTH, 113 | ) 114 | pooled_text_embeddings_dims = (batch_size, cfg.pooled_text_embed_dim, 1, 1) 115 | token_level_text_embeddings_dims = ( 116 | batch_size, 117 | cfg.token_level_text_embed_dim, 118 | 1, 119 | cfg.text_seq_len, 120 | ) 121 | timestep_dims = (2,) 122 | 123 | torch_test_inputs = { 124 | "latent_image_embeddings": torch.randn(*latent_image_embeddings_dims), 125 | "token_level_text_embeddings": torch.randn(*token_level_text_embeddings_dims), 126 | "pooled_text_embeddings": torch.randn(pooled_text_embeddings_dims), 127 | "timestep": torch.randn(*timestep_dims), 128 | } 129 | 130 | return { 131 | k: v.to(TEST_DEV).to(TEST_TORCH_DTYPE) for k, v in torch_test_inputs.items() 132 | } 133 | 134 | 135 | def convert_mmdit_to_mlpackage( 136 | model_version: str, 137 | latent_h: int, 138 | latent_w: int, 139 | output_dir: str = None, 140 | **test_config_kwargs, 141 | ) -> str: 142 | """Converts a MMDiT model to a CoreML package. 143 | 144 | Returns: 145 | `str`: path to the converted model. 146 | """ 147 | global TEST_SD3_CKPT_PATH, TEST_SD3_HF_REPO, TEST_LATENT_WIDTH, TEST_LATENT_HEIGHT, TEST_CACHE_DIR 148 | 149 | # Convert to CoreML 150 | TEST_SD3_HF_REPO = model_version 151 | TEST_LATENT_HEIGHT = latent_h or TEST_LATENT_SIZE 152 | TEST_LATENT_WIDTH = latent_w or TEST_LATENT_SIZE 153 | 154 | setup_test_config(compile_coreml=False, **test_config_kwargs) 155 | 156 | with argmaxtools_test_utils._get_test_cache_dir( 157 | persistent_cache_dir=output_dir 158 | ) as TEST_CACHE_DIR: 159 | suite = unittest.TestSuite() 160 | suite.addTest(TestSD3MMDiT("test_torch2coreml_correctness_and_speedup")) 161 | 162 | if os.getenv("DEBUG", False): 163 | suite.debug() 164 | else: 165 | runner = unittest.TextTestRunner() 166 | runner.run(suite) 167 | 168 | return os.path.join(TEST_CACHE_DIR, f"{TestSD3MMDiT.model_name}.mlpackage") 169 | 170 | 171 | if __name__ == "__main__": 172 | import argparse 173 | 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument("--sd3-ckpt-path", default=TEST_SD3_CKPT_PATH, type=str) 176 | parser.add_argument("--ckpt-file-name", default="sd3_medium.safetensors", type=str) 177 | parser.add_argument( 178 | "--model-version", 179 | required=True, 180 | default="2b", 181 | choices=TEST_MODELS.keys(), 182 | type=str, 183 | ) 184 | parser.add_argument("-o", default=TEST_CACHE_DIR, type=str) 185 | parser.add_argument("--latent-size", default=TEST_LATENT_SIZE, type=int) 186 | args = parser.parse_args() 187 | 188 | TEST_SD3_CKPT_PATH = ( 189 | args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None 190 | ) 191 | TEST_SD3_HF_REPO = args.sd3_ckpt_path 192 | TEST_LATENT_SIZE = args.latent_size 193 | TEST_CKPT_FILE_NAME = args.ckpt_file_name 194 | 195 | setup_test_config() 196 | 197 | with argmaxtools_test_utils._get_test_cache_dir(args.o) as TEST_CACHE_DIR: 198 | suite = unittest.TestSuite() 199 | suite.addTest(TestSD3MMDiT("test_torch2coreml_correctness_and_speedup")) 200 | 201 | if os.getenv("DEBUG", False): 202 | suite.debug() 203 | else: 204 | runner = unittest.TextTestRunner() 205 | runner.run(suite) 206 | -------------------------------------------------------------------------------- /diffusionkit/tests/torch2coreml/test_vae.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2024 Argmax, Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import unittest 8 | from typing import Dict 9 | 10 | import coremltools as ct 11 | import torch 12 | from argmaxtools import test_utils as argmaxtools_test_utils 13 | from argmaxtools.utils import get_fastest_device, get_logger 14 | from diffusionkit.torch import vae 15 | from diffusionkit.torch.model_io import _load_vae_decoder_weights 16 | from huggingface_hub import hf_hub_download 17 | 18 | torch.set_grad_enabled(False) 19 | logger = get_logger(__name__) 20 | 21 | TEST_SD3_CKPT_PATH = os.getenv("TEST_SD3_CKPT_PATH", None) or None 22 | TEST_SD3_HF_REPO = os.getenv("TEST_SD3_HF_REPO", None) or None 23 | TEST_CACHE_DIR = os.getenv("TEST_CACHE_DIR", None) or "/tmp" 24 | TEST_DEV = os.getenv("TEST_DEV", None) or get_fastest_device() 25 | TEST_TORCH_DTYPE = torch.float32 26 | TEST_PSNR_THR = 35 27 | TEST_LATENT_SIZE = 64 # 64 latent -> 512 image, 128 latent -> 1024 image 28 | TEST_LATENT_HEIGHT = TEST_LATENT_SIZE 29 | TEST_LATENT_WIDTH = TEST_LATENT_SIZE 30 | 31 | SD3_8b = vae.VAEDecoderConfig(resolution=1024) 32 | SD3_2b = vae.VAEDecoderConfig(resolution=512) 33 | 34 | 35 | def setup_test_config( 36 | min_speedup_vs_cpu=3.0, 37 | compute_precision=ct.precision.FLOAT16, 38 | compute_unit=ct.ComputeUnit.CPU_AND_GPU, 39 | compression_min_speedup=0.5, 40 | default_nbits=None, 41 | skip_speed_tests=True, 42 | compile_coreml=False, 43 | ): 44 | argmaxtools_test_utils.TEST_MIN_SPEEDUP_VS_CPU = min_speedup_vs_cpu 45 | argmaxtools_test_utils.TEST_COREML_PRECISION = compute_precision 46 | argmaxtools_test_utils.TEST_COMPUTE_UNIT = compute_unit 47 | argmaxtools_test_utils.TEST_COMPRESSION_MIN_SPEEDUP = compression_min_speedup 48 | argmaxtools_test_utils.TEST_DEFAULT_NBITS = default_nbits 49 | argmaxtools_test_utils.TEST_SKIP_SPEED_TESTS = skip_speed_tests 50 | argmaxtools_test_utils.TEST_COMPILE_COREML = compile_coreml 51 | 52 | 53 | class TestSD3VAEDecoder(argmaxtools_test_utils.CoreMLTestsMixin, unittest.TestCase): 54 | """Unit tests for stable_duffusion_3.vae.VAEDecoder module""" 55 | 56 | @classmethod 57 | def setUpClass(cls): 58 | global TEST_SD3_CKPT_PATH 59 | cls.model_name = "VAEDecoder" 60 | cls.test_output_names = ["image"] 61 | cls.test_cache_dir = TEST_CACHE_DIR 62 | 63 | # Base test model 64 | logger.info("Initializing SD3 VAEDecoder model") 65 | cls.test_torch_model = ( 66 | vae.VAEDecoder(SD3_2b).to(TEST_DEV).to(TEST_TORCH_DTYPE).eval() 67 | ) 68 | logger.info("Initialized.") 69 | 70 | TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download( 71 | TEST_SD3_HF_REPO, "sd3_medium.safetensors" 72 | ) 73 | if TEST_SD3_CKPT_PATH is not None: 74 | logger.info(f"Loading SD3 model checkpoint from {TEST_SD3_CKPT_PATH}") 75 | _load_vae_decoder_weights(cls.test_torch_model, TEST_SD3_CKPT_PATH) 76 | logger.info("Loaded.") 77 | else: 78 | logger.info( 79 | "No TEST_SD3_CKPT_PATH (--sd3-ckpt-path) provided, exporting random weights" 80 | ) 81 | 82 | # Sample inputs 83 | # TODO(atiorh): CLI configurable model version 84 | cls.test_torch_inputs = get_test_inputs(SD3_2b) 85 | 86 | super().setUpClass() 87 | 88 | @classmethod 89 | def tearDownClass(cls): 90 | cls.test_torch_model = None 91 | cls.test_torch_inputs = None 92 | super().tearDownClass() 93 | 94 | 95 | def get_test_inputs(config: vae.VAEDecoderConfig) -> Dict[str, torch.Tensor]: 96 | """Generate random inputs for the SD3 MMDiT model""" 97 | config_expected_latent_resolution = ( 98 | config.resolution // 2 ** len(config.channel_multipliers) - 1 99 | ) 100 | if TEST_LATENT_SIZE != config_expected_latent_resolution: 101 | logger.warning( 102 | f"TEST_LATENT_SIZE ({TEST_LATENT_SIZE}) does not match the implied " 103 | f"latent resolution ({config_expected_latent_resolution}) from the model config " 104 | ) 105 | 106 | z_dims = (1, config.in_channels, TEST_LATENT_HEIGHT, TEST_LATENT_WIDTH) 107 | return {"z": torch.randn(*z_dims).to(TEST_DEV).to(TEST_TORCH_DTYPE)} 108 | 109 | 110 | def convert_vae_to_mlpackage( 111 | model_version: str, 112 | latent_h: int, 113 | latent_w: int, 114 | output_dir: str = None, 115 | **test_config_kwargs, 116 | ) -> str: 117 | """Converts a VAE decoder model to a CoreML package. 118 | 119 | Returns: 120 | `str`: path to the converted model. 121 | """ 122 | global TEST_SD3_CKPT_PATH, TEST_SD3_HF_REPO, TEST_LATENT_WIDTH, TEST_LATENT_HEIGHT, TEST_CACHE_DIR 123 | 124 | # Convert to CoreML 125 | TEST_SD3_HF_REPO = model_version 126 | TEST_LATENT_HEIGHT = latent_h or TEST_LATENT_SIZE 127 | TEST_LATENT_WIDTH = latent_w or TEST_LATENT_SIZE 128 | 129 | setup_test_config(compile_coreml=False, **test_config_kwargs) 130 | 131 | with argmaxtools_test_utils._get_test_cache_dir( 132 | persistent_cache_dir=output_dir 133 | ) as TEST_CACHE_DIR: 134 | suite = unittest.TestSuite() 135 | suite.addTest(TestSD3VAEDecoder("test_torch2coreml_correctness_and_speedup")) 136 | 137 | if os.getenv("DEBUG", False): 138 | suite.debug() 139 | else: 140 | runner = unittest.TextTestRunner() 141 | runner.run(suite) 142 | 143 | return os.path.join(TEST_CACHE_DIR, f"{TestSD3VAEDecoder.model_name}.mlpackage") 144 | 145 | 146 | if __name__ == "__main__": 147 | import argparse 148 | 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument("--sd3-ckpt-path", default=TEST_SD3_CKPT_PATH, type=str) 151 | parser.add_argument("-o", default=TEST_CACHE_DIR, type=str) 152 | parser.add_argument("--latent-size", default=TEST_LATENT_SIZE, type=int) 153 | args = parser.parse_args() 154 | 155 | TEST_SD3_CKPT_PATH = ( 156 | args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None 157 | ) 158 | TEST_SD3_HF_REPO = args.sd3_ckpt_path 159 | TEST_LATENT_SIZE = args.latent_size 160 | 161 | setup_test_config() 162 | 163 | with argmaxtools_test_utils._get_test_cache_dir(args.o) as TEST_CACHE_DIR: 164 | suite = unittest.TestSuite() 165 | suite.addTest(TestSD3VAEDecoder("test_torch2coreml_correctness_and_speedup")) 166 | 167 | if os.getenv("DEBUG", False): 168 | suite.debug() 169 | else: 170 | runner = unittest.TextTestRunner() 171 | runner.run(suite) 172 | -------------------------------------------------------------------------------- /diffusionkit/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from argmaxtools.utils import get_logger 5 | from PIL import Image 6 | from safetensors import safe_open 7 | 8 | logger = get_logger(__name__) 9 | 10 | 11 | def _load_weights(module: nn.Module, path: str) -> None: 12 | """Load weights from a checkpoint file (safetensors or pt)""" 13 | total_params_in_module = sum(p.numel() for p in module.parameters()) 14 | logger.info( 15 | f"Loading state_dict into nn.Module with {len([n for n,p in module.named_parameters()])} " 16 | f"parameter tensors totaling {total_params_in_module} " 17 | f"parameters from {path}" 18 | ) 19 | 20 | if path.endswith(".pt"): 21 | state_dict = torch.load(path, map_location="cpu") 22 | module.load_state_dict(state_dict) 23 | elif path.endswith(".safetensors"): 24 | state_dict = {} 25 | 26 | with safe_open(path, framework="pt", device="cpu") as f: 27 | for key in f.keys(): 28 | state_dict[key] = f.get_tensor(key) 29 | else: 30 | raise ValueError(f"Unsupported file format: {path}") 31 | 32 | total_params_in_state_dict = sum(np.prod(v.shape) for v in state_dict.values()) 33 | logger.info( 34 | f"Loaded state dict with {len(state_dict)} tensors totaling " 35 | f"{total_params_in_state_dict} parameters" 36 | ) 37 | 38 | if total_params_in_module != total_params_in_state_dict: 39 | raise ValueError( 40 | f"Total number of parameters in state_dict ({total_params_in_state_dict}) " 41 | f"does not match the number of parameters in the module ({total_params_in_module})" 42 | ) 43 | 44 | module.load_state_dict(state_dict) 45 | 46 | 47 | def bytes2gigabytes(n: int) -> int: 48 | """Convert bytes to gigabytes""" 49 | return n / 1024**3 50 | 51 | 52 | def image_psnr(reference: Image, proxy: Image) -> float: 53 | """Peak-Signal-to-Noise-Ratio in dB between a reference 54 | and a proxy PIL.Image 55 | """ 56 | reference = np.asarray(reference) 57 | proxy = np.asarray(proxy) 58 | 59 | assert ( 60 | reference.squeeze().shape == proxy.squeeze().shape 61 | ), f"{reference.shape} is incompatible with {proxy.shape}!" 62 | reference = reference.flatten() 63 | proxy = proxy.flatten() 64 | 65 | peak_signal = np.abs(reference).max() 66 | mse = np.sqrt(np.mean((reference - proxy) ** 2)) 67 | return 20 * np.log10((peak_signal + 1e-5) / (mse + 1e-10)) 68 | 69 | 70 | def compute_psnr(reference: np.array, proxy: np.array) -> float: 71 | """Peak-Signal-to-Noise-Ratio in dB between a reference 72 | and a proxy np.array 73 | """ 74 | assert ( 75 | reference.squeeze().shape == proxy.squeeze().shape 76 | ), f"{reference.shape} is incompatible with {proxy.shape}!" 77 | reference = reference.flatten() 78 | proxy = proxy.flatten() 79 | 80 | peak_signal = np.abs(reference).max() 81 | mse = np.sqrt(np.mean((reference - proxy) ** 2)) 82 | return 20 * np.log10((peak_signal + 1e-5) / (mse + 1e-10)) 83 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-mlx" 3 | description = "Faster workflows for ComfyUI users on Mac with Apple silicon" 4 | version = "1.0.4" 5 | license = {file = "LICENSE"} 6 | 7 | [project.urls] 8 | Repository = "https://github.com/thoddnn/ComfyUI-MLX" 9 | # Used by Comfy Registry https://comfyregistry.org 10 | 11 | [tool.comfy] 12 | PublisherId = "thoddnn" 13 | DisplayName = "ComfyUI-MLX" 14 | Icon = "" 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argmaxtools>=0.1.13 2 | torch 3 | safetensors 4 | mlx 5 | jaxtyping 6 | transformers 7 | pillow 8 | sentencepiece -------------------------------------------------------------------------------- /workflows/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoddnn/ComfyUI-MLX/215691b282f3d1eddb2e7029c2c399567cd0be9b/workflows/.DS_Store -------------------------------------------------------------------------------- /workflows/basic_workflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 10, 3 | "last_link_id": 11, 4 | "nodes": [ 5 | { 6 | "id": 5, 7 | "type": "MLXDecoder", 8 | "pos": [ 9 | 1624.08203125, 10 | 244.5 11 | ], 12 | "size": { 13 | "0": 229.20001220703125, 14 | "1": 46 15 | }, 16 | "flags": {}, 17 | "order": 4, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "latent_image", 22 | "type": "LATENT", 23 | "link": 9 24 | }, 25 | { 26 | "name": "mlx_vae", 27 | "type": "mlx_vae", 28 | "link": 4 29 | } 30 | ], 31 | "outputs": [ 32 | { 33 | "name": "IMAGE", 34 | "type": "IMAGE", 35 | "links": [ 36 | 7 37 | ], 38 | "shape": 3, 39 | "slot_index": 0 40 | } 41 | ], 42 | "properties": { 43 | "Node name for S&R": "MLXDecoder" 44 | } 45 | }, 46 | { 47 | "id": 7, 48 | "type": "SaveImage", 49 | "pos": [ 50 | 1938, 51 | 205 52 | ], 53 | "size": { 54 | "0": 457.3671875, 55 | "1": 406.68359375 56 | }, 57 | "flags": {}, 58 | "order": 5, 59 | "mode": 0, 60 | "inputs": [ 61 | { 62 | "name": "images", 63 | "type": "IMAGE", 64 | "link": 7 65 | } 66 | ], 67 | "properties": {}, 68 | "widgets_values": [ 69 | "ComfyUI" 70 | ] 71 | }, 72 | { 73 | "id": 2, 74 | "type": "MLXLoadFlux", 75 | "pos": [ 76 | 227, 77 | 426 78 | ], 79 | "size": { 80 | "0": 315, 81 | "1": 98 82 | }, 83 | "flags": {}, 84 | "order": 0, 85 | "mode": 0, 86 | "outputs": [ 87 | { 88 | "name": "mlx_model", 89 | "type": "mlx_model", 90 | "links": [ 91 | 8 92 | ], 93 | "shape": 3, 94 | "slot_index": 0 95 | }, 96 | { 97 | "name": "mlx_vae", 98 | "type": "mlx_vae", 99 | "links": [ 100 | 4 101 | ], 102 | "shape": 3, 103 | "slot_index": 1 104 | }, 105 | { 106 | "name": "mlx_conditioning", 107 | "type": "mlx_conditioning", 108 | "links": [ 109 | 1 110 | ], 111 | "shape": 3, 112 | "slot_index": 2 113 | } 114 | ], 115 | "properties": { 116 | "Node name for S&R": "MLXLoadFlux" 117 | }, 118 | "widgets_values": [ 119 | "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized" 120 | ] 121 | }, 122 | { 123 | "id": 3, 124 | "type": "MLXClipTextEncoder", 125 | "pos": [ 126 | 627, 127 | 157 128 | ], 129 | "size": { 130 | "0": 418.1999816894531, 131 | "1": 200 132 | }, 133 | "flags": {}, 134 | "order": 2, 135 | "mode": 0, 136 | "inputs": [ 137 | { 138 | "name": "mlx_conditioning", 139 | "type": "mlx_conditioning", 140 | "link": 1 141 | } 142 | ], 143 | "outputs": [ 144 | { 145 | "name": "mlx_conditioning", 146 | "type": "mlx_conditioning", 147 | "links": [ 148 | 10 149 | ], 150 | "shape": 3, 151 | "slot_index": 0 152 | } 153 | ], 154 | "properties": { 155 | "Node name for S&R": "MLXClipTextEncoder" 156 | }, 157 | "widgets_values": [ 158 | "photo of a cat" 159 | ] 160 | }, 161 | { 162 | "id": 9, 163 | "type": "MLXSampler", 164 | "pos": [ 165 | 1142, 166 | 418 167 | ], 168 | "size": { 169 | "0": 405.5999755859375, 170 | "1": 194 171 | }, 172 | "flags": {}, 173 | "order": 3, 174 | "mode": 0, 175 | "inputs": [ 176 | { 177 | "name": "mlx_model", 178 | "type": "mlx_model", 179 | "link": 8 180 | }, 181 | { 182 | "name": "mlx_positive_conditioning", 183 | "type": "mlx_conditioning", 184 | "link": 10 185 | }, 186 | { 187 | "name": "latent_image", 188 | "type": "LATENT", 189 | "link": 11, 190 | "slot_index": 2 191 | } 192 | ], 193 | "outputs": [ 194 | { 195 | "name": "LATENT", 196 | "type": "LATENT", 197 | "links": [ 198 | 9 199 | ], 200 | "shape": 3, 201 | "slot_index": 0 202 | } 203 | ], 204 | "properties": { 205 | "Node name for S&R": "MLXSampler" 206 | }, 207 | "widgets_values": [ 208 | 3501603197, 209 | "randomize", 210 | 4, 211 | 0, 212 | 1 213 | ] 214 | }, 215 | { 216 | "id": 10, 217 | "type": "EmptyLatentImage", 218 | "pos": [ 219 | 709, 220 | 733 221 | ], 222 | "size": { 223 | "0": 315, 224 | "1": 106 225 | }, 226 | "flags": {}, 227 | "order": 1, 228 | "mode": 0, 229 | "outputs": [ 230 | { 231 | "name": "LATENT", 232 | "type": "LATENT", 233 | "links": [ 234 | 11 235 | ], 236 | "shape": 3 237 | } 238 | ], 239 | "properties": { 240 | "Node name for S&R": "EmptyLatentImage" 241 | }, 242 | "widgets_values": [ 243 | 512, 244 | 512, 245 | 1 246 | ] 247 | } 248 | ], 249 | "links": [ 250 | [ 251 | 1, 252 | 2, 253 | 2, 254 | 3, 255 | 0, 256 | "mlx_conditioning" 257 | ], 258 | [ 259 | 4, 260 | 2, 261 | 1, 262 | 5, 263 | 1, 264 | "mlx_vae" 265 | ], 266 | [ 267 | 7, 268 | 5, 269 | 0, 270 | 7, 271 | 0, 272 | "IMAGE" 273 | ], 274 | [ 275 | 8, 276 | 2, 277 | 0, 278 | 9, 279 | 0, 280 | "mlx_model" 281 | ], 282 | [ 283 | 9, 284 | 9, 285 | 0, 286 | 5, 287 | 0, 288 | "LATENT" 289 | ], 290 | [ 291 | 10, 292 | 3, 293 | 0, 294 | 9, 295 | 1, 296 | "mlx_conditioning" 297 | ], 298 | [ 299 | 11, 300 | 10, 301 | 0, 302 | 9, 303 | 2, 304 | "LATENT" 305 | ] 306 | ], 307 | "groups": [], 308 | "config": {}, 309 | "extra": { 310 | "ds": { 311 | "scale": 0.683013455365071, 312 | "offset": [ 313 | -46.40798165245038, 314 | 206.20957133931108 315 | ] 316 | } 317 | }, 318 | "version": 0.4 319 | } --------------------------------------------------------------------------------