├── config_example.yaml ├── README.md ├── inference.py └── .gitignore /config_example.yaml: -------------------------------------------------------------------------------- 1 | job: extension 2 | config: 3 | name: "YOUR_LORA_NAME" 4 | process: 5 | - type: 'sd_trainer' 6 | training_folder: "/root/ai-toolkit/modal_output" 7 | device: cuda:0 8 | trigger_word: "atelierai_sks_768" 9 | network: 10 | type: "lokr" 11 | linear: 16 12 | linear_alpha: 16 13 | network_kwargs: 14 | only_if_contains: 15 | - "transformer.single_transformer_blocks.9." 16 | - "transformer.single_transformer_blocks.25." 17 | - "transformer.transformer_blocks.5." 18 | - "transformer.transformer_blocks.15." 19 | save: 20 | dtype: float16 21 | save_every: 10000 22 | max_step_saves_to_keep: 4 23 | push_to_hub: true 24 | hf_private: true 25 | hf_repo_id: "YOUR_USERNAME/YOUR_MODEL_NAME" 26 | datasets: 27 | - folder_path: "/root/ai-toolkit/YOUR_DATASET" 28 | caption_ext: "txt" 29 | caption_dropout_rate: 0.0 30 | shuffle_tokens: false 31 | cache_latents_to_disk: false 32 | resolution: [768, 1024] 33 | train: 34 | batch_size: 1 35 | steps: 1000 36 | gradient_accumulation_steps: 1 37 | train_unet: true 38 | train_text_encoder: false 39 | gradient_checkpointing: true 40 | noise_scheduler: "flowmatch" 41 | optimizer: "adamw8bit" 42 | lr: 1e-3 43 | skip_first_sample: true 44 | disable_sampling: true 45 | ema_config: 46 | use_ema: true 47 | ema_decay: 0.99 48 | dtype: bf16 49 | model: 50 | name_or_path: "black-forest-labs/FLUX.1-dev" 51 | is_flux: true 52 | quantize: false 53 | low_vram: false 54 | sample: 55 | sampler: "flowmatch" 56 | sample_every: 1000 57 | width: 1024 58 | height: 1024 59 | prompts: 60 | - "cowboy wearing a denim jacket, atelierai_sks_768" 61 | neg: "" 62 | seed: 42 63 | walk_seed: true 64 | guidance_scale: 3.5 65 | sample_steps: 28 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Modal FLUX LoRA toolkit 2 | 3 | Have you ever wondered why [Fal AI](https://fal.ai)'s LoRA trainer is so fast? I recently was working on [Atelier AI](https://atelierai.me) which is an equivalent of PhotoAI for Iranian/Persian speaking users and I needed to make the train procedure fast. I first wanted to use Fal AI, but I realized that $2 per train is too much for me. In the other hand, I could use [Modal](https://modal.com) in order to do my training and stuff. 4 | 5 | So I did a lot of research personally to find out these tips about making train faster: 6 | 7 | - Using a powerful GPU can help, but it's not everything. 8 | - We only can use one or two layers of LoRA in order to fit our concepts in it (if you see [config_example.yaml](./config_example.yaml) file, you will notice I've used 4.) 9 | - If we disable sampling, we can make it even faster. 10 | - On powerful GPU's such as _A100 80GB_ or _H100_ we won't need to worry about `lowvram` and we can disable it (and therefore make the process faster.) 11 | 12 |

13 | 14 |

15 | 16 | ## What do you need for usig this project? 17 | 18 | - An account on [Modal](https://modal.com) with sufficient funds. 19 | - [Ostris' AI Toolkit](https://github.com/ostris/ai-toolkit) 20 | - A [HuggingFace](https://huggingface.co) token with `read` permission. 21 | - Setting up your HF token on modal. 22 | - Enough time and courage for your AI weekend project! 23 | 24 | ## How to use training configuration 25 | 26 | Please check AI Toolkit's documentations in order to find out how to use modal trainer script. I just provided the YAML file for configuring your model. Also do not forget to update the config file with true values. 27 | 28 | ## How to use inference 29 | 30 | Just run the code like this: 31 | 32 | ``` 33 | modal run inference.py \ 34 | --prompt "a cat" \ 35 | --width 1024 \ 36 | --height 1024 \ 37 | --lora "HF_USERNAME/MODEL_NAME" \ 38 | --filename "my_amazing_image" 39 | ``` 40 | 41 | ## Notes 42 | 43 | - Since [Atelier AI](https://atelierai.me) is a part of my bigger startup [Mann-E](https://mann-e.com), the model is set to `mann-e/mann-e_flux`. You easily can change it to your desired FLUX based checkpoint. 44 | - The training configuration uploads trained LoRA (in this case, LoKR) on HuggingFace automatically. You can use the same weight on Fal AI, Replicate, Self-Hosted Forge or anywhere you have the ability to use LoRAs. -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import time 2 | from io import BytesIO 3 | import modal 4 | 5 | cuda_version = "12.4.0" 6 | flavor = "devel" 7 | operating_sys = "ubuntu22.04" 8 | tag = f"{cuda_version}-{flavor}-{operating_sys}" 9 | 10 | cuda_dev_image = modal.Image.from_registry( 11 | f"nvidia/cuda:{tag}", add_python="3.11" 12 | ).entrypoint([]) 13 | 14 | flux_image = ( 15 | cuda_dev_image.apt_install( 16 | "git", 17 | "libglib2.0-0", 18 | "libsm6", 19 | "libxrender1", 20 | "libxext6", 21 | "ffmpeg", 22 | "libgl1", 23 | ) 24 | .pip_install( 25 | "invisible_watermark==0.2.0", 26 | "transformers==4.44.0", 27 | "huggingface_hub[hf_transfer]==0.26.2", 28 | "accelerate==0.33.0", 29 | "safetensors==0.4.4", 30 | "sentencepiece==0.2.0", 31 | f"git+https://github.com/huggingface/diffusers.git", 32 | "numpy<2", 33 | "protobuf", 34 | "peft", 35 | ) 36 | .env({"HF_HUB_ENABLE_HF_TRANSFER": "1", "HF_HUB_CACHE": "/cache"}) 37 | ) 38 | 39 | flux_image = flux_image.env( 40 | { 41 | "TORCHINDUCTOR_CACHE_DIR": "/root/.inductor-cache", 42 | "TORCHINDUCTOR_FX_GRAPH_CACHE": "1", 43 | } 44 | ) 45 | 46 | app = modal.App("example-flux-lora", image=flux_image) 47 | 48 | with flux_image.imports(): 49 | import torch 50 | from diffusers import DiffusionPipeline, AutoencoderTiny 51 | 52 | 53 | MINUTES = 60 54 | NUM_INFERENCE_STEPS = 16 55 | 56 | 57 | @app.cls( 58 | gpu="H100", 59 | container_idle_timeout=3 * MINUTES, 60 | timeout=60 * MINUTES, 61 | secrets=[modal.Secret.from_name("huggingface-secret")], 62 | volumes={ 63 | "/cache": modal.Volume.from_name("hf-hub-cache", create_if_missing=True), 64 | "/root/.nv": modal.Volume.from_name("nv-cache", create_if_missing=True), 65 | "/root/.triton": modal.Volume.from_name("triton-cache", create_if_missing=True), 66 | "/root/.inductor-cache": modal.Volume.from_name("inductor-cache", create_if_missing=True), 67 | }, 68 | ) 69 | class Model: 70 | compile: int = modal.parameter(default=0) 71 | 72 | @modal.enter() 73 | def enter(self): 74 | from huggingface_hub import snapshot_download 75 | 76 | snapshot_download("mann-e/mann-e_flux") 77 | 78 | taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16) 79 | pipe = DiffusionPipeline.from_pretrained( 80 | "mann-e/mann-e_flux", torch_dtype=torch.bfloat16, vae=taef1 81 | ) 82 | 83 | pipe.to("cuda") 84 | self.pipe = optimize(pipe, compile=bool(self.compile)) 85 | 86 | @modal.method() 87 | def inference(self, prompt: str, width: int, height: int, lora: str) -> bytes: 88 | print("🎨 generating image...") 89 | 90 | self.pipe.load_lora_weights(lora) 91 | self.pipe.fuse_lora(lora_scale=1.0) 92 | 93 | out = self.pipe( 94 | prompt, 95 | output_type="pil", 96 | num_inference_steps=NUM_INFERENCE_STEPS, 97 | width=width, 98 | height=height, 99 | guidance_scale=3.5, 100 | ).images[0] 101 | 102 | byte_stream = BytesIO() 103 | out.save(byte_stream, format="JPEG") 104 | return byte_stream.getvalue() 105 | 106 | 107 | @app.local_entrypoint() 108 | def main( 109 | prompt: str, 110 | width: int, 111 | height: int, 112 | lora: str, 113 | filename: str, 114 | twice: bool = False, 115 | compile: bool = False, 116 | ): 117 | t0 = time.time() 118 | image_bytes = Model(compile=compile).inference.remote(prompt, width, height, lora) 119 | print(f"🎨 first inference latency: {time.time() - t0:.2f} seconds") 120 | 121 | if twice: 122 | t0 = time.time() 123 | image_bytes = Model(compile=compile).inference.remote(prompt, width, height, lora) 124 | print(f"🎨 second inference latency: {time.time() - t0:.2f} seconds") 125 | 126 | with open(f"{filename}.jpg", "wb") as f: 127 | f.write(image_bytes) 128 | 129 | 130 | def optimize(pipe, compile=True): 131 | return pipe 132 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python,flask 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,flask 3 | 4 | ### Flask ### 5 | instance/* 6 | !instance/.gitignore 7 | .webassets-cache 8 | .env 9 | 10 | ### Flask.Python Stack ### 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/#use-with-ide 119 | .pdm.toml 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | ### Python ### 171 | # Byte-compiled / optimized / DLL files 172 | 173 | # C extensions 174 | 175 | # Distribution / packaging 176 | 177 | # PyInstaller 178 | # Usually these files are written by a python script from a template 179 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 180 | 181 | # Installer logs 182 | 183 | # Unit test / coverage reports 184 | 185 | # Translations 186 | 187 | # Django stuff: 188 | 189 | # Flask stuff: 190 | 191 | # Scrapy stuff: 192 | 193 | # Sphinx documentation 194 | 195 | # PyBuilder 196 | 197 | # Jupyter Notebook 198 | 199 | # IPython 200 | 201 | # pyenv 202 | # For a library or package, you might want to ignore these files since the code is 203 | # intended to run in multiple environments; otherwise, check them in: 204 | # .python-version 205 | 206 | # pipenv 207 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 208 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 209 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 210 | # install all needed dependencies. 211 | 212 | # poetry 213 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 214 | # This is especially recommended for binary packages to ensure reproducibility, and is more 215 | # commonly ignored for libraries. 216 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 217 | 218 | # pdm 219 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 220 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 221 | # in version control. 222 | # https://pdm.fming.dev/#use-with-ide 223 | 224 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 225 | 226 | # Celery stuff 227 | 228 | # SageMath parsed files 229 | 230 | # Environments 231 | 232 | # Spyder project settings 233 | 234 | # Rope project settings 235 | 236 | # mkdocs documentation 237 | 238 | # mypy 239 | 240 | # Pyre type checker 241 | 242 | # pytype static type analyzer 243 | 244 | # Cython debug symbols 245 | 246 | # PyCharm 247 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 248 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 249 | # and can be added to the global gitignore or merged into this file. For a more nuclear 250 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 251 | 252 | ### Python Patch ### 253 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 254 | poetry.toml 255 | 256 | # ruff 257 | .ruff_cache/ 258 | 259 | # LSP config files 260 | pyrightconfig.json 261 | 262 | # End of https://www.toptal.com/developers/gitignore/api/python,flask --------------------------------------------------------------------------------