├── .gitignore ├── Dockerfile.sd15-lora ├── README.md ├── config.yml ├── config.yml.sample_sd15 ├── config_sdxl.yml ├── mistral_lora.py ├── requirements.txt ├── sd15_lora.py └── sdxl_lora.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.sw* -------------------------------------------------------------------------------- /Dockerfile.sd15-lora: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.6.2-base-ubuntu20.04 2 | 3 | # Install lora and pre-cache stable diffusion 1.5 model to avoid re-downloading 4 | # it for every inference. 5 | # 6 | # NB: diffusers downgrade is because of https://github.com/cloneofsimo/lora/issues/231 7 | RUN apt-get update -y && apt-get install -y python3 python3-pip git unzip && \ 8 | pip install git+https://github.com/cloneofsimo/lora.git@v0.1.7 && \ 9 | pip install diffusers==0.14 && \ 10 | pip install accelerate==0.20.3 && \ 11 | python3 -c "from diffusers import StableDiffusionPipeline; model_id = 'runwayml/stable-diffusion-v1-5'; StableDiffusionPipeline.from_pretrained(model_id)" 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🌸 Fine-tuning Stable Diffusion using [LoRA](https://github.com/cloneofsimo/lora) with Dagger 2 | 3 | This example shows how to create brand assets for merchandise from some screenshots of a brand's website: 4 | 5 | ![dagger-stable-diffusion-lora](https://github.com/lukemarsden/dagger-stable-diffusion-lora/assets/264658/071001b9-6873-45b4-8e46-c6f924ef8b33) 6 | 7 | It's implemented as a Stable Diffusion LoRA pipeline using [Dagger](https://dagger.io): pipelines as (python) code. 8 | 9 | * Go to [lambdalabs.com](https://lambdalabs.com), or any other GPU provider of your choice (the instructions below were tested on Lambda) 10 | * Get an instance (e.g. A100 or A10). Min GPU memory is 16GB, tested on 24GB 11 | * Hit up Jupyter (or SSH in, but Jupyter makes viewing the outputs easier 😊) 12 | 13 | ## 🐋 New terminal, add user to docker group 14 | 15 | ``` 16 | sudo adduser ubuntu docker 17 | ``` 18 | ``` 19 | sudo su - ubuntu 20 | ``` 21 | 22 | ## 🐍 Install newer Python 23 | 24 | ``` 25 | sudo add-apt-repository ppa:deadsnakes/ppa 26 | ``` 27 | Press enter to install the PPA. 28 | 29 | ``` 30 | sudo apt install -y socat python3.10-venv 31 | ``` 32 | ``` 33 | python3.10 -m virtualenv venv 34 | ``` 35 | ``` 36 | . venv/bin/activate 37 | ``` 38 | ``` 39 | curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 40 | ``` 41 | 42 | ## 🚀 Install Dagger CLI 43 | 44 | ``` 45 | ( cd /usr/local ; curl -L https://dl.dagger.io/dagger/install.sh | sudo sh ) 46 | ``` 47 | 48 | ## ⚙️ Check out repo and configure it 49 | 50 | ``` 51 | git clone https://github.com/lukemarsden/dagger-stable-diffusion-lora 52 | ``` 53 | ``` 54 | cd dagger-stable-diffusion-lora 55 | ``` 56 | 57 | ``` 58 | pip install -r requirements.txt 59 | ``` 60 | ``` 61 | cp config.yml.sample config.yml 62 | ``` 63 | 64 | Now open `dagger-stable-diffusion-lora/config.yml` in the Jupyter editor and change anything you like. 65 | 66 | ## 🚂 Train some LoRAs! 67 | 68 | Back in the first terminal, run: 69 | ``` 70 | export _EXPERIMENTAL_DAGGER_INTERACTIVE_TUI=1 71 | dagger run python lora.py 72 | ``` 73 | 74 | Now go and have lunch while you pull docker images & train some LoRAs 😊 75 | 76 | ... 77 | 78 | If you're curious to watch the progress, in another terminal tab, do `sudo docker ps` and `sudo docker logs -f ` for some of the running jobs. 79 | 80 | Welcome back, check out `output/inference` to see the results! 81 | 82 | ## 🏃 Observe the dagger cache making things faster 83 | 84 | Now uncomment some of the prompts and/or brands in the `config.yml` and re-run, note how the Dagger cache saves you from having to redo any work that it's already done! 85 | 86 | ## 💸 Remember to shut down your GPU 87 | 88 | If you like keeping your money! 89 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | # the name of the brands (short human readable name), corresponds to the 2 | # zipfiles the script will download 3 | brands: 4 | - coke 5 | # - dagger 6 | # - docker 7 | # - kubernetes 8 | # - nike 9 | # - vision-pro 10 | 11 | # the prompts for each brand, adjust these to generate different images 12 | prompts: 13 | mug: "coffee mug with logo on it, in the style of " 14 | # mug2: "coffee mug with brand logo on it, in the style of " 15 | # mug3: "coffee mug with brand logo on it, in the style of , 50mm portrait photography, hard rim lighting photography, merchandise" 16 | # tshirt: "woman torso wearing tshirt with logo, 50mm portrait photography, hard rim lighting photography, merchandise" 17 | 18 | # how many images to generate for each prompt for each brand 19 | num_images: 10 20 | 21 | # the script expects zip files to download from "{url_prefix}/{brand}.zip", e.g. https://storage.googleapis.com/dagger-assets/coke.zip 22 | # the zip file must just contain the images to fine tune the model on, filenames don't seem to matter 23 | # **must have trailing slash** 24 | url_prefix: https://storage.googleapis.com/dagger-assets/ 25 | 26 | # you shouldn't need to change this (and note that doing so will break the 27 | # caching, see Dockerfile) 28 | model_name: runwayml/stable-diffusion-v1-5 29 | 30 | # container image built from the Dockerfile in this repo 31 | container_image: quay.io/lukemarsden/lora:v0.0.2 -------------------------------------------------------------------------------- /config.yml.sample_sd15: -------------------------------------------------------------------------------- 1 | # the name of the brands (short human readable name), corresponds to the 2 | # zipfiles the script will download 3 | brands: 4 | - coke 5 | # - dagger 6 | # - docker 7 | # - kubernetes 8 | # - nike 9 | # - vision-pro 10 | 11 | # the prompts for each brand, adjust these to generate different images 12 | prompts: 13 | mug: "coffee mug with logo on it, in the style of " 14 | # mug2: "coffee mug with brand logo on it, in the style of " 15 | # mug3: "coffee mug with brand logo on it, in the style of , 50mm portrait photography, hard rim lighting photography, merchandise" 16 | # tshirt: "woman torso wearing tshirt with logo, 50mm portrait photography, hard rim lighting photography, merchandise" 17 | 18 | # how many images to generate for each prompt for each brand 19 | num_images: 10 20 | 21 | # the script expects zip files to download from "{url_prefix}/{brand}.zip", e.g. https://storage.googleapis.com/dagger-assets/coke.zip 22 | # the zip file must just contain the images to fine tune the model on, filenames don't seem to matter 23 | # **must have trailing slash** 24 | url_prefix: https://storage.googleapis.com/dagger-assets/ 25 | 26 | # you shouldn't need to change this (and note that doing so will break the 27 | # caching, see Dockerfile) 28 | model_name: runwayml/stable-diffusion-v1-5 29 | 30 | # container image built from the Dockerfile in this repo 31 | container_image: quay.io/lukemarsden/lora:v0.0.2 -------------------------------------------------------------------------------- /config_sdxl.yml: -------------------------------------------------------------------------------- 1 | # the name of the brands (short human readable name), corresponds to the 2 | # zipfiles the script will download 3 | brands: 4 | - for-sale-signs 5 | # - coke 6 | # - docker 7 | # - kubernetes 8 | # - nike 9 | # - vision-pro 10 | 11 | # the prompts for each brand, adjust these to generate different images 12 | # note with SDXL you should reference words and phrases you used in the captions for best effect 13 | prompts: 14 | img1: "cj hole for sale sign in front of a posh house with a tesla in winter with snow" 15 | img2: "cj hole sold sign in front of a council house with a vw beetle in spring with daffodills" 16 | img3: "cj hole for sale sign in front of a detached house in summer with a bbq in the front garden" 17 | img4: "cj hole for sale sign in front of a posh house, with spider webs and halloween decorations" 18 | img5: "cj hole for sale sign in front of a detached house with christmas decorations" 19 | 20 | # how many images to generate for each prompt for each brand 21 | num_images: 1 22 | 23 | # how much emphasis to place on the finetune set 24 | finetune_weighting: 0.8 25 | 26 | # the script expects zip files to download from "{url_prefix}/sdxl_{brand}.zip", e.g. https://storage.googleapis.com/dagger-assets/sdxl_coke.zip 27 | # the zip file must just contain the images to fine tune the model on, and for each image file foo.jpg a foo.txt file containing a descriptive caption of the image 28 | # you will then be able to reuse the language in the captions in the trained model when prompting it for inference 29 | # **must have trailing slash** 30 | url_prefix: https://storage.googleapis.com/dagger-assets/ -------------------------------------------------------------------------------- /mistral_lora.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import anyio 4 | import dagger 5 | import os 6 | import time 7 | import subprocess 8 | import urllib.request 9 | import zipfile 10 | import textwrap 11 | import yaml 12 | 13 | IMAGE = "quay.io/lukemarsden/axolotl:v0.0.1" 14 | PROMPT = "If I put up a hammock hung between opposite sides of a round lake, go to sleep in the hammock and fall out, where will I land?" 15 | 16 | async def main(): 17 | 18 | print("Spawning docker socket forwarder...") 19 | p = subprocess.Popen(["socat", "TCP-LISTEN:12345,reuseaddr,fork,bind=172.17.0.1", "UNIX-CONNECT:/var/run/docker.sock"]) 20 | time.sleep(1) 21 | print("Done!") 22 | 23 | config = dagger.Config(log_output=sys.stdout) 24 | 25 | async with dagger.Connection(config) as client: 26 | try: 27 | python = ( 28 | client 29 | .container() 30 | .from_("docker:latest") # TODO: use '@sha256:...' 31 | # break cache 32 | # .with_env_variable("BREAK_CACHE", str(time.time())) 33 | .with_entrypoint("/usr/local/bin/docker") 34 | .with_exec(["-H", "tcp://172.17.0.1:12345", 35 | "run", "-i", "--rm", "--gpus", "all", 36 | IMAGE, 37 | "bash", "-c", "echo "{PROMPT}" |python -u -m axolotl.cli.inference examples/mistral/qlora-instruct.yml", 38 | ]) 39 | ) 40 | # execute 41 | err = await python.stderr() 42 | out = await python.stdout() 43 | # print stderr 44 | print(f"Question: {PROMPT}\n\nAnswer: {out}") 45 | except Exception as e: 46 | import pdb; pdb.set_trace() 47 | print(f"error: {e}") 48 | 49 | p.terminate() 50 | 51 | anyio.run(main) 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dagger-io==0.6.2 2 | PyYAML==6.0 3 | -------------------------------------------------------------------------------- /sd15_lora.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import anyio 4 | import dagger 5 | import os 6 | import time 7 | import subprocess 8 | import urllib.request 9 | import zipfile 10 | import textwrap 11 | import yaml 12 | 13 | # Load from config.yml 14 | config = yaml.load(open("config.yml", "r"), Loader=yaml.FullLoader) 15 | 16 | MODEL_NAME = config.get("model_name", "runwayml/stable-diffusion-v1-5") 17 | IMAGE = config.get("container_image", "quay.io/lukemarsden/lora:v0.0.2") 18 | ASSETS = config.get("brands", [ 19 | "coke", 20 | "dagger", 21 | "docker", 22 | "kubernetes", 23 | "nike", 24 | "vision-pro", 25 | ]) 26 | PROMPTS = config.get("prompts", { 27 | "mug": "photograph of a coffee mug with logo on it, in the style of ", 28 | "mug2": "coffee mug with brand logo on it, in the style of ", 29 | "mug3": "coffee mug with brand logo on it, in the style of , 50mm portrait photography, hard rim lighting photography, merchandise", 30 | "tshirt": "woman torso wearing tshirt with logo, 50mm portrait photography, hard rim lighting photography, merchandise", 31 | }) 32 | NUM_IMAGES = config.get("num_images", 10) 33 | URL_PREFIX = config.get("url_prefix", "https://storage.googleapis.com/dagger-assets/") 34 | COEFF = config.get("finetune_weighting", 0.5) 35 | 36 | async def main(): 37 | 38 | print("Spawning docker socket forwarder...") 39 | p = subprocess.Popen(["socat", "TCP-LISTEN:12345,reuseaddr,fork,bind=172.17.0.1", "UNIX-CONNECT:/var/run/docker.sock"]) 40 | time.sleep(1) 41 | print("Done!") 42 | 43 | config = dagger.Config(log_output=sys.stdout) 44 | 45 | # create output directory on the host 46 | output_dir = os.path.join(os.getcwd(), "output") 47 | 48 | print("=============================") 49 | print(f"OUTPUT DIRECTORY: {output_dir}") 50 | print("=============================") 51 | os.makedirs(os.path.join(output_dir, "assets"), exist_ok=True) 52 | os.makedirs(os.path.join(output_dir, "downloads"), exist_ok=True) 53 | os.makedirs(os.path.join(output_dir, "loras"), exist_ok=True) 54 | os.makedirs(os.path.join(output_dir, "inference"), exist_ok=True) 55 | 56 | for brand in ASSETS: 57 | # http download storage.googleapis.com/dagger-assets/dagger.zip 58 | urllib.request.urlretrieve( 59 | URL_PREFIX + brand + ".zip", 60 | os.path.join(output_dir, "downloads", f"{brand}.zip"), 61 | ) 62 | # unzip with zipfile module 63 | with zipfile.ZipFile(os.path.join(output_dir, "downloads", f"{brand}.zip"), 'r') as zip_ref: 64 | zip_ref.extractall(os.path.join(output_dir, "assets")) 65 | 66 | # train the loras 67 | for brand in ASSETS: 68 | # initialize Dagger client - no parallelism here 69 | async with dagger.Connection(config) as client: 70 | # fine tune lora 71 | try: 72 | python = ( 73 | client 74 | .container() 75 | .from_("docker:latest") # TODO: use '@sha256:...' 76 | # break cache 77 | # .with_env_variable("BREAK_CACHE", str(time.time())) 78 | .with_entrypoint("/usr/local/bin/docker") 79 | .with_exec(["-H", "tcp://172.17.0.1:12345", 80 | "run", "-i", "--rm", "--gpus", "all", 81 | "-v", os.path.join(output_dir, "assets", brand)+":/input", 82 | "-v", os.path.join(output_dir, "loras", brand)+":/output", 83 | IMAGE, 84 | 'lora_pti', 85 | f'--pretrained_model_name_or_path={MODEL_NAME}', 86 | '--instance_data_dir=/input', '--output_dir=/output', 87 | '--train_text_encoder', '--resolution=512', 88 | '--train_batch_size=1', 89 | '--gradient_accumulation_steps=4', '--scale_lr', 90 | '--learning_rate_unet=1e-4', 91 | '--learning_rate_text=1e-5', '--learning_rate_ti=5e-4', 92 | '--color_jitter', '--lr_scheduler="linear"', 93 | '--lr_warmup_steps=0', 94 | '--placeholder_tokens="|"', 95 | '--use_template="style"', '--save_steps=100', 96 | '--max_train_steps_ti=1000', 97 | '--max_train_steps_tuning=1000', 98 | '--perform_inversion=True', '--clip_ti_decay', 99 | '--weight_decay_ti=0.000', '--weight_decay_lora=0.001', 100 | '--continue_inversion', '--continue_inversion_lr=1e-4', 101 | '--device="cuda:0"', '--lora_rank=1' 102 | ]) 103 | ) 104 | # execute 105 | err = await python.stderr() 106 | out = await python.stdout() 107 | # print stderr 108 | print(f"Hello from Dagger, fine tune LoRA on {brand}: {out}{err}") 109 | except Exception as e: 110 | import pdb; pdb.set_trace() 111 | print(f"error: {e}") 112 | 113 | async with dagger.Connection(config) as client: 114 | for brand in ASSETS: 115 | for key, prompt in PROMPTS.items(): 116 | for seed in range(NUM_IMAGES): 117 | # inference! 118 | python = ( 119 | client 120 | .container() 121 | .from_("docker:latest") 122 | .with_entrypoint("/usr/local/bin/docker") 123 | .with_exec(["-H", "tcp://172.17.0.1:12345", 124 | "run", "-i", "--rm", "--gpus", "all", 125 | "-v", os.path.join(output_dir, "loras", brand)+":/input", 126 | "-v", os.path.join(output_dir, "inference", brand)+":/output", 127 | IMAGE, 128 | 'python3', 129 | '-c', 130 | # dedent 131 | textwrap.dedent(f""" 132 | from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler 133 | import torch 134 | from lora_diffusion import tune_lora_scale, patch_pipe 135 | 136 | model_id = "{MODEL_NAME}" 137 | 138 | pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to( 139 | "cuda" 140 | ) 141 | pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) 142 | 143 | prompt = "{prompt}" 144 | seed = {seed} 145 | torch.manual_seed(seed) 146 | 147 | patch_pipe( 148 | pipe, 149 | "/input/final_lora.safetensors", 150 | patch_text=True, 151 | patch_ti=True, 152 | patch_unet=True, 153 | ) 154 | 155 | coeff = {COEFF} 156 | tune_lora_scale(pipe.unet, coeff) 157 | tune_lora_scale(pipe.text_encoder, coeff) 158 | 159 | image = pipe(prompt, num_inference_steps=50, guidance_scale=7).images[0] 160 | image.save(f"/output/{key}-{{seed}}.jpg") 161 | image 162 | """) 163 | ]) 164 | ) 165 | # execute 166 | err = await python.stderr() 167 | out = await python.stdout() 168 | # print stderr 169 | print(f"Hello from Dagger, inference {brand}, prompt: {prompt} and {out}{err}") 170 | 171 | p.terminate() 172 | 173 | anyio.run(main) 174 | -------------------------------------------------------------------------------- /sdxl_lora.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import anyio 4 | import dagger 5 | import os 6 | import time 7 | import subprocess 8 | import urllib.request 9 | import zipfile 10 | import textwrap 11 | import yaml 12 | 13 | # Load from config.yml 14 | config = yaml.load(open("config_sdxl.yml", "r"), Loader=yaml.FullLoader) 15 | 16 | IMAGE = config.get("container_image", "quay.io/lukemarsden/sd-scripts:v0.0.3") 17 | ASSETS = config.get("brands", [ 18 | # "coke", 19 | "dagger", 20 | # "docker", 21 | # "kubernetes", 22 | # "nike", 23 | # "vision-pro", 24 | ]) 25 | PROMPTS = config.get("prompts", { 26 | "mug": "coffee mug with dagger logo on it", 27 | "mug2": "coffee mug with astronauts on mars on it holding a map", 28 | "mug3": "coffee mug with dagger logo on it, 50mm portrait photography, hard rim lighting photography, merchandise", 29 | "tshirt": "woman torso wearing dagger logo tshirt, 50mm portrait photography, hard rim lighting photography, merchandise", 30 | }) 31 | NUM_IMAGES = config.get("num_images", 10) 32 | URL_PREFIX = config.get("url_prefix", "https://storage.googleapis.com/dagger-assets/") 33 | COEFF = config.get("finetune_weighting", 0.8) 34 | 35 | async def main(): 36 | 37 | print("Spawning docker socket forwarder...") 38 | p = subprocess.Popen(["socat", "TCP-LISTEN:12345,reuseaddr,fork,bind=172.17.0.1", "UNIX-CONNECT:/var/run/docker.sock"]) 39 | time.sleep(1) 40 | print("Done!") 41 | 42 | config = dagger.Config(log_output=sys.stdout) 43 | 44 | # create output directory on the host 45 | output_dir = os.path.join(os.getcwd(), "output") 46 | 47 | print("=============================") 48 | print(f"OUTPUT DIRECTORY: {output_dir}") 49 | print("=============================") 50 | os.makedirs(os.path.join(output_dir, "assets"), exist_ok=True) 51 | os.makedirs(os.path.join(output_dir, "downloads"), exist_ok=True) 52 | os.makedirs(os.path.join(output_dir, "loras"), exist_ok=True) 53 | os.makedirs(os.path.join(output_dir, "inference"), exist_ok=True) 54 | 55 | for brand in ASSETS: 56 | # http download storage.googleapis.com/dagger-assets/sdxl_dagger.zip (the sdxl_prefixed ones have .txt file captions in there) 57 | urllib.request.urlretrieve( 58 | URL_PREFIX + "sdxl_" + brand + ".zip", 59 | os.path.join(output_dir, "downloads", f"{brand}.zip"), 60 | ) 61 | # unzip with zipfile module 62 | with zipfile.ZipFile(os.path.join(output_dir, "downloads", f"{brand}.zip"), 'r') as zip_ref: 63 | zip_ref.extractall(os.path.join(output_dir, "assets")) 64 | 65 | open(os.path.join(output_dir, "config.toml"), "w").write("""[general] 66 | enable_bucket = true # Whether to use Aspect Ratio Bucketing 67 | 68 | [[datasets]] 69 | resolution = 1024 # Training resolution 70 | batch_size = 4 # Batch size 71 | 72 | [[datasets.subsets]] 73 | image_dir = '/input' # Specify the folder containing the training images 74 | caption_extension = '.txt' # Caption file extension; change this if using .txt 75 | num_repeats = 10 # Number of repetitions for training images 76 | """) 77 | 78 | # train the loras 79 | for brand in ASSETS: 80 | # initialize Dagger client - no parallelism here 81 | async with dagger.Connection(config) as client: 82 | # fine tune lora 83 | try: 84 | args = ["-H", "tcp://172.17.0.1:12345", 85 | "run", "-i", 86 | "--rm", "--gpus", "all", 87 | "-v", os.path.join(output_dir, "config.toml")+":/config.toml", 88 | "-v", os.path.join(output_dir, "assets", brand)+":/input", 89 | "-v", os.path.join(output_dir, "loras", brand)+":/output", 90 | IMAGE, 91 | 92 | "accelerate", "launch", "--num_cpu_threads_per_process", "1", "sdxl_train_network.py", 93 | "--pretrained_model_name_or_path=./sdxl/sd_xl_base_1.0.safetensors", 94 | "--dataset_config=/config.toml", 95 | "--output_dir=/output", 96 | "--output_name=lora", 97 | "--save_model_as=safetensors", 98 | "--prior_loss_weight=1.0", 99 | "--max_train_steps=400", 100 | "--vae=madebyollin/sdxl-vae-fp16-fix", 101 | "--learning_rate=1e-4", 102 | "--optimizer_type=AdamW8bit", 103 | "--xformers", 104 | "--mixed_precision=fp16", 105 | "--cache_latents", 106 | "--gradient_checkpointing", 107 | "--save_every_n_epochs=1", 108 | "--network_module=networks.lora", 109 | 110 | ] 111 | print("RUNNING:", " ".join(args)) 112 | python = ( 113 | client 114 | .container() 115 | .from_("docker:latest") # TODO: use '@sha256:...' 116 | # break cache 117 | .with_env_variable("BREAK_CACHE", brand) 118 | # .with_entrypoint("/usr/local/bin/docker") 119 | .with_entrypoint("/bin/sh") 120 | .with_exec(["-c", "docker " + " ".join(args)]) 121 | # .with_exec(args) 122 | ) 123 | # execute 124 | err = await python.stderr() 125 | out = await python.stdout() 126 | # print stderr 127 | print(f"Hello from Dagger, fine tune LoRA on {brand}: {out}{err}") 128 | except Exception as e: 129 | import pdb; pdb.set_trace() 130 | print(f"error: {e}") 131 | 132 | async with dagger.Connection(config) as client: 133 | for brand in ASSETS: 134 | for key, prompt in PROMPTS.items(): 135 | for seed in range(NUM_IMAGES): 136 | # inference! 137 | python = ( 138 | client 139 | .container() 140 | .from_("docker:latest") 141 | # .with_env_variable("BREAK_CACHE", str(time.time())) 142 | .with_entrypoint("/usr/local/bin/docker") 143 | .with_exec(["-H", "tcp://172.17.0.1:12345", 144 | "run", 145 | "-i", "--rm", "--gpus", "all", 146 | "-v", os.path.join(output_dir, "loras", brand)+":/input", 147 | "-v", os.path.join(output_dir, "inference", brand)+":/output", 148 | IMAGE, 149 | 150 | "accelerate", "launch", "--num_cpu_threads_per_process", "1", "sdxl_minimal_inference.py", 151 | "--ckpt_path=sdxl/sd_xl_base_1.0.safetensors", 152 | f'--lora_weights=/input/lora.safetensors;{COEFF}', 153 | f'--prompt={prompt}', 154 | "--output_dir=/output", 155 | ]) 156 | ) 157 | # execute 158 | err = await python.stderr() 159 | out = await python.stdout() 160 | # print stderr 161 | print(f"Hello from Dagger, inference {brand}, prompt: {prompt} and {out}{err}") 162 | 163 | p.terminate() 164 | 165 | anyio.run(main) 166 | --------------------------------------------------------------------------------