├── config_example.yaml
├── README.md
├── inference.py
└── .gitignore
/config_example.yaml:
--------------------------------------------------------------------------------
1 | job: extension
2 | config:
3 | name: "YOUR_LORA_NAME"
4 | process:
5 | - type: 'sd_trainer'
6 | training_folder: "/root/ai-toolkit/modal_output"
7 | device: cuda:0
8 | trigger_word: "atelierai_sks_768"
9 | network:
10 | type: "lokr"
11 | linear: 16
12 | linear_alpha: 16
13 | network_kwargs:
14 | only_if_contains:
15 | - "transformer.single_transformer_blocks.9."
16 | - "transformer.single_transformer_blocks.25."
17 | - "transformer.transformer_blocks.5."
18 | - "transformer.transformer_blocks.15."
19 | save:
20 | dtype: float16
21 | save_every: 10000
22 | max_step_saves_to_keep: 4
23 | push_to_hub: true
24 | hf_private: true
25 | hf_repo_id: "YOUR_USERNAME/YOUR_MODEL_NAME"
26 | datasets:
27 | - folder_path: "/root/ai-toolkit/YOUR_DATASET"
28 | caption_ext: "txt"
29 | caption_dropout_rate: 0.0
30 | shuffle_tokens: false
31 | cache_latents_to_disk: false
32 | resolution: [768, 1024]
33 | train:
34 | batch_size: 1
35 | steps: 1000
36 | gradient_accumulation_steps: 1
37 | train_unet: true
38 | train_text_encoder: false
39 | gradient_checkpointing: true
40 | noise_scheduler: "flowmatch"
41 | optimizer: "adamw8bit"
42 | lr: 1e-3
43 | skip_first_sample: true
44 | disable_sampling: true
45 | ema_config:
46 | use_ema: true
47 | ema_decay: 0.99
48 | dtype: bf16
49 | model:
50 | name_or_path: "black-forest-labs/FLUX.1-dev"
51 | is_flux: true
52 | quantize: false
53 | low_vram: false
54 | sample:
55 | sampler: "flowmatch"
56 | sample_every: 1000
57 | width: 1024
58 | height: 1024
59 | prompts:
60 | - "cowboy wearing a denim jacket, atelierai_sks_768"
61 | neg: ""
62 | seed: 42
63 | walk_seed: true
64 | guidance_scale: 3.5
65 | sample_steps: 28
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Modal FLUX LoRA toolkit
2 |
3 | Have you ever wondered why [Fal AI](https://fal.ai)'s LoRA trainer is so fast? I recently was working on [Atelier AI](https://atelierai.me) which is an equivalent of PhotoAI for Iranian/Persian speaking users and I needed to make the train procedure fast. I first wanted to use Fal AI, but I realized that $2 per train is too much for me. In the other hand, I could use [Modal](https://modal.com) in order to do my training and stuff.
4 |
5 | So I did a lot of research personally to find out these tips about making train faster:
6 |
7 | - Using a powerful GPU can help, but it's not everything.
8 | - We only can use one or two layers of LoRA in order to fit our concepts in it (if you see [config_example.yaml](./config_example.yaml) file, you will notice I've used 4.)
9 | - If we disable sampling, we can make it even faster.
10 | - On powerful GPU's such as _A100 80GB_ or _H100_ we won't need to worry about `lowvram` and we can disable it (and therefore make the process faster.)
11 |
12 |
13 |
14 |
15 |
16 | ## What do you need for usig this project?
17 |
18 | - An account on [Modal](https://modal.com) with sufficient funds.
19 | - [Ostris' AI Toolkit](https://github.com/ostris/ai-toolkit)
20 | - A [HuggingFace](https://huggingface.co) token with `read` permission.
21 | - Setting up your HF token on modal.
22 | - Enough time and courage for your AI weekend project!
23 |
24 | ## How to use training configuration
25 |
26 | Please check AI Toolkit's documentations in order to find out how to use modal trainer script. I just provided the YAML file for configuring your model. Also do not forget to update the config file with true values.
27 |
28 | ## How to use inference
29 |
30 | Just run the code like this:
31 |
32 | ```
33 | modal run inference.py \
34 | --prompt "a cat" \
35 | --width 1024 \
36 | --height 1024 \
37 | --lora "HF_USERNAME/MODEL_NAME" \
38 | --filename "my_amazing_image"
39 | ```
40 |
41 | ## Notes
42 |
43 | - Since [Atelier AI](https://atelierai.me) is a part of my bigger startup [Mann-E](https://mann-e.com), the model is set to `mann-e/mann-e_flux`. You easily can change it to your desired FLUX based checkpoint.
44 | - The training configuration uploads trained LoRA (in this case, LoKR) on HuggingFace automatically. You can use the same weight on Fal AI, Replicate, Self-Hosted Forge or anywhere you have the ability to use LoRAs.
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import time
2 | from io import BytesIO
3 | import modal
4 |
5 | cuda_version = "12.4.0"
6 | flavor = "devel"
7 | operating_sys = "ubuntu22.04"
8 | tag = f"{cuda_version}-{flavor}-{operating_sys}"
9 |
10 | cuda_dev_image = modal.Image.from_registry(
11 | f"nvidia/cuda:{tag}", add_python="3.11"
12 | ).entrypoint([])
13 |
14 | flux_image = (
15 | cuda_dev_image.apt_install(
16 | "git",
17 | "libglib2.0-0",
18 | "libsm6",
19 | "libxrender1",
20 | "libxext6",
21 | "ffmpeg",
22 | "libgl1",
23 | )
24 | .pip_install(
25 | "invisible_watermark==0.2.0",
26 | "transformers==4.44.0",
27 | "huggingface_hub[hf_transfer]==0.26.2",
28 | "accelerate==0.33.0",
29 | "safetensors==0.4.4",
30 | "sentencepiece==0.2.0",
31 | f"git+https://github.com/huggingface/diffusers.git",
32 | "numpy<2",
33 | "protobuf",
34 | "peft",
35 | )
36 | .env({"HF_HUB_ENABLE_HF_TRANSFER": "1", "HF_HUB_CACHE": "/cache"})
37 | )
38 |
39 | flux_image = flux_image.env(
40 | {
41 | "TORCHINDUCTOR_CACHE_DIR": "/root/.inductor-cache",
42 | "TORCHINDUCTOR_FX_GRAPH_CACHE": "1",
43 | }
44 | )
45 |
46 | app = modal.App("example-flux-lora", image=flux_image)
47 |
48 | with flux_image.imports():
49 | import torch
50 | from diffusers import DiffusionPipeline, AutoencoderTiny
51 |
52 |
53 | MINUTES = 60
54 | NUM_INFERENCE_STEPS = 16
55 |
56 |
57 | @app.cls(
58 | gpu="H100",
59 | container_idle_timeout=3 * MINUTES,
60 | timeout=60 * MINUTES,
61 | secrets=[modal.Secret.from_name("huggingface-secret")],
62 | volumes={
63 | "/cache": modal.Volume.from_name("hf-hub-cache", create_if_missing=True),
64 | "/root/.nv": modal.Volume.from_name("nv-cache", create_if_missing=True),
65 | "/root/.triton": modal.Volume.from_name("triton-cache", create_if_missing=True),
66 | "/root/.inductor-cache": modal.Volume.from_name("inductor-cache", create_if_missing=True),
67 | },
68 | )
69 | class Model:
70 | compile: int = modal.parameter(default=0)
71 |
72 | @modal.enter()
73 | def enter(self):
74 | from huggingface_hub import snapshot_download
75 |
76 | snapshot_download("mann-e/mann-e_flux")
77 |
78 | taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16)
79 | pipe = DiffusionPipeline.from_pretrained(
80 | "mann-e/mann-e_flux", torch_dtype=torch.bfloat16, vae=taef1
81 | )
82 |
83 | pipe.to("cuda")
84 | self.pipe = optimize(pipe, compile=bool(self.compile))
85 |
86 | @modal.method()
87 | def inference(self, prompt: str, width: int, height: int, lora: str) -> bytes:
88 | print("🎨 generating image...")
89 |
90 | self.pipe.load_lora_weights(lora)
91 | self.pipe.fuse_lora(lora_scale=1.0)
92 |
93 | out = self.pipe(
94 | prompt,
95 | output_type="pil",
96 | num_inference_steps=NUM_INFERENCE_STEPS,
97 | width=width,
98 | height=height,
99 | guidance_scale=3.5,
100 | ).images[0]
101 |
102 | byte_stream = BytesIO()
103 | out.save(byte_stream, format="JPEG")
104 | return byte_stream.getvalue()
105 |
106 |
107 | @app.local_entrypoint()
108 | def main(
109 | prompt: str,
110 | width: int,
111 | height: int,
112 | lora: str,
113 | filename: str,
114 | twice: bool = False,
115 | compile: bool = False,
116 | ):
117 | t0 = time.time()
118 | image_bytes = Model(compile=compile).inference.remote(prompt, width, height, lora)
119 | print(f"🎨 first inference latency: {time.time() - t0:.2f} seconds")
120 |
121 | if twice:
122 | t0 = time.time()
123 | image_bytes = Model(compile=compile).inference.remote(prompt, width, height, lora)
124 | print(f"🎨 second inference latency: {time.time() - t0:.2f} seconds")
125 |
126 | with open(f"{filename}.jpg", "wb") as f:
127 | f.write(image_bytes)
128 |
129 |
130 | def optimize(pipe, compile=True):
131 | return pipe
132 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.toptal.com/developers/gitignore/api/python,flask
2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,flask
3 |
4 | ### Flask ###
5 | instance/*
6 | !instance/.gitignore
7 | .webassets-cache
8 | .env
9 |
10 | ### Flask.Python Stack ###
11 | # Byte-compiled / optimized / DLL files
12 | __pycache__/
13 | *.py[cod]
14 | *$py.class
15 |
16 | # C extensions
17 | *.so
18 |
19 | # Distribution / packaging
20 | .Python
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | share/python-wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .nox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | *.py,cover
60 | .hypothesis/
61 | .pytest_cache/
62 | cover/
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # Django stuff:
69 | *.log
70 | local_settings.py
71 | db.sqlite3
72 | db.sqlite3-journal
73 |
74 | # Flask stuff:
75 | instance/
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | .pybuilder/
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # IPython
91 | profile_default/
92 | ipython_config.py
93 |
94 | # pyenv
95 | # For a library or package, you might want to ignore these files since the code is
96 | # intended to run in multiple environments; otherwise, check them in:
97 | # .python-version
98 |
99 | # pipenv
100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
103 | # install all needed dependencies.
104 | #Pipfile.lock
105 |
106 | # poetry
107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108 | # This is especially recommended for binary packages to ensure reproducibility, and is more
109 | # commonly ignored for libraries.
110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111 | #poetry.lock
112 |
113 | # pdm
114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115 | #pdm.lock
116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
117 | # in version control.
118 | # https://pdm.fming.dev/#use-with-ide
119 | .pdm.toml
120 |
121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122 | __pypackages__/
123 |
124 | # Celery stuff
125 | celerybeat-schedule
126 | celerybeat.pid
127 |
128 | # SageMath parsed files
129 | *.sage.py
130 |
131 | # Environments
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | ### Python ###
171 | # Byte-compiled / optimized / DLL files
172 |
173 | # C extensions
174 |
175 | # Distribution / packaging
176 |
177 | # PyInstaller
178 | # Usually these files are written by a python script from a template
179 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
180 |
181 | # Installer logs
182 |
183 | # Unit test / coverage reports
184 |
185 | # Translations
186 |
187 | # Django stuff:
188 |
189 | # Flask stuff:
190 |
191 | # Scrapy stuff:
192 |
193 | # Sphinx documentation
194 |
195 | # PyBuilder
196 |
197 | # Jupyter Notebook
198 |
199 | # IPython
200 |
201 | # pyenv
202 | # For a library or package, you might want to ignore these files since the code is
203 | # intended to run in multiple environments; otherwise, check them in:
204 | # .python-version
205 |
206 | # pipenv
207 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
208 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
209 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
210 | # install all needed dependencies.
211 |
212 | # poetry
213 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
214 | # This is especially recommended for binary packages to ensure reproducibility, and is more
215 | # commonly ignored for libraries.
216 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
217 |
218 | # pdm
219 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
220 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
221 | # in version control.
222 | # https://pdm.fming.dev/#use-with-ide
223 |
224 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
225 |
226 | # Celery stuff
227 |
228 | # SageMath parsed files
229 |
230 | # Environments
231 |
232 | # Spyder project settings
233 |
234 | # Rope project settings
235 |
236 | # mkdocs documentation
237 |
238 | # mypy
239 |
240 | # Pyre type checker
241 |
242 | # pytype static type analyzer
243 |
244 | # Cython debug symbols
245 |
246 | # PyCharm
247 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
248 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
249 | # and can be added to the global gitignore or merged into this file. For a more nuclear
250 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
251 |
252 | ### Python Patch ###
253 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
254 | poetry.toml
255 |
256 | # ruff
257 | .ruff_cache/
258 |
259 | # LSP config files
260 | pyrightconfig.json
261 |
262 | # End of https://www.toptal.com/developers/gitignore/api/python,flask
--------------------------------------------------------------------------------