├── test_lora.png ├── example ├── github.jpg ├── person │ ├── 1.jpg │ ├── 1_mask.png │ ├── 00008_00.jpg │ ├── 00055_00.jpg │ ├── 00057_00.jpg │ ├── 00064_00.jpg │ ├── 00067_00.jpg │ ├── 00069_00.jpg │ ├── 00008_00_mask.png │ ├── 00055_00_mask.png │ ├── 00057_00_mask.png │ ├── 00064_00_mask.png │ ├── 00067_00_mask.png │ └── 00069_00_mask.png ├── result │ ├── 1.png │ ├── 2.png │ └── 3.png ├── garment │ ├── 00034_00.jpg │ ├── 00035_00.jpg │ ├── 00055_00.jpg │ ├── 00057_00.jpg │ ├── 00064_00.jpg │ ├── 00067_00.jpg │ ├── 00069_00.jpg │ ├── 00396_00.jpg │ └── 04564_00.jpg └── tryoff_result │ ├── restored_garment1.png │ ├── restored_garment2.png │ ├── restored_garment3.png │ ├── restored_garment4.png │ ├── restored_garment5.png │ └── restored_garment6.png ├── .gitattributes ├── tryon.sh ├── tryoff.sh ├── requirements.txt ├── .github └── workflows │ └── main.yml ├── accelerate_config.yaml ├── .gitignore ├── LICENSE ├── cog.yaml ├── train_flux_inpaint.sh ├── script └── fid_eval.py ├── .gradio └── certificate.pem ├── tryoff_inference.py ├── tryon_inference.py ├── TrainingNotes.md ├── predict.py ├── tryon_inference_lora.py ├── image_datasets └── cp_dataset.py ├── app_lora.py ├── app_no_lora.py ├── app.py ├── README.md ├── src └── flux │ └── train_utils.py ├── paser_helper.py ├── LICENSE-MODEL └── train_flux_inpaint.py /test_lora.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/test_lora.png -------------------------------------------------------------------------------- /example/github.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/github.jpg -------------------------------------------------------------------------------- /example/person/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/person/1.jpg -------------------------------------------------------------------------------- /example/result/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/result/1.png -------------------------------------------------------------------------------- /example/result/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/result/2.png -------------------------------------------------------------------------------- /example/result/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/result/3.png -------------------------------------------------------------------------------- /example/person/1_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/person/1_mask.png -------------------------------------------------------------------------------- /example/garment/00034_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/garment/00034_00.jpg -------------------------------------------------------------------------------- /example/garment/00035_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/garment/00035_00.jpg -------------------------------------------------------------------------------- /example/garment/00055_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/garment/00055_00.jpg -------------------------------------------------------------------------------- /example/garment/00057_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/garment/00057_00.jpg -------------------------------------------------------------------------------- /example/garment/00064_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/garment/00064_00.jpg -------------------------------------------------------------------------------- /example/garment/00067_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/garment/00067_00.jpg -------------------------------------------------------------------------------- /example/garment/00069_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/garment/00069_00.jpg -------------------------------------------------------------------------------- /example/garment/00396_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/garment/00396_00.jpg -------------------------------------------------------------------------------- /example/garment/04564_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/garment/04564_00.jpg -------------------------------------------------------------------------------- /example/person/00008_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/person/00008_00.jpg -------------------------------------------------------------------------------- /example/person/00055_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/person/00055_00.jpg -------------------------------------------------------------------------------- /example/person/00057_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/person/00057_00.jpg -------------------------------------------------------------------------------- /example/person/00064_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/person/00064_00.jpg -------------------------------------------------------------------------------- /example/person/00067_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/person/00067_00.jpg -------------------------------------------------------------------------------- /example/person/00069_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/person/00069_00.jpg -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | example/github.mp4 filter=lfs diff=lfs merge=lfs -text 2 | *.mp4 filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /example/person/00008_00_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/person/00008_00_mask.png -------------------------------------------------------------------------------- /example/person/00055_00_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/person/00055_00_mask.png -------------------------------------------------------------------------------- /example/person/00057_00_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/person/00057_00_mask.png -------------------------------------------------------------------------------- /example/person/00064_00_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/person/00064_00_mask.png -------------------------------------------------------------------------------- /example/person/00067_00_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/person/00067_00_mask.png -------------------------------------------------------------------------------- /example/person/00069_00_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/person/00069_00_mask.png -------------------------------------------------------------------------------- /example/tryoff_result/restored_garment1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/tryoff_result/restored_garment1.png -------------------------------------------------------------------------------- /example/tryoff_result/restored_garment2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/tryoff_result/restored_garment2.png -------------------------------------------------------------------------------- /example/tryoff_result/restored_garment3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/tryoff_result/restored_garment3.png -------------------------------------------------------------------------------- /example/tryoff_result/restored_garment4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/tryoff_result/restored_garment4.png -------------------------------------------------------------------------------- /example/tryoff_result/restored_garment5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/tryoff_result/restored_garment5.png -------------------------------------------------------------------------------- /example/tryoff_result/restored_garment6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nftblackmagic/catvton-flux/HEAD/example/tryoff_result/restored_garment6.png -------------------------------------------------------------------------------- /tryon.sh: -------------------------------------------------------------------------------- 1 | python tryon_inference.py \ 2 | --image ./example/person/00008_00.jpg \ 3 | --mask ./example/person/00008_00_mask.png \ 4 | --garment ./example/garment/00034_00.jpg \ 5 | --seed 42 \ 6 | --output_tryon test.png \ 7 | --steps 30 -------------------------------------------------------------------------------- /tryoff.sh: -------------------------------------------------------------------------------- 1 | python tryoff_inference.py \ 2 | --image ./example/person/00069_00.jpg \ 3 | --mask ./example/person/00069_00_mask.png \ 4 | --seed 41 \ 5 | --output_tryon test_original.png \ 6 | --output_garment restored_garment6.png \ 7 | --steps 30 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | git+https://github.com/huggingface/diffusers.git 3 | gradio==5.6.0 4 | gradio_client==1.4.3 5 | tqdm==4.66.5 6 | transformers==4.43.3 7 | numpy==1.26.4 8 | peft==0.13.2 9 | huggingface-hub 10 | spaces 11 | protobuf 12 | torch==2.5.1 13 | torchaudio==2.5.1 14 | torchvision==0.20.1 15 | bitsandbytes==0.44.1 16 | sentencepiece==0.2.0 17 | deepspeed 18 | pandas 19 | wandb 20 | prodigyopt 21 | supervision 22 | pytest 23 | einops 24 | timm 25 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Sync to Hugging Face hub 2 | on: 3 | push: 4 | branches: [main] 5 | 6 | # to run this workflow manually from the Actions tab 7 | workflow_dispatch: 8 | 9 | jobs: 10 | sync-to-hub: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | with: 15 | fetch-depth: 0 16 | lfs: true 17 | - name: Push to hub 18 | env: 19 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 20 | run: git push https://xiaozaa:$HF_TOKEN@huggingface.co/spaces/xiaozaa/catvton-flux-try-on main -------------------------------------------------------------------------------- /accelerate_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | gradient_clipping: 1.0 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | enable_cpu_affinity: false 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 2 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | dist/ 8 | build/ 9 | *.egg-info/ 10 | 11 | # Virtual environments 12 | venv/ 13 | env/ 14 | .env/ 15 | .venv/ 16 | 17 | # IDE specific files 18 | .idea/ 19 | .vscode/ 20 | *.swp 21 | *.swo 22 | 23 | # Unit test / coverage reports 24 | htmlcov/ 25 | .tox/ 26 | .coverage 27 | .coverage.* 28 | coverage.xml 29 | *.cover 30 | 31 | # Jupyter Notebook 32 | .ipynb_checkpoints 33 | 34 | # Local development settings 35 | .env 36 | .env.local 37 | 38 | # Logs 39 | *.log 40 | 41 | # Database files 42 | *.db 43 | *.sqlite3 44 | 45 | # OS generated files 46 | .DS_Store 47 | .DS_Store? 48 | ._* 49 | .Spotlight-V100 50 | .Trashes 51 | ehthumbs.db 52 | Thumbs.db 53 | 54 | # Gradio cache 55 | .gradio/example/github.mp4 56 | 57 | aws/ 58 | checkpoints/ 59 | 60 | .cog/ 61 | 62 | wandb/ 63 | trained-flux-inpaint* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 nftblackmagic 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://cog.run/yaml 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: true 7 | 8 | # a list of ubuntu apt packages to install 9 | system_packages: 10 | - "libgl1-mesa-glx" 11 | - "libglib2.0-0" 12 | 13 | # python version in the form '3.11' or '3.11.4' 14 | python_version: "3.11" 15 | 16 | # a list of packages in the format == 17 | python_packages: 18 | - torch==2.4.0 19 | - transformers==4.43.3 20 | - datasets==2.20.0 21 | - accelerate==1.3.0 22 | - jupyter==1.0.0 23 | - numpy==1.26.4 24 | - pillow==10.2.0 25 | - peft==0.13.2 26 | - diffusers>=0.32.0 27 | - timm==0.9.16 28 | - torchvision==0.19.0 29 | - tqdm==4.66.5 30 | - numpy==1.26.4 31 | - sentencepiece 32 | - protobuf 33 | 34 | # commands run after the environment is setup 35 | run: 36 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget 37 | 38 | # predict.py defines how predictions are run on your model 39 | predict: "predict.py:Predictor" -------------------------------------------------------------------------------- /train_flux_inpaint.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="black-forest-labs/FLUX.1-dev" 2 | export INSTANCE_DIR="dog" 3 | export OUTPUT_DIR="trained-flux-inpaint" 4 | 5 | accelerate launch --config_file accelerate_config.yaml train_flux_inpaint.py \ 6 | --pretrained_model_name_or_path=$MODEL_NAME \ 7 | --pretrained_inpaint_model_name_or_path="xiaozaa/flux1-fill-dev-diffusers" \ 8 | --instance_data_dir=$INSTANCE_DIR \ 9 | --output_dir=$OUTPUT_DIR \ 10 | --mixed_precision="bf16" \ 11 | --train_batch_size=1 \ 12 | --guidance_scale=1 \ 13 | --gradient_accumulation_steps=8 \ 14 | --optimizer="adamw" \ 15 | --use_8bit_adam \ 16 | --learning_rate=2e-5 \ 17 | --lr_scheduler="constant" \ 18 | --lr_warmup_steps=0 \ 19 | --max_train_steps=100000 \ 20 | --validation_epochs=2500 \ 21 | --validation_steps=500 \ 22 | --seed="42" \ 23 | --dataroot="../data/VITON-HD" \ # Adjust the path to your dataset 24 | --train_data_list="train_pairs.txt" \ # Adjust the txt file to your train data list 25 | --train_verification_list="subtrain_1.txt" \ # Adjust the txt file to your train verification list 26 | --validation_data_list="subtest_1.txt" \ # Adjust the txt file to your validation data list 27 | --height=768 \ 28 | --width=576 \ 29 | --max_sequence_length=512 \ 30 | --checkpointing_steps=1000 \ 31 | --report_to="wandb" \ 32 | --train_base_model \ 33 | # --resume_from_checkpoint="latest" \ -------------------------------------------------------------------------------- /script/fid_eval.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import numpy as np 4 | from torchvision.transforms import functional as F 5 | import torch 6 | from torchmetrics.image.fid import FrechetInceptionDistance 7 | 8 | 9 | # Paths setup 10 | generated_dataset_path = "output/tryon_results" 11 | original_dataset_path = "data/VITON-HD/test/image" # Replace with your actual original dataset path 12 | 13 | # Get generated images 14 | image_paths = sorted([os.path.join(generated_dataset_path, x) for x in os.listdir(generated_dataset_path)]) 15 | generated_images = [np.array(Image.open(path).convert("RGB")) for path in image_paths] 16 | 17 | # Get corresponding original images 18 | original_images = [] 19 | for gen_path in image_paths: 20 | # Extract the XXXXXX part from "tryon_XXXXXX.jpg" 21 | base_name = os.path.basename(gen_path) # get filename from path 22 | original_id = base_name.replace("tryon_", "") # remove "tryon_" prefix 23 | 24 | # Construct original image path 25 | original_path = os.path.join(original_dataset_path, original_id) 26 | original_images.append(np.array(Image.open(original_path).convert("RGB"))) 27 | 28 | 29 | 30 | def preprocess_image(image): 31 | image = torch.tensor(image).unsqueeze(0) 32 | image = image.permute(0, 3, 1, 2) / 255.0 33 | return F.center_crop(image, (768, 1024)) 34 | 35 | real_images = torch.cat([preprocess_image(image) for image in original_images]) 36 | fake_images = torch.cat([preprocess_image(image) for image in generated_images]) 37 | print(real_images.shape, fake_images.shape) 38 | 39 | fid = FrechetInceptionDistance(normalize=True) 40 | fid.update(real_images, real=True) 41 | fid.update(fake_images, real=False) 42 | 43 | print(f"FID: {float(fid.compute())}") -------------------------------------------------------------------------------- /.gradio/certificate.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw 3 | TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh 4 | cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 5 | WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu 6 | ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY 7 | MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc 8 | h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ 9 | 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U 10 | A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW 11 | T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH 12 | B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC 13 | B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv 14 | KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn 15 | OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn 16 | jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw 17 | qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI 18 | rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV 19 | HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq 20 | hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL 21 | ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ 22 | 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK 23 | NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 24 | ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur 25 | TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC 26 | jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc 27 | oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq 28 | 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA 29 | mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d 30 | emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= 31 | -----END CERTIFICATE----- 32 | -------------------------------------------------------------------------------- /tryoff_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from diffusers.utils import load_image, check_min_version 4 | from diffusers import FluxPriorReduxPipeline, FluxFillPipeline 5 | from diffusers import FluxTransformer2DModel 6 | import numpy as np 7 | from torchvision import transforms 8 | 9 | def run_inference( 10 | image_path, 11 | mask_path, 12 | size=(576, 768), 13 | num_steps=50, 14 | guidance_scale=30, 15 | seed=42, 16 | pipe=None 17 | ): 18 | # Build pipeline 19 | if pipe is None: 20 | transformer = FluxTransformer2DModel.from_pretrained( 21 | "xiaozaa/cat-tryoff-flux", 22 | torch_dtype=torch.bfloat16 23 | ) 24 | pipe = FluxFillPipeline.from_pretrained( 25 | "black-forest-labs/FLUX.1-dev", 26 | transformer=transformer, 27 | torch_dtype=torch.bfloat16 28 | ).to("cuda") 29 | else: 30 | pipe.to("cuda") 31 | 32 | pipe.transformer.to(torch.bfloat16) 33 | 34 | # Add transform 35 | transform = transforms.Compose([ 36 | transforms.ToTensor(), 37 | transforms.Normalize([0.5], [0.5]) # For RGB images 38 | ]) 39 | mask_transform = transforms.Compose([ 40 | transforms.ToTensor() 41 | ]) 42 | 43 | # Load and process images 44 | # print("image_path", image_path) 45 | image = load_image(image_path).convert("RGB").resize(size) 46 | mask = load_image(mask_path).convert("RGB").resize(size) 47 | 48 | # Transform images using the new preprocessing 49 | image_tensor = transform(image) 50 | mask_tensor = mask_transform(mask)[:1] # Take only first channel 51 | garment_tensor = torch.zeros_like(image_tensor) 52 | image_tensor = image_tensor * mask_tensor 53 | 54 | # Create concatenated images 55 | inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width 56 | garment_mask = torch.zeros_like(mask_tensor) 57 | extended_mask = torch.cat([1 - garment_mask, garment_mask], dim=2) 58 | 59 | prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \ 60 | f"[IMAGE1] Detailed product shot of a clothing" \ 61 | f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting." 62 | 63 | generator = torch.Generator(device="cuda").manual_seed(seed) 64 | 65 | result = pipe( 66 | height=size[1], 67 | width=size[0] * 2, 68 | image=inpaint_image, 69 | mask_image=extended_mask, 70 | num_inference_steps=num_steps, 71 | generator=generator, 72 | max_sequence_length=512, 73 | guidance_scale=guidance_scale, 74 | prompt=prompt, 75 | ).images[0] 76 | 77 | # Split and save results 78 | width = size[0] 79 | garment_result = result.crop((0, 0, width, size[1])) 80 | tryon_result = result.crop((width, 0, width * 2, size[1])) 81 | 82 | return garment_result, tryon_result 83 | 84 | def main(): 85 | parser = argparse.ArgumentParser(description='Run FLUX virtual try-on inference') 86 | parser.add_argument('--image', required=True, help='Path to the model image') 87 | parser.add_argument('--mask', required=True, help='Path to the agnostic mask') 88 | parser.add_argument('--output_garment', default='flux_inpaint_garment.png', help='Output path for garment result') 89 | parser.add_argument('--output_tryon', default='flux_inpaint_tryon.png', help='Output path for try-on result') 90 | parser.add_argument('--steps', type=int, default=50, help='Number of inference steps') 91 | parser.add_argument('--guidance_scale', type=float, default=30, help='Guidance scale') 92 | parser.add_argument('--seed', type=int, default=0, help='Random seed') 93 | parser.add_argument('--width', type=int, default=576, help='Width') 94 | parser.add_argument('--height', type=int, default=768, help='Height') 95 | 96 | args = parser.parse_args() 97 | 98 | check_min_version("0.30.2") 99 | 100 | garment_result, tryon_result = run_inference( 101 | image_path=args.image, 102 | mask_path=args.mask, 103 | num_steps=args.steps, 104 | guidance_scale=args.guidance_scale, 105 | seed=args.seed, 106 | size=(args.width, args.height) 107 | ) 108 | output_tryon_path=args.output_tryon 109 | output_garment_path=args.output_garment 110 | 111 | tryon_result.save(output_tryon_path) 112 | garment_result.save(output_garment_path) 113 | 114 | print("Successfully saved garment and try-on images") 115 | 116 | if __name__ == "__main__": 117 | main() -------------------------------------------------------------------------------- /tryon_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from diffusers.utils import load_image, check_min_version 4 | from diffusers import FluxPriorReduxPipeline, FluxFillPipeline 5 | from diffusers import FluxTransformer2DModel 6 | import numpy as np 7 | from torchvision import transforms 8 | 9 | def run_inference( 10 | image_path, 11 | mask_path, 12 | garment_path, 13 | size=(576, 768), 14 | num_steps=50, 15 | guidance_scale=30, 16 | seed=42, 17 | pipe=None 18 | ): 19 | # Build pipeline 20 | if pipe is None: 21 | transformer = FluxTransformer2DModel.from_pretrained( 22 | "xiaozaa/catvton-flux-alpha", 23 | torch_dtype=torch.bfloat16 24 | ) 25 | pipe = FluxFillPipeline.from_pretrained( 26 | "black-forest-labs/FLUX.1-dev", 27 | transformer=transformer, 28 | torch_dtype=torch.bfloat16 29 | ).to("cuda") 30 | else: 31 | pipe.to("cuda") 32 | 33 | pipe.transformer.to(torch.bfloat16) 34 | 35 | # Add transform 36 | transform = transforms.Compose([ 37 | transforms.ToTensor(), 38 | transforms.Normalize([0.5], [0.5]) # For RGB images 39 | ]) 40 | mask_transform = transforms.Compose([ 41 | transforms.ToTensor() 42 | ]) 43 | 44 | # Load and process images 45 | # print("image_path", image_path) 46 | image = load_image(image_path).convert("RGB").resize(size) 47 | mask = load_image(mask_path).convert("RGB").resize(size) 48 | garment = load_image(garment_path).convert("RGB").resize(size) 49 | 50 | # Transform images using the new preprocessing 51 | image_tensor = transform(image) 52 | mask_tensor = mask_transform(mask)[:1] # Take only first channel 53 | garment_tensor = transform(garment) 54 | 55 | # Create concatenated images 56 | inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width 57 | garment_mask = torch.zeros_like(mask_tensor) 58 | extended_mask = torch.cat([garment_mask, mask_tensor], dim=2) 59 | 60 | prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \ 61 | f"[IMAGE1] Detailed product shot of a clothing" \ 62 | f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting." 63 | 64 | generator = torch.Generator(device="cuda").manual_seed(seed) 65 | 66 | result = pipe( 67 | height=size[1], 68 | width=size[0] * 2, 69 | image=inpaint_image, 70 | mask_image=extended_mask, 71 | num_inference_steps=num_steps, 72 | generator=generator, 73 | max_sequence_length=512, 74 | guidance_scale=guidance_scale, 75 | prompt=prompt, 76 | ).images[0] 77 | 78 | # Split and save results 79 | width = size[0] 80 | garment_result = result.crop((0, 0, width, size[1])) 81 | tryon_result = result.crop((width, 0, width * 2, size[1])) 82 | 83 | return garment_result, tryon_result 84 | 85 | def main(): 86 | parser = argparse.ArgumentParser(description='Run FLUX virtual try-on inference') 87 | parser.add_argument('--image', required=True, help='Path to the model image') 88 | parser.add_argument('--mask', required=True, help='Path to the agnostic mask') 89 | parser.add_argument('--garment', required=True, help='Path to the garment image') 90 | parser.add_argument('--output_garment', default='flux_inpaint_garment.png', help='Output path for garment result') 91 | parser.add_argument('--output_tryon', default='flux_inpaint_tryon.png', help='Output path for try-on result') 92 | parser.add_argument('--steps', type=int, default=50, help='Number of inference steps') 93 | parser.add_argument('--guidance_scale', type=float, default=30, help='Guidance scale') 94 | parser.add_argument('--seed', type=int, default=0, help='Random seed') 95 | parser.add_argument('--width', type=int, default=576, help='Width') 96 | parser.add_argument('--height', type=int, default=768, help='Height') 97 | 98 | args = parser.parse_args() 99 | 100 | check_min_version("0.30.2") 101 | 102 | garment_result, tryon_result = run_inference( 103 | image_path=args.image, 104 | mask_path=args.mask, 105 | garment_path=args.garment, 106 | num_steps=args.steps, 107 | guidance_scale=args.guidance_scale, 108 | seed=args.seed, 109 | size=(args.width, args.height) 110 | ) 111 | output_tryon_path=args.output_tryon 112 | 113 | tryon_result.save(output_tryon_path) 114 | 115 | print("Successfully saved garment and try-on images") 116 | 117 | if __name__ == "__main__": 118 | main() -------------------------------------------------------------------------------- /TrainingNotes.md: -------------------------------------------------------------------------------- 1 | # Mask is Important 2 | 3 | About two months ago, when flux fill was just released, we open-sourced a VTON-based virtual try-on model that achieved promising results. In the following period, we conducted numerous experiments and trials. We want to document these training experiences to provide reference for others who might need it in the future. 4 | 5 | ## Thoughts on Flux Fill 6 | 7 | To be honest, flux fill isn't really a mysterious method. Interestingly, even without training, flux fill could already achieve quite good results in most cases. Our training work was more like solving the "last mile" problem. In practice, we found that basically all training could achieve good results with just 5000 steps using batch size=1 and learning rate=1e-5. This made us wonder: how significant was our training work really? 8 | 9 | ## Comparison between Fine-tuning and LoRA 10 | 11 | During our experiments, we found that there were notable differences between fine-tuning and LoRA in the final generated images. Although LoRA could accomplish some virtual try-on tasks, it didn't perform as well as fine-tuning when dealing with complex garments, especially in preserving details like text. 12 | 13 | ## The Importance of Mask Processing 14 | 15 | When training with VTON, we used the pre-processed inpaint mask regions from the VTON dataset and achieved good results. However, when we tried other datasets, problems emerged. Without mature segmentation methods like VTON's, we discovered that mask selection had a huge impact on the final results. 16 | 17 | We tried using SAM2 for garment segmentation, but the results weren't ideal. The trained model would stick too closely to the mask shape, leading to a serious problem: long sleeves could only become long sleeves, and short sleeves could only become short sleeves. Worse still, hand-drawn masks would cause severe errors in the final generated images. 18 | 19 | After repeated experiments, we realized a key point: the mask preprocessing area needs to be as large as possible. The garment mask shouldn't just frame the garment itself but needs to leave enough drawing space for different garment replacements. However, this brought a new problem: if the redrawing area was too large, it would lead to unstable training. 20 | 21 | To solve this problem, we ultimately adopted a combination of OpenPose and SAM2 to redraw human limbs. We paid special attention to ensuring that the drawn mask completely covered the limbs, which was done to counteract the influence of different garment styles. For example, for short-sleeve garments, we needed to ensure their masks showed no trace of whether they were short-sleeve or long-sleeve. Because if the mask itself contained short-sleeve information, the model would tend to generate short sleeves based on this information, affecting try-on accuracy. 22 | 23 | This principle also applies to bottom wear processing. For the same model, whether it's long pants, shorts, or skirts, their masks should remain basically consistent. If you can tell from the mask whether it's a skirt or pants, then the final generated garment form would be dominated by the mask rather than determined by the garment itself. This is why masks need to be as general as possible, trying to hide all garment form information. 24 | 25 | Of course, this approach needs to be well-balanced. The mask shouldn't be so large that it needs to generate unnecessary information, such as faces or backgrounds. This requires continuous adjustment and balance in practice. 26 | 27 | ## Importance of Datasets 28 | 29 | We experimented with many datasets and found that VTON and DressCode are excellent datasets. If a dataset's accuracy doesn't reach the level of these two datasets, using it for training would be devastating for the model. You'll find that the entire model's precision drops off a cliff, and flux fill's own capability gets severely reduced as well. Therefore, if anyone wants to use specific training sets for training, they must pay attention to the accuracy of data preprocessing. 30 | 31 | ## Limitations of Current Methods 32 | 33 | Although flux fill can solve many virtual try-on problems, its handling of complex patterns hasn't reached our expected level. For example, consider this dress with intricate floral patterns: 34 | 35 | ![053417_1](https://github.com/user-attachments/assets/a29b7355-f795-4b9f-b7df-04323a55825a) 36 | 37 | 38 | This type of garment, with its dense, small-scale repeating patterns, is practically a nightmare for all virtual try-on technologies. The complexity and detail of these tiny geometric patterns pose a significant challenge for current methods. Currently, none of the virtual try-on tools in the market can successfully restore these delicate patterns to a highly satisfactory level. The difficulty lies in both preserving the pattern's structure and ensuring its consistent application across the transformed garment. This problem remains a technical challenge that the entire field needs to continue tackling. 39 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from cog import BasePredictor, Input, Path, Secret 4 | from diffusers.utils import load_image 5 | from diffusers import FluxFillPipeline 6 | from diffusers import FluxTransformer2DModel 7 | import torch 8 | from torchvision import transforms 9 | 10 | class Predictor(BasePredictor): 11 | def setup(self) -> None: 12 | """Load part of the model into memory to make running multiple predictions efficient""" 13 | self.try_on_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/catvton-flux-beta", 14 | torch_dtype=torch.bfloat16) 15 | self.try_off_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/cat-tryoff-flux", 16 | torch_dtype=torch.bfloat16) 17 | 18 | def predict(self, 19 | hf_token: Secret = Input(description="Hugging Face API token. Create a write token at https://huggingface.co/settings/token. You also need to approve the Flux Dev terms."), 20 | image: Path = Input(description="Image file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/person/1.jpg"), 21 | mask: Path = Input(description="Mask file path", default="https://github.com/nftblackmagic/catvton-flux/blob/main/example/person/1_mask.png?raw=true"), 22 | try_on: bool = Input(False, description="Try on or try off"), 23 | garment: Path = Input(description="Garment file path like https://github.com/nftblackmagic/catvton-flux/raw/main/example/garment/00035_00.jpg", default=None), 24 | num_steps: int = Input(50, description="Number of steps to run the model for"), 25 | guidance_scale: float = Input(30, description="Guidance scale for the model"), 26 | seed: int = Input(0, description="Seed for the model"), 27 | width: int = Input(576, description="Width of the output image"), 28 | height: int = Input(768, description="Height of the output image")) -> List[Path]: 29 | 30 | size = (width, height) 31 | i = load_image(str(image)).convert("RGB").resize(size) 32 | m = load_image(str(mask)).convert("RGB").resize(size) 33 | 34 | if try_on: 35 | g = load_image(str(garment)).convert("RGB").resize(size) 36 | self.transformer = self.try_on_transformer 37 | else: 38 | self.transformer = self.try_off_transformer 39 | 40 | self.pipe = FluxFillPipeline.from_pretrained( 41 | "black-forest-labs/FLUX.1-dev", 42 | transformer=self.transformer, 43 | torch_dtype=torch.bfloat16, 44 | token=hf_token.get_secret_value() 45 | ).to("cuda") 46 | 47 | self.pipe.transformer.to(torch.bfloat16) 48 | 49 | transform = transforms.Compose([ 50 | transforms.ToTensor(), 51 | transforms.Normalize([0.5], [0.5]) # For RGB images 52 | ]) 53 | mask_transform = transforms.Compose([ 54 | transforms.ToTensor() 55 | ]) 56 | 57 | # Transform images using the new preprocessing 58 | image_tensor = transform(i) 59 | mask_tensor = mask_transform(m)[:1] # Take only first channel 60 | if try_on: 61 | garment_tensor = transform(g) 62 | else: 63 | garment_tensor = torch.zeros_like(image_tensor) 64 | image_tensor = image_tensor * mask_tensor 65 | 66 | # Create concatenated images 67 | inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width 68 | garment_mask = torch.zeros_like(mask_tensor) 69 | 70 | if try_on: 71 | extended_mask = torch.cat([garment_mask, mask_tensor], dim=2) 72 | else: 73 | extended_mask = torch.cat([1 - garment_mask, garment_mask], dim=2) 74 | 75 | prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \ 76 | f"[IMAGE1] Detailed product shot of a clothing" \ 77 | f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting." 78 | 79 | generator = torch.Generator(device="cuda").manual_seed(seed) 80 | result = self.pipe( 81 | height=size[1], 82 | width=size[0] * 2, 83 | image=inpaint_image, 84 | mask_image=extended_mask, 85 | num_inference_steps=num_steps, 86 | generator=generator, 87 | max_sequence_length=512, 88 | guidance_scale=guidance_scale, 89 | prompt=prompt, 90 | ).images[0] 91 | 92 | # Split and save results 93 | width = size[0] 94 | garment_result = result.crop((0, 0, width, size[1])) 95 | try_result = result.crop((width, 0, width * 2, size[1])) 96 | out_path = "/tmp/try.png" 97 | try_result.save(out_path) 98 | garm_out_path = "/tmp/garment.png" 99 | garment_result.save(garm_out_path) 100 | return [Path(out_path), Path(garm_out_path)] 101 | -------------------------------------------------------------------------------- /tryon_inference_lora.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from diffusers.utils import load_image, check_min_version 4 | from diffusers import FluxPriorReduxPipeline, FluxFillPipeline 5 | from diffusers import FluxTransformer2DModel 6 | import numpy as np 7 | from torchvision import transforms 8 | 9 | def run_inference( 10 | image_path, 11 | mask_path, 12 | garment_path, 13 | size=(576, 768), 14 | num_steps=50, 15 | guidance_scale=30, 16 | seed=42, 17 | pipe=None 18 | ): 19 | # Build pipeline 20 | if pipe is None: 21 | transformer = FluxTransformer2DModel.from_pretrained( 22 | "xiaozaa/flux1-fill-dev-diffusers", ## The official Flux-Fill weights 23 | torch_dtype=torch.bfloat16 24 | ) 25 | print("Start loading LoRA weights") 26 | state_dict, network_alphas = FluxFillPipeline.lora_state_dict( 27 | pretrained_model_name_or_path_or_dict="xiaozaa/catvton-flux-lora-alpha", ## The tryon Lora weights 28 | weight_name="pytorch_lora_weights.safetensors", 29 | return_alphas=True 30 | ) 31 | is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) 32 | if not is_correct_format: 33 | raise ValueError("Invalid LoRA checkpoint.") 34 | 35 | FluxFillPipeline.load_lora_into_transformer( 36 | state_dict=state_dict, 37 | network_alphas=network_alphas, 38 | transformer=transformer, 39 | ) 40 | 41 | pipe = FluxFillPipeline.from_pretrained( 42 | "black-forest-labs/FLUX.1-dev", 43 | transformer=transformer, 44 | torch_dtype=torch.bfloat16 45 | ).to("cuda") 46 | else: 47 | pipe.to("cuda") 48 | 49 | pipe.transformer.to(torch.bfloat16) 50 | 51 | # Add transform 52 | transform = transforms.Compose([ 53 | transforms.ToTensor(), 54 | transforms.Normalize([0.5], [0.5]) # For RGB images 55 | ]) 56 | mask_transform = transforms.Compose([ 57 | transforms.ToTensor() 58 | ]) 59 | 60 | # Load and process images 61 | # print("image_path", image_path) 62 | image = load_image(image_path).convert("RGB").resize(size) 63 | mask = load_image(mask_path).convert("RGB").resize(size) 64 | garment = load_image(garment_path).convert("RGB").resize(size) 65 | 66 | # Transform images using the new preprocessing 67 | image_tensor = transform(image) 68 | mask_tensor = mask_transform(mask)[:1] # Take only first channel 69 | garment_tensor = transform(garment) 70 | 71 | # Create concatenated images 72 | inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width 73 | garment_mask = torch.zeros_like(mask_tensor) 74 | extended_mask = torch.cat([garment_mask, mask_tensor], dim=2) 75 | 76 | prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \ 77 | f"[IMAGE1] Detailed product shot of a clothing" \ 78 | f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting." 79 | 80 | generator = torch.Generator(device="cuda").manual_seed(seed) 81 | 82 | result = pipe( 83 | height=size[1], 84 | width=size[0] * 2, 85 | image=inpaint_image, 86 | mask_image=extended_mask, 87 | num_inference_steps=num_steps, 88 | generator=generator, 89 | max_sequence_length=512, 90 | guidance_scale=guidance_scale, 91 | prompt=prompt, 92 | ).images[0] 93 | 94 | # Split and save results 95 | width = size[0] 96 | garment_result = result.crop((0, 0, width, size[1])) 97 | tryon_result = result.crop((width, 0, width * 2, size[1])) 98 | 99 | return garment_result, tryon_result 100 | 101 | def main(): 102 | parser = argparse.ArgumentParser(description='Run FLUX virtual try-on inference') 103 | parser.add_argument('--image', required=True, help='Path to the model image') 104 | parser.add_argument('--mask', required=True, help='Path to the agnostic mask') 105 | parser.add_argument('--garment', required=True, help='Path to the garment image') 106 | parser.add_argument('--output_garment', default='flux_inpaint_garment.png', help='Output path for garment result') 107 | parser.add_argument('--output_tryon', default='flux_inpaint_tryon.png', help='Output path for try-on result') 108 | parser.add_argument('--steps', type=int, default=50, help='Number of inference steps') 109 | parser.add_argument('--guidance_scale', type=float, default=30, help='Guidance scale') 110 | parser.add_argument('--seed', type=int, default=0, help='Random seed') 111 | parser.add_argument('--width', type=int, default=576, help='Width') 112 | parser.add_argument('--height', type=int, default=768, help='Height') 113 | 114 | args = parser.parse_args() 115 | 116 | check_min_version("0.30.2") 117 | 118 | garment_result, tryon_result = run_inference( 119 | image_path=args.image, 120 | mask_path=args.mask, 121 | garment_path=args.garment, 122 | num_steps=args.steps, 123 | guidance_scale=args.guidance_scale, 124 | seed=args.seed, 125 | size=(args.width, args.height) 126 | ) 127 | output_tryon_path=args.output_tryon 128 | 129 | tryon_result.save(output_tryon_path) 130 | 131 | print("Successfully saved garment and try-on images") 132 | 133 | if __name__ == "__main__": 134 | main() -------------------------------------------------------------------------------- /image_datasets/cp_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal 3 | 4 | import json 5 | import os 6 | import os.path as osp 7 | import random 8 | from copy import deepcopy 9 | 10 | import cv2 11 | import numpy as np 12 | import pandas as pd 13 | import PIL 14 | import torch 15 | import torch.utils.data as data 16 | import torchvision.transforms as transforms 17 | from torchvision.transforms import ToPILImage 18 | from PIL import Image, ImageDraw 19 | from torch.utils.data import DataLoader 20 | from transformers import CLIPImageProcessor 21 | 22 | 23 | debug_mode=False 24 | 25 | def tensor_to_image(tensor, image_path): 26 | """ 27 | Convert a torch tensor to an image file. 28 | 29 | Args: 30 | - tensor (torch.Tensor): the input tensor. Shape (C, H, W). 31 | - image_path (str): path where the image should be saved. 32 | 33 | Returns: 34 | - None 35 | """ 36 | if debug_mode: 37 | # Check the tensor dimensions. If it's a batch, take the first image 38 | if len(tensor.shape) == 4: 39 | tensor = tensor[0] 40 | 41 | # Check for possible normalization and bring the tensor to 0-1 range if necessary 42 | if tensor.min() < 0 or tensor.max() > 1: 43 | tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) 44 | 45 | # Convert tensor to PIL Image 46 | to_pil = ToPILImage() 47 | img = to_pil(tensor) 48 | 49 | # Save the PIL Image 50 | dir_path = os.path.dirname(image_path) 51 | if not os.path.exists(dir_path): 52 | os.makedirs(dir_path) 53 | img.save(image_path) 54 | 55 | class VitonHDTestDataset(data.Dataset): 56 | def __init__( 57 | self, 58 | dataroot_path: str, 59 | phase: Literal["train", "test"], 60 | order: Literal["paired", "unpaired"] = "paired", 61 | size: Tuple[int, int] = (512, 384), 62 | data_list: Optional[str] = None, 63 | ): 64 | super(VitonHDTestDataset, self).__init__() 65 | self.dataroot = dataroot_path 66 | self.phase = phase 67 | self.height = size[0] 68 | self.width = size[1] 69 | self.size = size 70 | # This code defines a transformation pipeline for image processing 71 | self.transform = transforms.Compose( 72 | [ 73 | # Convert the input image to a PyTorch tensor 74 | transforms.ToTensor(), 75 | # Normalize the tensor values to a range of [-1, 1] 76 | # The first [0.5] is the mean, and the second [0.5] is the standard deviation 77 | # This normalization is applied to each color channel 78 | transforms.Normalize([0.5], [0.5]), 79 | ] 80 | ) 81 | self.toTensor = transforms.ToTensor() 82 | 83 | self.order = order 84 | self.toTensor = transforms.ToTensor() 85 | 86 | im_names = [] 87 | c_names = [] 88 | dataroot_names = [] 89 | 90 | 91 | filename = os.path.join(dataroot_path, data_list) 92 | 93 | with open(filename, "r") as f: 94 | for line in f.readlines(): 95 | if phase == "train": 96 | im_name, _ = line.strip().split() 97 | c_name = im_name 98 | else: 99 | if order == "paired": 100 | im_name, _ = line.strip().split() 101 | c_name = im_name 102 | else: 103 | im_name, c_name = line.strip().split() 104 | 105 | im_names.append(im_name) 106 | c_names.append(c_name) 107 | dataroot_names.append(dataroot_path) 108 | 109 | self.im_names = im_names 110 | self.c_names = c_names 111 | self.dataroot_names = dataroot_names 112 | def __getitem__(self, index): 113 | c_name = self.c_names[index] 114 | im_name = self.im_names[index] 115 | 116 | 117 | cloth = Image.open(os.path.join(self.dataroot, self.phase, "cloth", c_name)).resize((self.width,self.height)) 118 | cloth_pure = self.transform(cloth) 119 | cloth_mask = Image.open(os.path.join(self.dataroot, self.phase, "cloth-mask", c_name)).resize((self.width,self.height)) 120 | cloth_mask = self.transform(cloth_mask) 121 | 122 | im_pil_big = Image.open( 123 | os.path.join(self.dataroot, self.phase, "image", im_name) 124 | ).resize((self.width,self.height)) 125 | image = self.transform(im_pil_big) 126 | 127 | mask = Image.open(os.path.join(self.dataroot, self.phase, "agnostic-mask", im_name.replace('.jpg','_mask.png'))).resize((self.width,self.height)) 128 | mask = self.toTensor(mask) 129 | mask = mask[:1] 130 | mask = 1-mask 131 | im_mask = image * mask 132 | 133 | pose_img = Image.open( 134 | os.path.join(self.dataroot, self.phase, "image-densepose", im_name) 135 | ).resize((self.width,self.height)) 136 | pose_img = self.transform(pose_img) # [-1,1] 137 | 138 | result = {} 139 | result["c_name"] = c_name 140 | result["im_name"] = im_name 141 | result["cloth_pure"] = cloth_pure 142 | result["cloth_mask"] = cloth_mask 143 | 144 | # Concatenate image and garment along width dimension 145 | inpaint_image = torch.cat([cloth_pure, im_mask], dim=2) # dim=2 is width dimension 146 | result["im_mask"] = inpaint_image 147 | 148 | GT_image = torch.cat([cloth_pure, image], dim=2) # dim=2 is width dimension 149 | result["image"] = GT_image 150 | 151 | # Create extended black mask for garment portion 152 | garment_mask = torch.zeros_like(1-mask) # Create mask of same size as original 153 | extended_mask = torch.cat([garment_mask, 1-mask], dim=2) # Concatenate masks 154 | result["inpaint_mask"] = extended_mask 155 | 156 | return result 157 | 158 | def __len__(self): 159 | # model images + cloth image 160 | return len(self.im_names) 161 | 162 | 163 | if __name__ == "__main__": 164 | dataset = CPDataset("/data/user/gjh/VITON-HD", 512, mode="train", unpaired=False) 165 | loader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=4) 166 | for data in loader: 167 | pass -------------------------------------------------------------------------------- /app_lora.py: -------------------------------------------------------------------------------- 1 | import spaces 2 | 3 | import gradio as gr 4 | from tryon_inference import run_inference 5 | import os 6 | import numpy as np 7 | from PIL import Image 8 | import tempfile 9 | import torch 10 | from diffusers import FluxTransformer2DModel, FluxFillPipeline 11 | import subprocess 12 | 13 | subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) 14 | dtype = torch.bfloat16 15 | device = "cuda" if torch.cuda.is_available() else "cpu" 16 | 17 | print("Start loading LoRA weights") 18 | state_dict, network_alphas = FluxFillPipeline.lora_state_dict( 19 | pretrained_model_name_or_path_or_dict="xiaozaa/catvton-flux-lora-alpha", ## The tryon Lora weights 20 | weight_name="pytorch_lora_weights.safetensors", 21 | return_alphas=True 22 | ) 23 | is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) 24 | if not is_correct_format: 25 | raise ValueError("Invalid LoRA checkpoint.") 26 | print('Loading diffusion model ...') 27 | pipe = FluxFillPipeline.from_pretrained( 28 | "black-forest-labs/FLUX.1-Fill-dev", 29 | torch_dtype=torch.bfloat16 30 | ).to(device) 31 | FluxFillPipeline.load_lora_into_transformer( 32 | state_dict=state_dict, 33 | network_alphas=network_alphas, 34 | transformer=pipe.transformer, 35 | ) 36 | 37 | print('Loading Finished!') 38 | 39 | @spaces.GPU 40 | def gradio_inference( 41 | image_data, 42 | garment, 43 | num_steps=50, 44 | guidance_scale=30.0, 45 | seed=-1, 46 | width=768, 47 | height=1024 48 | ): 49 | """Wrapper function for Gradio interface""" 50 | # Use temporary directory 51 | with tempfile.TemporaryDirectory() as tmp_dir: 52 | # Save inputs to temp directory 53 | temp_image = os.path.join(tmp_dir, "image.png") 54 | temp_mask = os.path.join(tmp_dir, "mask.png") 55 | temp_garment = os.path.join(tmp_dir, "garment.png") 56 | 57 | # Extract image and mask from ImageEditor data 58 | image = image_data["background"] 59 | mask = image_data["layers"][0] # First layer contains the mask 60 | 61 | # Convert to numpy array and process mask 62 | mask_array = np.array(mask) 63 | is_black = np.all(mask_array < 10, axis=2) 64 | mask = Image.fromarray(((~is_black) * 255).astype(np.uint8)) 65 | 66 | # Save files to temp directory 67 | image.save(temp_image) 68 | mask.save(temp_mask) 69 | garment.save(temp_garment) 70 | 71 | try: 72 | # Run inference 73 | _, tryon_result = run_inference( 74 | pipe=pipe, 75 | image_path=temp_image, 76 | mask_path=temp_mask, 77 | garment_path=temp_garment, 78 | num_steps=num_steps, 79 | guidance_scale=guidance_scale, 80 | seed=seed, 81 | size=(width, height) 82 | ) 83 | return tryon_result 84 | except Exception as e: 85 | raise gr.Error(f"Error during inference: {str(e)}") 86 | 87 | with gr.Blocks() as demo: 88 | gr.Markdown(""" 89 | # CATVTON FLUX Virtual Try-On Demo (by using LoRA weights) 90 | Upload a model image, draw a mask, and a garment image to generate virtual try-on results. 91 | 92 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/xiaozaa/catvton-flux-alpha) 93 | [![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/nftblackmagic/catvton-flux) 94 | """) 95 | 96 | # gr.Video("example/github.mp4", label="Demo Video: How to use the tool") 97 | 98 | with gr.Column(): 99 | with gr.Row(): 100 | with gr.Column(): 101 | image_input = gr.ImageMask( 102 | label="Model Image (Click 'Edit' and draw mask over the clothing area)", 103 | type="pil", 104 | height=600, 105 | width=300 106 | ) 107 | gr.Examples( 108 | examples=[ 109 | ["./example/person/00008_00.jpg"], 110 | ["./example/person/00055_00.jpg"], 111 | ["./example/person/00057_00.jpg"], 112 | ["./example/person/00067_00.jpg"], 113 | ["./example/person/00069_00.jpg"], 114 | ], 115 | inputs=[image_input], 116 | label="Person Images", 117 | ) 118 | with gr.Column(): 119 | garment_input = gr.Image(label="Garment Image", type="pil", height=600, width=300) 120 | gr.Examples( 121 | examples=[ 122 | ["./example/garment/04564_00.jpg"], 123 | ["./example/garment/00055_00.jpg"], 124 | ["./example/garment/00396_00.jpg"], 125 | ["./example/garment/00067_00.jpg"], 126 | ["./example/garment/00069_00.jpg"], 127 | ], 128 | inputs=[garment_input], 129 | label="Garment Images", 130 | ) 131 | with gr.Column(): 132 | tryon_output = gr.Image(label="Try-On Result", height=600, width=300) 133 | 134 | with gr.Row(): 135 | num_steps = gr.Slider( 136 | minimum=1, 137 | maximum=100, 138 | value=30, 139 | step=1, 140 | label="Number of Steps" 141 | ) 142 | guidance_scale = gr.Slider( 143 | minimum=1.0, 144 | maximum=50.0, 145 | value=30.0, 146 | step=0.5, 147 | label="Guidance Scale" 148 | ) 149 | seed = gr.Slider( 150 | minimum=-1, 151 | maximum=2147483647, 152 | step=1, 153 | value=-1, 154 | label="Seed (-1 for random)" 155 | ) 156 | width = gr.Slider( 157 | minimum=256, 158 | maximum=1024, 159 | step=64, 160 | value=768, 161 | label="Width" 162 | ) 163 | height = gr.Slider( 164 | minimum=256, 165 | maximum=1024, 166 | step=64, 167 | value=1024, 168 | label="Height" 169 | ) 170 | 171 | 172 | submit_btn = gr.Button("Generate Try-On", variant="primary") 173 | 174 | 175 | with gr.Row(): 176 | gr.Markdown(""" 177 | ### Notes: 178 | - The model is trained on VITON-HD dataset. It focuses on the woman upper body try-on generation. 179 | - The mask should indicate the region where the garment will be placed. 180 | - The garment image should be on a clean background. 181 | - The model is not perfect. It may generate some artifacts. 182 | - The model is slow. Please be patient. 183 | - The model is just for research purpose. 184 | """) 185 | 186 | submit_btn.click( 187 | fn=gradio_inference, 188 | inputs=[ 189 | image_input, 190 | garment_input, 191 | num_steps, 192 | guidance_scale, 193 | seed, 194 | width, 195 | height 196 | ], 197 | outputs=[tryon_output], 198 | api_name="try-on" 199 | ) 200 | 201 | 202 | demo.launch() 203 | -------------------------------------------------------------------------------- /app_no_lora.py: -------------------------------------------------------------------------------- 1 | import spaces 2 | 3 | import gradio as gr 4 | from tryon_inference import run_inference 5 | import os 6 | import numpy as np 7 | from PIL import Image 8 | import tempfile 9 | import torch 10 | from diffusers import FluxTransformer2DModel, FluxFillPipeline 11 | 12 | import shutil 13 | 14 | def find_cuda(): 15 | # Check if CUDA_HOME or CUDA_PATH environment variables are set 16 | cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') 17 | 18 | if cuda_home and os.path.exists(cuda_home): 19 | return cuda_home 20 | 21 | # Search for the nvcc executable in the system's PATH 22 | nvcc_path = shutil.which('nvcc') 23 | 24 | if nvcc_path: 25 | # Remove the 'bin/nvcc' part to get the CUDA installation path 26 | cuda_path = os.path.dirname(os.path.dirname(nvcc_path)) 27 | return cuda_path 28 | 29 | return None 30 | 31 | cuda_path = find_cuda() 32 | 33 | if cuda_path: 34 | print(f"CUDA installation found at: {cuda_path}") 35 | else: 36 | print("CUDA installation not found") 37 | 38 | device = torch.device('cuda') 39 | 40 | print('Loading diffusion model ...') 41 | transformer = FluxTransformer2DModel.from_pretrained( 42 | "xiaozaa/catvton-flux-alpha", 43 | torch_dtype=torch.bfloat16 44 | ) 45 | pipe = FluxFillPipeline.from_pretrained( 46 | "black-forest-labs/FLUX.1-dev", 47 | transformer=transformer, 48 | torch_dtype=torch.bfloat16 49 | ).to(device) 50 | print('Loading Finished!') 51 | 52 | @spaces.GPU 53 | def gradio_inference( 54 | image_data, 55 | garment, 56 | num_steps=50, 57 | guidance_scale=30.0, 58 | seed=-1, 59 | width=768, 60 | height=1024 61 | ): 62 | """Wrapper function for Gradio interface""" 63 | # Use temporary directory 64 | with tempfile.TemporaryDirectory() as tmp_dir: 65 | # Save inputs to temp directory 66 | temp_image = os.path.join(tmp_dir, "image.png") 67 | temp_mask = os.path.join(tmp_dir, "mask.png") 68 | temp_garment = os.path.join(tmp_dir, "garment.png") 69 | 70 | # Extract image and mask from ImageEditor data 71 | image = image_data["background"] 72 | mask = image_data["layers"][0] # First layer contains the mask 73 | 74 | # Convert to numpy array and process mask 75 | mask_array = np.array(mask) 76 | is_black = np.all(mask_array < 10, axis=2) 77 | mask = Image.fromarray(((~is_black) * 255).astype(np.uint8)) 78 | 79 | # Save files to temp directory 80 | image.save(temp_image) 81 | mask.save(temp_mask) 82 | garment.save(temp_garment) 83 | 84 | try: 85 | # Run inference 86 | _, tryon_result = run_inference( 87 | pipe=pipe, 88 | image_path=temp_image, 89 | mask_path=temp_mask, 90 | garment_path=temp_garment, 91 | num_steps=num_steps, 92 | guidance_scale=guidance_scale, 93 | seed=seed, 94 | size=(width, height) 95 | ) 96 | return tryon_result 97 | except Exception as e: 98 | raise gr.Error(f"Error during inference: {str(e)}") 99 | 100 | with gr.Blocks() as demo: 101 | gr.Markdown(""" 102 | # CATVTON FLUX Virtual Try-On Demo 103 | Upload a model image, draw a mask, and a garment image to generate virtual try-on results. 104 | 105 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/xiaozaa/catvton-flux-alpha) 106 | [![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/nftblackmagic/catvton-flux) 107 | """) 108 | 109 | # gr.Video("example/github.mp4", label="Demo Video: How to use the tool") 110 | 111 | with gr.Column(): 112 | with gr.Row(): 113 | with gr.Column(): 114 | image_input = gr.ImageMask( 115 | label="Model Image (Click 'Edit' and draw mask over the clothing area)", 116 | type="pil", 117 | height=600, 118 | width=300 119 | ) 120 | gr.Examples( 121 | examples=[ 122 | ["./example/person/00008_00.jpg"], 123 | ["./example/person/00055_00.jpg"], 124 | ["./example/person/00057_00.jpg"], 125 | ["./example/person/00067_00.jpg"], 126 | ["./example/person/00069_00.jpg"], 127 | ], 128 | inputs=[image_input], 129 | label="Person Images", 130 | ) 131 | with gr.Column(): 132 | garment_input = gr.Image(label="Garment Image", type="pil", height=600, width=300) 133 | gr.Examples( 134 | examples=[ 135 | ["./example/garment/04564_00.jpg"], 136 | ["./example/garment/00055_00.jpg"], 137 | ["./example/garment/00396_00.jpg"], 138 | ["./example/garment/00067_00.jpg"], 139 | ["./example/garment/00069_00.jpg"], 140 | ], 141 | inputs=[garment_input], 142 | label="Garment Images", 143 | ) 144 | with gr.Column(): 145 | tryon_output = gr.Image(label="Try-On Result", height=600, width=300) 146 | 147 | with gr.Row(): 148 | num_steps = gr.Slider( 149 | minimum=1, 150 | maximum=100, 151 | value=30, 152 | step=1, 153 | label="Number of Steps" 154 | ) 155 | guidance_scale = gr.Slider( 156 | minimum=1.0, 157 | maximum=50.0, 158 | value=30.0, 159 | step=0.5, 160 | label="Guidance Scale" 161 | ) 162 | seed = gr.Slider( 163 | minimum=-1, 164 | maximum=2147483647, 165 | step=1, 166 | value=-1, 167 | label="Seed (-1 for random)" 168 | ) 169 | width = gr.Slider( 170 | minimum=256, 171 | maximum=1024, 172 | step=64, 173 | value=768, 174 | label="Width" 175 | ) 176 | height = gr.Slider( 177 | minimum=256, 178 | maximum=1024, 179 | step=64, 180 | value=1024, 181 | label="Height" 182 | ) 183 | 184 | 185 | submit_btn = gr.Button("Generate Try-On", variant="primary") 186 | 187 | 188 | with gr.Row(): 189 | gr.Markdown(""" 190 | ### Notes: 191 | - The model is trained on VITON-HD dataset. It focuses on the woman upper body try-on generation. 192 | - The mask should indicate the region where the garment will be placed. 193 | - The garment image should be on a clean background. 194 | - The model is not perfect. It may generate some artifacts. 195 | - The model is slow. Please be patient. 196 | - The model is just for research purpose. 197 | """) 198 | 199 | submit_btn.click( 200 | fn=gradio_inference, 201 | inputs=[ 202 | image_input, 203 | garment_input, 204 | num_steps, 205 | guidance_scale, 206 | seed, 207 | width, 208 | height 209 | ], 210 | outputs=[tryon_output], 211 | api_name="try-on" 212 | ) 213 | 214 | 215 | demo.launch() -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import spaces 2 | 3 | import gradio as gr 4 | from tryon_inference import run_inference 5 | import os 6 | import numpy as np 7 | from PIL import Image 8 | import tempfile 9 | import torch 10 | from diffusers import FluxTransformer2DModel, FluxFillPipeline 11 | import subprocess 12 | 13 | subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) 14 | dtype = torch.bfloat16 15 | device = "cuda" if torch.cuda.is_available() else "cpu" 16 | 17 | print('Loading diffusion model ...') 18 | transformer = FluxTransformer2DModel.from_pretrained( 19 | "xiaozaa/catvton-flux-alpha", 20 | torch_dtype=dtype 21 | ) 22 | pipe = FluxFillPipeline.from_pretrained( 23 | "black-forest-labs/FLUX.1-dev", 24 | transformer=transformer, 25 | torch_dtype=dtype 26 | ).to(device) 27 | print('Loading Finished!') 28 | 29 | @spaces.GPU(duration=120) 30 | def gradio_inference( 31 | image_data, 32 | garment, 33 | num_steps=50, 34 | guidance_scale=30.0, 35 | seed=-1, 36 | width=768, 37 | height=1024 38 | ): 39 | """Wrapper function for Gradio interface""" 40 | # Check if mask has been drawn 41 | if image_data is None or "layers" not in image_data or not image_data["layers"]: 42 | raise gr.Error("Please draw a mask over the clothing area before generating!") 43 | 44 | # Check if mask is empty (all black) 45 | mask = image_data["layers"][0] 46 | mask_array = np.array(mask) 47 | if np.all(mask_array < 10): 48 | raise gr.Error("The mask is empty! Please draw over the clothing area you want to replace.") 49 | 50 | # Use temporary directory 51 | with tempfile.TemporaryDirectory() as tmp_dir: 52 | # Save inputs to temp directory 53 | temp_image = os.path.join(tmp_dir, "image.png") 54 | temp_mask = os.path.join(tmp_dir, "mask.png") 55 | temp_garment = os.path.join(tmp_dir, "garment.png") 56 | 57 | # Extract image and mask from ImageEditor data 58 | image = image_data["background"] 59 | mask = image_data["layers"][0] # First layer contains the mask 60 | 61 | # Convert to numpy array and process mask 62 | mask_array = np.array(mask) 63 | is_black = np.all(mask_array < 10, axis=2) 64 | mask = Image.fromarray(((~is_black) * 255).astype(np.uint8)) 65 | 66 | # Save files to temp directory 67 | image.save(temp_image) 68 | mask.save(temp_mask) 69 | garment.save(temp_garment) 70 | 71 | try: 72 | # Run inference 73 | _, tryon_result = run_inference( 74 | pipe=pipe, 75 | image_path=temp_image, 76 | mask_path=temp_mask, 77 | garment_path=temp_garment, 78 | num_steps=num_steps, 79 | guidance_scale=guidance_scale, 80 | seed=seed, 81 | size=(width, height) 82 | ) 83 | return tryon_result 84 | except Exception as e: 85 | raise gr.Error(f"Error during inference: {str(e)}") 86 | 87 | with gr.Blocks() as demo: 88 | gr.Markdown(""" 89 | # CATVTON FLUX Virtual Try-On Demo 90 | Upload a model image, draw a mask, and a garment image to generate virtual try-on results. 91 | 92 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/xiaozaa/catvton-flux-alpha) 93 | [![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/nftblackmagic/catvton-flux) 94 | """) 95 | 96 | # gr.Video("example/github.mp4", label="Demo Video: How to use the tool") 97 | 98 | with gr.Column(): 99 | gr.Markdown(""" 100 | ### ⚠️ Important: 101 | 1. Choose a model image or upload your own 102 | 2. Use the Pen tool to draw a mask over the clothing area you want to replace 103 | 3. Choose a garment image or upload your own 104 | """) 105 | 106 | with gr.Row(): 107 | with gr.Column(): 108 | image_input = gr.ImageMask( 109 | label="Model Image (Click 'Edit' and draw mask over the clothing area)", 110 | type="pil", 111 | height=600, 112 | width=300 113 | ) 114 | gr.Examples( 115 | examples=[ 116 | ["./example/person/00008_00.jpg"], 117 | ["./example/person/00055_00.jpg"], 118 | ["./example/person/00057_00.jpg"], 119 | ["./example/person/00067_00.jpg"], 120 | ["./example/person/00069_00.jpg"], 121 | ], 122 | inputs=[image_input], 123 | label="Person Images", 124 | ) 125 | with gr.Column(): 126 | garment_input = gr.Image(label="Garment Image", type="pil", height=600, width=300) 127 | gr.Examples( 128 | examples=[ 129 | ["./example/garment/04564_00.jpg"], 130 | ["./example/garment/00055_00.jpg"], 131 | ["./example/garment/00396_00.jpg"], 132 | ["./example/garment/00067_00.jpg"], 133 | ["./example/garment/00069_00.jpg"], 134 | ], 135 | inputs=[garment_input], 136 | label="Garment Images", 137 | ) 138 | with gr.Column(): 139 | tryon_output = gr.Image(label="Try-On Result", height=600, width=300) 140 | 141 | with gr.Row(): 142 | num_steps = gr.Slider( 143 | minimum=1, 144 | maximum=100, 145 | value=30, 146 | step=1, 147 | label="Number of Steps" 148 | ) 149 | guidance_scale = gr.Slider( 150 | minimum=1.0, 151 | maximum=50.0, 152 | value=30.0, 153 | step=0.5, 154 | label="Guidance Scale" 155 | ) 156 | seed = gr.Slider( 157 | minimum=-1, 158 | maximum=2147483647, 159 | step=1, 160 | value=-1, 161 | label="Seed (-1 for random)" 162 | ) 163 | width = gr.Slider( 164 | minimum=256, 165 | maximum=1024, 166 | step=64, 167 | value=768, 168 | label="Width" 169 | ) 170 | height = gr.Slider( 171 | minimum=256, 172 | maximum=1024, 173 | step=64, 174 | value=1024, 175 | label="Height" 176 | ) 177 | 178 | 179 | submit_btn = gr.Button("Generate Try-On", variant="primary") 180 | 181 | 182 | with gr.Row(): 183 | gr.Markdown(""" 184 | ### Notes: 185 | - The model is trained on VITON-HD dataset. It focuses on the woman upper body try-on generation. 186 | - The mask should indicate the region where the garment will be placed. 187 | - The garment image should be on a clean background. 188 | - The model is not perfect. It may generate some artifacts. 189 | - The model is slow. Please be patient. 190 | - The model is just for research purpose. 191 | """) 192 | 193 | submit_btn.click( 194 | fn=gradio_inference, 195 | inputs=[ 196 | image_input, 197 | garment_input, 198 | num_steps, 199 | guidance_scale, 200 | seed, 201 | width, 202 | height 203 | ], 204 | outputs=[tryon_output], 205 | api_name="try-on" 206 | ) 207 | 208 | 209 | demo.launch() 210 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: catvton-flux 3 | emoji: 🖥️ 4 | colorFrom: yellow 5 | colorTo: pink 6 | sdk: gradio 7 | sdk_version: 5.0.1 8 | app_file: app.py 9 | pinned: false 10 | --- 11 | 12 | 13 | # catvton-flux 14 | 15 | An state-of-the-art virtual try-on solution that combines the power of [CATVTON](https://arxiv.org/abs/2407.15886) (CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models) with Flux fill inpainting model for realistic and accurate clothing transfer. 16 | Also inspired by [In-Context LoRA](https://arxiv.org/abs/2410.23775) for prompt engineering. 17 | 18 | Running it now on website: [CATVTON-FLUX-TRY-ON](https://huggingface.co/spaces/xiaozaa/catvton-flux-try-on) 19 | 20 | ## Update 21 | 22 | --- 23 | **Latest Achievement** 24 | 25 | (2025/1/26): 26 | - Released the training code for Flux inpainting full para fine-tuning. It requires 2xH100 80GB. 27 | 28 | (2025/1/16): 29 | - Released a new version of tryon model [CATVTON-FLUX-BETA](https://huggingface.co/xiaozaa/catvton-flux-beta). This model can handle all kinds of garments. 30 | - Released a training notes. This is a summary of what we found when doing the training. [Here](https://github.com/nftblackmagic/catvton-flux/blob/main/TrainingNotes.md) 31 | 32 | (2024/12/6): 33 | - Released a new weights for tryoff. The model named [cat-tryoff-flux](https://huggingface.co/xiaozaa/cat-tryoff-flux) can extract and reconstruct the front view of clothing items from images of people wearing them. [Showcase examples](#try-off-examples) is here. 34 | - Try-off Hugging Face: 🤗 [CAT-TRYOFF-FLUX](https://huggingface.co/spaces/xiaozaa/cat-try-off-flux) 35 | 36 | (2024/12/1): 37 | - Community comfyui support [here](https://github.com/lujiazho/ComfyUI-CatvtonFluxWrapper). Thanks to [lujiazho](https://github.com/lujiazho) 38 | 39 | (2024/11/26): 40 | - Updated the weights. (Still training on the VITON-HD dataset only.) 41 | - Reduce the fine-tuning weights size (46GB -> 23GB) 42 | - Weights has better performance on garment small details/text. 43 | - Added the huggingface ZeroGPU support. You can run **CATVTON-FLUX-TRY-ON** now on huggingface space [here](https://huggingface.co/spaces/xiaozaa/catvton-flux-try-on) 44 | 45 | (2024/11/25): 46 | - Released lora weights. Lora weights achieved FID: `6.0675811767578125` on VITON-HD dataset. Test configuration: scale 30, step 30. 47 | - Revise gradio demo. Added huggingface spaces support. 48 | - Clean up the requirements.txt. 49 | 50 | (2024/11/24): 51 | - Released FID score and gradio demo 52 | - CatVton-Flux-Alpha achieved **SOTA** performance with FID: `5.593255043029785` on VITON-HD dataset. Test configuration: scale 30, step 30. My VITON-HD test inferencing results available [here](https://drive.google.com/file/d/1T2W5R1xH_uszGVD8p6UUAtWyx43rxGmI/view?usp=sharing) 53 | 54 | --- 55 | 56 | ## Showcase 57 | 58 | ### Try-on examples 59 | | Original | Garment | Result | 60 | |----------|---------|---------| 61 | | ![Original](example/person/1.jpg) | ![Garment](example/garment/00035_00.jpg) | ![Result](example/result/1.png) | 62 | | ![Original](example/person/1.jpg) | ![Garment](example/garment/04564_00.jpg) | ![Result](example/result/2.png) | 63 | | ![Original](example/person/00008_00.jpg) | ![Garment](example/garment/00034_00.jpg) | ![Result](example/result/3.png) | 64 | 65 | ### Try-off examples 66 | | Original clothed model | Restored garment result | 67 | |------------------------|------------------------| 68 | | ![Original](example/person/00055_00.jpg) | ![Restored garment result](example/tryoff_result/restored_garment2.png) | 69 | | ![Original](example/person/00064_00.jpg) | ![Restored garment result](example/tryoff_result/restored_garment4.png) | 70 | | ![Original](example/person/00069_00.jpg) | ![Restored garment result](example/tryoff_result/restored_garment6.png) | 71 | 72 | 73 | ## Model Weights 74 | ### Tryon 75 | Fine-tuning weights in Hugging Face: 🤗 [catvton-flux-alpha](https://huggingface.co/xiaozaa/catvton-flux-alpha) 76 | 77 | LORA weights in Hugging Face: 🤗 [catvton-flux-lora-alpha](https://huggingface.co/xiaozaa/catvton-flux-lora-alpha) 78 | 79 | ### Tryoff 80 | Fine-tuning weights in Hugging Face: 🤗 [cat-tryoff-flux](https://huggingface.co/xiaozaa/cat-tryoff-flux) 81 | 82 | ### Dataset 83 | The model weights are trained on the [VITON-HD](https://github.com/shadow2496/VITON-HD) dataset. 84 | 85 | ## Prerequisites 86 | Make sure you are running the code with VRAM >= 40GB. (I run all my experiments on a 80GB GPU, lower VRAM will cause OOM error. Will support lower VRAM in the future.) 87 | 88 | ```bash 89 | bash 90 | conda create -n flux python=3.10 91 | conda activate flux 92 | pip install -r requirements.txt 93 | huggingface-cli login 94 | ``` 95 | 96 | ## Usage 97 | 98 | ### Training steps 99 | #### Prepare the dataset 100 | You can download VITON-HD dataset from [VITON-HD](https://github.com/shadow2496/VITON-HD). 101 | The data structure is as follows: 102 | Structure of the Dataset directory should be as follows. 103 | 104 | 105 | ``` 106 | train 107 | |-- ... 108 | 109 | test 110 | |-- image 111 | |-- image-densepose 112 | |-- agnostic-mask 113 | |-- cloth 114 | ``` 115 | 116 | #### Train the model 117 | Run the following command to train the model (make sure you have 2xH100 80GB): 118 | ```bash 119 | bash train_flux_inpaint.sh 120 | ``` 121 | Adjust the path to your dataset and txt file. 122 | 123 | ### Tryoff inference 124 | Run the following command to restore the front side of the garment from the clothed model image: 125 | ```bash 126 | python tryoff_inference.py \ 127 | --image ./example/person/00069_00.jpg \ 128 | --mask ./example/person/00069_00_mask.png \ 129 | --seed 41 \ 130 | --output_tryon test_original.png \ 131 | --output_garment restored_garment6.png \ 132 | --steps 30 133 | ``` 134 | 135 | ### Tryon inference 136 | Run the following command to try on an image: 137 | 138 | LORA version: 139 | ```bash 140 | python tryon_inference_lora.py \ 141 | --image ./example/person/00008_00.jpg \ 142 | --mask ./example/person/00008_00_mask.png \ 143 | --garment ./example/garment/00034_00.jpg \ 144 | --seed 4096 \ 145 | --output_tryon test_lora.png \ 146 | --steps 30 147 | ``` 148 | 149 | Fine-tuning version: 150 | ```bash 151 | python tryon_inference.py \ 152 | --image ./example/person/00008_00.jpg \ 153 | --mask ./example/person/00008_00_mask.png \ 154 | --garment ./example/garment/00034_00.jpg \ 155 | --seed 42 \ 156 | --output_tryon test.png \ 157 | --steps 30 158 | ``` 159 | 160 | Run the following command to start a gradio demo with LoRA weights: 161 | ```bash 162 | python app.py 163 | ``` 164 | 165 | Run the following command to start a gradio demo without LoRA weights: 166 | ```bash 167 | python app_no_lora.py 168 | ``` 169 | 170 | Gradio demo: 171 | Try-on Hugging Face: 🤗 [CATVTON-FLUX-TRY-ON](https://huggingface.co/spaces/xiaozaa/catvton-flux-try-on) 172 | Try-off Hugging Face: 🤗 [CAT-TRYOFF-FLUX](https://huggingface.co/spaces/xiaozaa/cat-try-off-flux) 173 | 174 | [![Demo](example/github.jpg)](https://upcdn.io/FW25b7k/raw/uploads/github.mp4) 175 | 176 | 177 | ## TODO: 178 | - [x] Release the FID score 179 | - [x] Add gradio demo 180 | - [x] Release updated weights with better performance 181 | - [x] Train a smaller model 182 | - [x] Support comfyui 183 | - [x] Release tryoff weights 184 | - [x] Release Flux inpainting full para fine-tuning training code 185 | ## Citation 186 | 187 | ```bibtex 188 | @misc{chong2024catvtonconcatenationneedvirtual, 189 | title={CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models}, 190 | author={Zheng Chong and Xiao Dong and Haoxiang Li and Shiyue Zhang and Wenqing Zhang and Xujie Zhang and Hanqing Zhao and Xiaodan Liang}, 191 | year={2024}, 192 | eprint={2407.15886}, 193 | archivePrefix={arXiv}, 194 | primaryClass={cs.CV}, 195 | url={https://arxiv.org/abs/2407.15886}, 196 | } 197 | @article{lhhuang2024iclora, 198 | title={In-Context LoRA for Diffusion Transformers}, 199 | author={Huang, Lianghua and Wang, Wei and Wu, Zhi-Fan and Shi, Yupeng and Dou, Huanzhang and Liang, Chen and Feng, Yutong and Liu, Yu and Zhou, Jingren}, 200 | journal={arXiv preprint arxiv:2410.23775}, 201 | year={2024} 202 | } 203 | ``` 204 | 205 | Thanks to [Jim](https://github.com/nom) for insisting on spatial concatenation. 206 | Thanks to [dingkang](https://github.com/dingkwang) [MoonBlvd](https://github.com/MoonBlvd) [Stevada](https://github.com/Stevada) for the helpful discussions. 207 | 208 | ## License 209 | - The code is licensed under the MIT License. 210 | - The model weights have the same license as Flux.1 Fill and VITON-HD. 211 | -------------------------------------------------------------------------------- /src/flux/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | 4 | def prepare_image_with_mask( 5 | image_processor, 6 | mask_processor, 7 | vae, 8 | vae_scale_factor, 9 | image, 10 | mask, 11 | width, 12 | height, 13 | batch_size, 14 | num_images_per_prompt, 15 | device, 16 | dtype, 17 | is_cloth=False, 18 | ): 19 | # Prepare image 20 | if isinstance(image, torch.Tensor): 21 | pass 22 | else: 23 | image = image_processor.preprocess(image, height=height, width=width) 24 | 25 | # print("image.shape", image.shape) 26 | image_batch_size = image.shape[0] 27 | if image_batch_size == 1: 28 | repeat_by = batch_size 29 | else: 30 | # image batch size is the same as prompt batch size 31 | repeat_by = num_images_per_prompt 32 | image = image.repeat_interleave(repeat_by, dim=0) 33 | image = image.to(device=device, dtype=dtype) 34 | 35 | # Prepare mask 36 | if isinstance(mask, torch.Tensor): 37 | pass 38 | else: 39 | mask = mask_processor.preprocess(mask, height=height, width=width) 40 | mask = mask.repeat_interleave(repeat_by, dim=0) 41 | mask = mask.to(device=device, dtype=dtype) 42 | 43 | # Get masked image 44 | masked_image = image.clone() 45 | masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1 46 | 47 | # Encode to latents 48 | image_latents = vae.encode(masked_image.to(vae.dtype)).latent_dist.sample() 49 | image_latents = ( 50 | image_latents - vae.config.shift_factor 51 | ) * vae.config.scaling_factor 52 | image_latents = image_latents.to(dtype) 53 | 54 | # print("image_latents.shape", image_latents.shape) 55 | mask = torch.nn.functional.interpolate( 56 | mask, size=(height // vae_scale_factor * 2, width // vae_scale_factor * 2) 57 | ) 58 | if is_cloth: 59 | mask = mask 60 | else: 61 | mask = 1 - mask 62 | 63 | control_image = torch.cat([image_latents, mask], dim=1) 64 | 65 | # Pack cond latents 66 | packed_control_image = pack_latents( 67 | control_image, 68 | batch_size * num_images_per_prompt, 69 | control_image.shape[1], 70 | control_image.shape[2], 71 | control_image.shape[3], 72 | ) 73 | 74 | return packed_control_image, height, width 75 | 76 | def prepare_fill_with_mask( 77 | image_processor, 78 | mask_processor, 79 | vae, 80 | vae_scale_factor, 81 | image, 82 | mask, 83 | width, 84 | height, 85 | batch_size, 86 | num_images_per_prompt, 87 | device, 88 | dtype, 89 | ): 90 | """ 91 | Prepares image and mask for fill operation with proper rearrangement. 92 | Focuses only on image and mask processing. 93 | """ 94 | # Determine effective batch size 95 | effective_batch_size = batch_size * num_images_per_prompt 96 | 97 | # Prepare image 98 | if isinstance(image, torch.Tensor): 99 | pass 100 | else: 101 | image = image_processor.preprocess(image, height=height, width=width) 102 | 103 | image_batch_size = image.shape[0] 104 | repeat_by = effective_batch_size if image_batch_size == 1 else num_images_per_prompt 105 | image = image.repeat_interleave(repeat_by, dim=0) 106 | image = image.to(device=device, dtype=dtype) 107 | 108 | # Prepare mask with specific processing 109 | if isinstance(mask, torch.Tensor): 110 | pass 111 | else: 112 | mask = mask_processor.preprocess(mask, height=height, width=width) 113 | 114 | mask = mask.repeat_interleave(repeat_by, dim=0) 115 | mask = mask.to(device=device, dtype=dtype) 116 | 117 | # Apply mask to image 118 | masked_image = image.clone() 119 | masked_image = masked_image * (1 - mask) 120 | 121 | # Encode to latents 122 | image_latents = vae.encode(masked_image.to(vae.dtype)).latent_dist.sample() 123 | image_latents = ( 124 | image_latents - vae.config.shift_factor 125 | ) * vae.config.scaling_factor 126 | image_latents = image_latents.to(dtype) 127 | 128 | # Process mask following the example's specific rearrangement 129 | mask = mask[:, 0, :, :] if mask.shape[1] > 1 else mask[:, 0, :, :] 130 | mask = mask.to(torch.bfloat16) 131 | 132 | # First rearrangement: 8x8 patches 133 | mask = rearrange( 134 | mask, 135 | "b (h ph) (w pw) -> b (ph pw) h w", 136 | ph=8, 137 | pw=8, 138 | ) 139 | 140 | # Second rearrangement: 2x2 patches 141 | mask = rearrange( 142 | mask, 143 | "b c (h ph) (w pw) -> b (h w) (c ph pw)", 144 | ph=2, 145 | pw=2 146 | ) 147 | 148 | # Rearrange image latents similarly 149 | image_latents = rearrange( 150 | image_latents, 151 | "b c (h ph) (w pw) -> b (h w) (c ph pw)", 152 | ph=2, 153 | pw=2 154 | ) 155 | 156 | # Combine image and mask 157 | image_cond = torch.cat([image_latents, mask], dim=-1) 158 | 159 | return image_cond, height, width 160 | 161 | def prepare_image_with_mask_sd3( 162 | image_processor, 163 | mask_processor, 164 | vae, 165 | vae_scale_factor, 166 | image, 167 | mask, 168 | width, 169 | height, 170 | batch_size, 171 | num_images_per_prompt, 172 | device, 173 | dtype, 174 | is_cloth=False, 175 | ): 176 | # Prepare image 177 | if isinstance(image, torch.Tensor): 178 | pass 179 | else: 180 | image = image_processor.preprocess(image, height=height, width=width) 181 | 182 | # print("image.shape", image.shape) 183 | image_batch_size = image.shape[0] 184 | if image_batch_size == 1: 185 | repeat_by = batch_size 186 | else: 187 | # image batch size is the same as prompt batch size 188 | repeat_by = num_images_per_prompt 189 | image = image.repeat_interleave(repeat_by, dim=0) 190 | image = image.to(device=device, dtype=dtype) 191 | 192 | # Prepare mask 193 | if isinstance(mask, torch.Tensor): 194 | pass 195 | else: 196 | mask = mask_processor.preprocess(mask, height=height, width=width) 197 | mask = mask.repeat_interleave(repeat_by, dim=0) 198 | mask = mask.to(device=device, dtype=dtype) 199 | 200 | # Get masked image 201 | masked_image = image.clone() 202 | masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1 203 | 204 | # Encode to latents 205 | image_latents = vae.encode(masked_image.to(vae.dtype)).latent_dist.sample() 206 | image_latents = ( 207 | image_latents - vae.config.shift_factor 208 | ) * vae.config.scaling_factor 209 | image_latents = image_latents.to(dtype) 210 | 211 | # print("image_latents.shape", image_latents.shape) 212 | mask = torch.nn.functional.interpolate( 213 | mask, size=(height // vae_scale_factor, width // vae_scale_factor) 214 | ) 215 | if is_cloth: 216 | mask = mask 217 | else: 218 | mask = 1 - mask 219 | 220 | control_image = torch.cat([image_latents, mask], dim=1) 221 | 222 | return control_image, height, width 223 | 224 | def prepare_image_for_refnet( 225 | image_processor, 226 | vae, 227 | image, 228 | width, 229 | height, 230 | batch_size, 231 | num_images_per_prompt, 232 | device, 233 | dtype, 234 | ): 235 | # Prepare image 236 | if isinstance(image, torch.Tensor): 237 | pass 238 | else: 239 | image = image_processor.preprocess(image, height=height, width=width) 240 | 241 | # print("image.shape", image.shape) 242 | image_batch_size = image.shape[0] 243 | if image_batch_size == 1: 244 | repeat_by = batch_size 245 | else: 246 | # image batch size is the same as prompt batch size 247 | repeat_by = num_images_per_prompt 248 | image = image.repeat_interleave(repeat_by, dim=0) 249 | image = image.to(device=device, dtype=dtype) 250 | 251 | # Encode to latents 252 | image_latents = vae.encode(image.to(vae.dtype)).latent_dist.sample() 253 | image_latents = ( 254 | image_latents - vae.config.shift_factor 255 | ) * vae.config.scaling_factor 256 | image_latents = image_latents.to(dtype) 257 | 258 | # Pack cond latents 259 | packed_image = pack_latents( 260 | image_latents, 261 | batch_size * num_images_per_prompt, 262 | image_latents.shape[1], 263 | image_latents.shape[2], 264 | image_latents.shape[3], 265 | ) 266 | 267 | return packed_image, height, width 268 | 269 | def prepare_image_for_refnet_sd3( 270 | image_processor, 271 | vae, 272 | image, 273 | width, 274 | height, 275 | batch_size, 276 | num_images_per_prompt, 277 | device, 278 | dtype, 279 | do_classifier_free_guidance=False, 280 | guess_mode=False, 281 | ): 282 | if isinstance(image, torch.Tensor): 283 | pass 284 | else: 285 | image = image_processor.preprocess(image, height=height, width=width) 286 | 287 | image_batch_size = image.shape[0] 288 | 289 | # Prepare image 290 | if image_batch_size == 1: 291 | repeat_by = batch_size 292 | else: 293 | # image batch size is the same as prompt batch size 294 | repeat_by = num_images_per_prompt 295 | 296 | image = image.repeat_interleave(repeat_by, dim=0) 297 | 298 | image = image.to(device=device, dtype=dtype) 299 | 300 | # Encode to latents 301 | # print("masked_image.dtype", masked_image.dtype) 302 | image_latents = vae.encode(image.to(vae.dtype)).latent_dist.sample() 303 | image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor 304 | image_latents = image_latents.to(dtype) 305 | 306 | return image_latents 307 | 308 | 309 | # Copied from diffusers.pipelines.flux.pipeline_flux._pack_latents 310 | def pack_latents(latents, batch_size, num_channels_latents, height, width): 311 | latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) 312 | latents = latents.permute(0, 2, 4, 1, 3, 5) 313 | latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) 314 | 315 | return latents 316 | 317 | 318 | def prepare_latent_image_ids(batch_size, height, width, device, dtype): 319 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) 320 | latent_image_ids[..., 1] = ( 321 | latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] 322 | ) 323 | latent_image_ids[..., 2] = ( 324 | latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] 325 | ) 326 | 327 | ( 328 | latent_image_id_height, 329 | latent_image_id_width, 330 | latent_image_id_channels, 331 | ) = latent_image_ids.shape 332 | 333 | latent_image_ids = latent_image_ids.reshape( 334 | latent_image_id_height * latent_image_id_width, 335 | latent_image_id_channels, 336 | ) 337 | 338 | return latent_image_ids.to(device=device, dtype=dtype) 339 | 340 | 341 | def prepare_latents( 342 | vae_scale_factor, 343 | batch_size, 344 | height, 345 | width, 346 | dtype, 347 | device, 348 | ): 349 | height = 2 * (int(height) // (vae_scale_factor * 2)) 350 | width = 2 * (int(width) // (vae_scale_factor * 2)) 351 | 352 | 353 | latent_image_ids = prepare_latent_image_ids( 354 | batch_size, height, width, device, dtype 355 | ) 356 | return latent_image_ids 357 | 358 | def decode_packed_image( 359 | packed_control_image, 360 | vae, 361 | vae_scale_factor, 362 | height, 363 | width, 364 | batch_size, 365 | num_images_per_prompt, 366 | device, 367 | dtype, 368 | ): 369 | # Unpack latents 370 | control_image = unpack_latents( 371 | packed_control_image, 372 | batch_size * num_images_per_prompt, 373 | 5, # 4 channels for image_latents + 1 for mask 374 | height // vae_scale_factor * 2, 375 | width // vae_scale_factor * 2, 376 | ) 377 | 378 | # Split control_image into image_latents and mask 379 | image_latents, mask = torch.split(control_image, [4, 1], dim=1) 380 | 381 | # Decode latents 382 | image_latents = image_latents / vae.config.scaling_factor + vae.config.shift_factor 383 | image = vae.decode(image_latents.to(vae.dtype)).sample 384 | 385 | # Interpolate mask back to original size 386 | mask = torch.nn.functional.interpolate(mask, size=(height, width)) 387 | mask = 1 - mask # Invert mask 388 | 389 | # Apply mask to image 390 | masked_image = image.clone() 391 | masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1 392 | 393 | return image, masked_image, mask 394 | 395 | # Helper function to unpack latents 396 | def unpack_latents(packed_latents, batch_size, num_channels, height, width): 397 | unpacked = packed_latents.reshape(batch_size, height // 2, width // 2, num_channels, 2, 2) 398 | unpacked = unpacked.permute(0, 3, 1, 4, 2, 5) 399 | unpacked = unpacked.reshape(batch_size, num_channels, height, width) 400 | return unpacked 401 | 402 | 403 | def get_image_proj( 404 | transformer, 405 | image_prompt: torch.Tensor, 406 | device, 407 | ): 408 | if transformer.auto_processor is not None and transformer.image_encoder is not None and transformer.garment_adapter_improj is not None: 409 | # encode image-prompt embeds 410 | # transformer.image_encoder.to(device=device, dtype=torch.float32) 411 | # print("image_prompt.dtype", image_prompt.dtype) 412 | image_prompt = transformer.clip_image_processor( 413 | images=image_prompt, 414 | return_tensors="pt" 415 | ).pixel_values 416 | 417 | image_prompt = image_prompt.to(device) 418 | image_prompt_embeds = transformer.image_encoder( 419 | image_prompt 420 | ).image_embeds.to( 421 | device=device, dtype=torch.bfloat16, 422 | ) 423 | 424 | # encode image 425 | # print("image_prompt_embeds.shape", image_prompt_embeds.shape) 426 | image_proj = transformer.garment_adapter_improj(image_prompt_embeds) 427 | 428 | return image_proj 429 | else: 430 | print("No image projector found") 431 | return None 432 | 433 | def encode_images_to_latents(vae, pixel_values, weight_dtype, height, width, image_processor=None): 434 | if image_processor is not None: 435 | pixel_values = image_processor.preprocess(pixel_values, height=height, width=width).to(dtype=vae.dtype, device=vae.device) 436 | model_input = vae.encode(pixel_values).latent_dist.sample() 437 | model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor 438 | model_input = model_input.to(dtype=weight_dtype) 439 | 440 | return model_input 441 | 442 | 443 | @staticmethod 444 | def _unpack_latents(latents, height, width, vae_scale_factor): 445 | batch_size, num_patches, channels = latents.shape 446 | 447 | height = height // vae_scale_factor 448 | width = width // vae_scale_factor 449 | 450 | latents = latents.view(batch_size, height, width, channels // 4, 2, 2) 451 | latents = latents.permute(0, 3, 1, 4, 2, 5) 452 | 453 | latents = latents.reshape( 454 | batch_size, channels // (2 * 2), height * 2, width * 2 455 | ) 456 | 457 | return latents -------------------------------------------------------------------------------- /paser_helper.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | 5 | 6 | def parse_args(input_args=None): 7 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 8 | parser.add_argument( 9 | "--pretrained_model_name_or_path", 10 | type=str, 11 | default=None, 12 | required=True, 13 | help="Path to pretrained model or model identifier from huggingface.co/models.", 14 | ) 15 | parser.add_argument( 16 | "--revision", 17 | type=str, 18 | default=None, 19 | required=False, 20 | help="Revision of pretrained model identifier from huggingface.co/models.", 21 | ) 22 | parser.add_argument( 23 | "--variant", 24 | type=str, 25 | default=None, 26 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 27 | ) 28 | parser.add_argument( 29 | "--dataset_name", 30 | type=str, 31 | default=None, 32 | help=( 33 | "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," 34 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 35 | " or to a folder containing files that 🤗 Datasets can understand." 36 | ), 37 | ) 38 | parser.add_argument( 39 | "--dataset_config_name", 40 | type=str, 41 | default=None, 42 | help="The config of the Dataset, leave as None if there's only one config.", 43 | ) 44 | parser.add_argument( 45 | "--instance_data_dir", 46 | type=str, 47 | default=None, 48 | help=("A folder containing the training data. "), 49 | ) 50 | 51 | parser.add_argument( 52 | "--cache_dir", 53 | type=str, 54 | default=None, 55 | help="The directory where the downloaded models and datasets will be stored.", 56 | ) 57 | 58 | parser.add_argument( 59 | "--image_column", 60 | type=str, 61 | default="image", 62 | help="The column of the dataset containing the target image. By " 63 | "default, the standard Image Dataset maps out 'file_name' " 64 | "to 'image'.", 65 | ) 66 | parser.add_argument( 67 | "--caption_column", 68 | type=str, 69 | default=None, 70 | help="The column of the dataset containing the instance prompt for each image", 71 | ) 72 | 73 | parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") 74 | 75 | parser.add_argument( 76 | "--class_data_dir", 77 | type=str, 78 | default=None, 79 | required=False, 80 | help="A folder containing the training data of class images.", 81 | ) 82 | parser.add_argument( 83 | "--instance_prompt", 84 | type=str, 85 | default=None, 86 | # required=True, 87 | help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", 88 | ) 89 | parser.add_argument( 90 | "--class_prompt", 91 | type=str, 92 | default=None, 93 | help="The prompt to specify images in the same class as provided instance images.", 94 | ) 95 | parser.add_argument( 96 | "--max_sequence_length", 97 | type=int, 98 | default=77, 99 | help="Maximum sequence length to use with with the T5 text encoder", 100 | ) 101 | parser.add_argument( 102 | "--validation_prompt", 103 | type=str, 104 | default=None, 105 | help="A prompt that is used during validation to verify that the model is learning.", 106 | ) 107 | parser.add_argument( 108 | "--num_validation_images", 109 | type=int, 110 | default=4, 111 | help="Number of images that should be generated during validation with `validation_prompt`.", 112 | ) 113 | parser.add_argument( 114 | "--validation_epochs", 115 | type=int, 116 | default=50, 117 | help=( 118 | "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" 119 | " `args.validation_prompt` multiple times: `args.num_validation_images`." 120 | ), 121 | ) 122 | parser.add_argument( 123 | "--with_prior_preservation", 124 | default=False, 125 | action="store_true", 126 | help="Flag to add prior preservation loss.", 127 | ) 128 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") 129 | parser.add_argument( 130 | "--num_class_images", 131 | type=int, 132 | default=100, 133 | help=( 134 | "Minimal class images for prior preservation loss. If there are not enough images already present in" 135 | " class_data_dir, additional images will be sampled with class_prompt." 136 | ), 137 | ) 138 | parser.add_argument( 139 | "--output_dir", 140 | type=str, 141 | default="flux-dreambooth", 142 | help="The output directory where the model predictions and checkpoints will be written.", 143 | ) 144 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 145 | parser.add_argument( 146 | "--resolution", 147 | type=int, 148 | default=512, 149 | help=( 150 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 151 | " resolution" 152 | ), 153 | ) 154 | parser.add_argument( 155 | "--center_crop", 156 | default=False, 157 | action="store_true", 158 | help=( 159 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 160 | " cropped. The images will be resized to the resolution first before cropping." 161 | ), 162 | ) 163 | parser.add_argument( 164 | "--random_flip", 165 | action="store_true", 166 | help="whether to randomly flip images horizontally", 167 | ) 168 | parser.add_argument( 169 | "--train_text_encoder", 170 | action="store_true", 171 | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", 172 | ) 173 | parser.add_argument( 174 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 175 | ) 176 | parser.add_argument( 177 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 178 | ) 179 | parser.add_argument("--num_train_epochs", type=int, default=1) 180 | parser.add_argument( 181 | "--max_train_steps", 182 | type=int, 183 | default=None, 184 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 185 | ) 186 | parser.add_argument( 187 | "--checkpointing_steps", 188 | type=int, 189 | default=500, 190 | help=( 191 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" 192 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" 193 | " training using `--resume_from_checkpoint`." 194 | ), 195 | ) 196 | parser.add_argument( 197 | "--checkpoints_total_limit", 198 | type=int, 199 | default=None, 200 | help=("Max number of checkpoints to store."), 201 | ) 202 | parser.add_argument( 203 | "--resume_from_checkpoint", 204 | type=str, 205 | default=None, 206 | help=( 207 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 208 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 209 | ), 210 | ) 211 | parser.add_argument( 212 | "--gradient_accumulation_steps", 213 | type=int, 214 | default=1, 215 | help="Number of updates steps to accumulate before performing a backward/update pass.", 216 | ) 217 | parser.add_argument( 218 | "--gradient_checkpointing", 219 | action="store_true", 220 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 221 | ) 222 | parser.add_argument( 223 | "--learning_rate", 224 | type=float, 225 | default=1e-4, 226 | help="Initial learning rate (after the potential warmup period) to use.", 227 | ) 228 | 229 | parser.add_argument( 230 | "--guidance_scale", 231 | type=float, 232 | default=3.5, 233 | help="the FLUX.1 dev variant is a guidance distilled model", 234 | ) 235 | 236 | parser.add_argument( 237 | "--text_encoder_lr", 238 | type=float, 239 | default=5e-6, 240 | help="Text encoder learning rate to use.", 241 | ) 242 | parser.add_argument( 243 | "--scale_lr", 244 | action="store_true", 245 | default=False, 246 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 247 | ) 248 | parser.add_argument( 249 | "--lr_scheduler", 250 | type=str, 251 | default="constant", 252 | help=( 253 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 254 | ' "constant", "constant_with_warmup"]' 255 | ), 256 | ) 257 | parser.add_argument( 258 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 259 | ) 260 | parser.add_argument( 261 | "--lr_num_cycles", 262 | type=int, 263 | default=1, 264 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 265 | ) 266 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 267 | parser.add_argument( 268 | "--dataloader_num_workers", 269 | type=int, 270 | default=0, 271 | help=( 272 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 273 | ), 274 | ) 275 | parser.add_argument( 276 | "--weighting_scheme", 277 | type=str, 278 | default="none", 279 | choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], 280 | help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), 281 | ) 282 | parser.add_argument( 283 | "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." 284 | ) 285 | parser.add_argument( 286 | "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." 287 | ) 288 | parser.add_argument( 289 | "--mode_scale", 290 | type=float, 291 | default=1.29, 292 | help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", 293 | ) 294 | parser.add_argument( 295 | "--optimizer", 296 | type=str, 297 | default="AdamW", 298 | help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), 299 | ) 300 | 301 | parser.add_argument( 302 | "--use_8bit_adam", 303 | action="store_true", 304 | help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", 305 | ) 306 | 307 | parser.add_argument( 308 | "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." 309 | ) 310 | parser.add_argument( 311 | "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." 312 | ) 313 | parser.add_argument( 314 | "--prodigy_beta3", 315 | type=float, 316 | default=None, 317 | help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " 318 | "uses the value of square root of beta2. Ignored if optimizer is adamW", 319 | ) 320 | parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") 321 | parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") 322 | parser.add_argument( 323 | "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" 324 | ) 325 | 326 | parser.add_argument( 327 | "--adam_epsilon", 328 | type=float, 329 | default=1e-08, 330 | help="Epsilon value for the Adam optimizer and Prodigy optimizers.", 331 | ) 332 | 333 | parser.add_argument( 334 | "--prodigy_use_bias_correction", 335 | type=bool, 336 | default=True, 337 | help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", 338 | ) 339 | parser.add_argument( 340 | "--prodigy_safeguard_warmup", 341 | type=bool, 342 | default=True, 343 | help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " 344 | "Ignored if optimizer is adamW", 345 | ) 346 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 347 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 348 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 349 | parser.add_argument( 350 | "--hub_model_id", 351 | type=str, 352 | default=None, 353 | help="The name of the repository to keep in sync with the local `output_dir`.", 354 | ) 355 | parser.add_argument( 356 | "--logging_dir", 357 | type=str, 358 | default="logs", 359 | help=( 360 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 361 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 362 | ), 363 | ) 364 | parser.add_argument( 365 | "--allow_tf32", 366 | action="store_true", 367 | help=( 368 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 369 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 370 | ), 371 | ) 372 | parser.add_argument( 373 | "--report_to", 374 | type=str, 375 | default="tensorboard", 376 | help=( 377 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 378 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 379 | ), 380 | ) 381 | parser.add_argument( 382 | "--mixed_precision", 383 | type=str, 384 | default=None, 385 | choices=["no", "fp8", "fp16", "bf16"], 386 | help=( 387 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 388 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 389 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 390 | ), 391 | ) 392 | parser.add_argument( 393 | "--prior_generation_precision", 394 | type=str, 395 | default=None, 396 | choices=["no", "fp32", "fp16", "bf16"], 397 | help=( 398 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 399 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." 400 | ), 401 | ) 402 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 403 | 404 | parser.add_argument( 405 | "--dataroot", 406 | type=str, 407 | help=( 408 | 'The dataset dir' 409 | ), 410 | ) 411 | parser.add_argument( 412 | "--train_data_list", 413 | type=str, 414 | default=None, 415 | ) 416 | parser.add_argument( 417 | "--use_local_model", action='store_true', help="Load local model" 418 | ) 419 | parser.add_argument( 420 | "--local_model_path", type=str, default=None, help="Local path to the model checkpoint" 421 | ) 422 | parser.add_argument("--validation_steps", type=int, default=1000) 423 | parser.add_argument("--num_inference_steps", type=int, default=50) 424 | parser.add_argument( 425 | "--validation_data_list", 426 | type=str, 427 | default=None, 428 | ) 429 | parser.add_argument( 430 | "--train_verification_list", 431 | type=str, 432 | default=None, 433 | help="The train verification list" 434 | ) 435 | 436 | parser.add_argument( 437 | "--width", type=int, default=512, help="The width for generated image" 438 | ) 439 | parser.add_argument( 440 | "--height", type=int, default=512, help="The height for generated image" 441 | ) 442 | 443 | parser.add_argument( 444 | "--controlnet_pretrained_model_name_or_path", 445 | type=str, 446 | default=None, 447 | # required=True, 448 | help="Path to pretrained model or model identifier from huggingface.co/models.", 449 | ) 450 | parser.add_argument( 451 | "--dropout_prob", type=float, default=0, help="The dropout probability" 452 | ) 453 | parser.add_argument( 454 | "--train_base_model", action='store_true', help="Train base model" 455 | ) 456 | parser.add_argument( 457 | "--vit_path", type=str, default="openai/clip-vit-base-patch16", help="The path to the vit model" 458 | ) 459 | parser.add_argument( 460 | "--use_extra_residual", action='store_true', help="Insert extra residual" 461 | ) 462 | parser.add_argument( 463 | "--use_extra_residual_concat", action='store_true', help="Insert extra residual" 464 | ) 465 | parser.add_argument( 466 | "--insert_dino_hs", action='store_true', help="Insert DINO hidden states" 467 | ) 468 | parser.add_argument( 469 | "--lora_rank", type=int, default=8, help="The rank for the LoRA layers" 470 | ) 471 | parser.add_argument( 472 | "--lora_layers", type=str, default=None, help="The layers to apply LoRA to" 473 | ) 474 | parser.add_argument( 475 | "--pretrained_lora_path", type=str, default=None, help="The path to the pretrained LoRA weights" 476 | ) 477 | parser.add_argument("--fuse_lora", action="store_true", help="Whether to fuse the LoRA weights into the base model") 478 | parser.add_argument("--lora_scale", type=float, default=1.0, help="Scale factor for LoRA weights when fusing") 479 | 480 | parser.add_argument( 481 | "--num_img", type=int, default=1, help="The number of images to generate" 482 | ) 483 | parser.add_argument( 484 | "--pretrained_inpaint_model_name_or_path", type=str, default="xiaozaa/flux1-fill-dev-diffusers", help="The path to the pretrained inpaint model" 485 | ) 486 | parser.add_argument( 487 | "--train_lora", action='store_true', help="Train LoRA" 488 | ) 489 | 490 | if input_args is not None: 491 | args = parser.parse_args(input_args) 492 | else: 493 | args = parser.parse_args() 494 | 495 | if args.dataset_name is None and args.instance_data_dir is None: 496 | raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") 497 | 498 | if args.dataset_name is not None and args.instance_data_dir is not None: 499 | raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") 500 | 501 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 502 | if env_local_rank != -1 and env_local_rank != args.local_rank: 503 | args.local_rank = env_local_rank 504 | 505 | if args.with_prior_preservation: 506 | if args.class_data_dir is None: 507 | raise ValueError("You must specify a data directory for class images.") 508 | if args.class_prompt is None: 509 | raise ValueError("You must specify prompt for class images.") 510 | else: 511 | # logger is not available yet 512 | if args.class_data_dir is not None: 513 | warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") 514 | if args.class_prompt is not None: 515 | warnings.warn("You need not use --class_prompt without --with_prior_preservation.") 516 | 517 | return args 518 | 519 | -------------------------------------------------------------------------------- /LICENSE-MODEL: -------------------------------------------------------------------------------- 1 | FLUX.1 [dev] Non-Commercial License 2 | Black Forest Labs, Inc. (“we” or “our” or “Company”) is pleased to make available the weights, parameters and inference code for the FLUX.1 [dev] Model (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX.1 [dev] Non-Commercial License (“License”). The “FLUX.1 [dev] Model” means the FLUX.1 [dev] AI models, including FLUX.1 [dev], FLUX.1 Fill [dev], FLUX.1 Depth [dev], FLUX.1 Canny [dev], FLUX.1 Redux [dev], FLUX.1 Canny [dev] LoRA and FLUX.1 Depth [dev] LoRA, and their elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX.1 [dev] AI models made available by Company under this License, including if any, the technical documentation, manuals and instructions for the use and operation thereof (collectively, “FLUX.1 [dev] Model”). 3 | By downloading, accessing, use, Distributing (as defined below), or creating a Derivative (as defined below) of the FLUX.1 [dev] Model, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to access, use, Distribute or create a Derivative of the FLUX.1 [dev] Model and you must immediately cease using the FLUX.1 [dev] Model. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to us that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the FLUX.1 [dev] Model on behalf of your employer or other entity. 4 | 1. Definitions. Capitalized terms used in this License but not defined herein have the following meanings: 5 | a. “Derivative” means any (i) modified version of the FLUX.1 [dev] Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the FLUX.1 [dev] Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered Derivatives under this License. 6 | b. “Distribution” or “Distribute” or “Distributing” means providing or making available, by any means, a copy of the FLUX.1 [dev] Models and/or the Derivatives as the case may be. 7 | c. “Non-Commercial Purpose” means any of the following uses, but only so far as you do not receive any direct or indirect payment arising from the use of the model or its output: (i) personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, or otherwise not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities; (ii) use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development in a non-production environment, (iii) use by any charitable organization for charitable purposes, or for testing or evaluation. For clarity, use for revenue-generating activity or direct interactions with or impacts on end users, or use to train, fine tune or distill other models for commercial use is not a Non-Commercial purpose. 8 | d. “Outputs” means any content generated by the operation of the FLUX.1 [dev] Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a FLUX.1 [dev] Models, such as any fine-tuned versions of the FLUX.1 [dev] Models, the weights, or parameters. 9 | e. “you” or “your” means the individual or entity entering into this License with Company. 10 | 2. License Grant. 11 | a. License. Subject to your compliance with this License, Company grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license to access, use, create Derivatives of, and Distribute the FLUX.1 [dev] Models solely for your Non-Commercial Purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Company’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License. Any restrictions set forth herein in regarding the FLUX.1 [dev] Model also applies to any Derivative you create or that are created on your behalf. 12 | b. Non-Commercial Use Only. You may only access, use, Distribute, or creative Derivatives of or the FLUX.1 [dev] Model or Derivatives for Non-Commercial Purposes. If You want to use a FLUX.1 [dev] Model a Derivative for any purpose that is not expressly authorized under this License, such as for a commercial activity, you must request a license from Company, which Company may grant to you in Company’s sole discretion and which additional use may be subject to a fee, royalty or other revenue share. Please contact Company at the following e-mail address if you want to discuss such a license: info@blackforestlabs.ai. 13 | c. Reserved Rights. The grant of rights expressly set forth in this License are the complete grant of rights to you in the FLUX.1 [dev] Model, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Company and its licensors reserve all rights not expressly granted by this License. 14 | d. Outputs. We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs you generate and their subsequent uses in accordance with this License. You may use Output for any purpose (including for commercial purposes), except as expressly prohibited herein. You may not use the Output to train, fine-tune or distill a model that is competitive with the FLUX.1 [dev] Model. 15 | 3. Distribution. Subject to this License, you may Distribute copies of the FLUX.1 [dev] Model and/or Derivatives made by you, under the following conditions: 16 | a. you must make available a copy of this License to third-party recipients of the FLUX.1 [dev] Models and/or Derivatives you Distribute, and specify that any rights to use the FLUX.1 [dev] Models and/or Derivatives shall be directly granted by Company to said third-party recipients pursuant to this License; 17 | b. you must make prominently display the following notice alongside the Distribution of the FLUX.1 [dev] Model or Derivative (such as via a “Notice” text file distributed as part of such FLUX.1 [dev] Model or Derivative) (the “Attribution Notice”): 18 | “The FLUX.1 [dev] Model is licensed by Black Forest Labs. Inc. under the FLUX.1 [dev] Non-Commercial License. Copyright Black Forest Labs. Inc. 19 | IN NO EVENT SHALL BLACK FOREST LABS, INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.” 20 | c. in the case of Distribution of Derivatives made by you, you must also include in the Attribution Notice a statement that you have modified the applicable FLUX.1 [dev] Model; and 21 | d. in the case of Distribution of Derivatives made by you, any terms and conditions you impose on any third-party recipients relating to Derivatives made by or for you shall neither limit such third-party recipients’ use of the FLUX.1 [dev] Model or any Derivatives made by or for Company in accordance with this License nor conflict with any of its terms and conditions. 22 | e. In the case of Distribution of Derivatives made by you, you must not misrepresent or imply, through any means, that the Derivatives made by or for you and/or any modified version of the FLUX.1 [dev] Model you Distribute under your name and responsibility is an official product of the Company or has been endorsed, approved or validated by the Company, unless you are authorized by Company to do so in writing. 23 | 4. Restrictions. You will not, and will not permit, assist or cause any third party to 24 | a. use, modify, copy, reproduce, create Derivatives of, or Distribute the FLUX.1 [dev] Model (or any Derivative thereof, or any data produced by the FLUX.1 [dev] Model), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing; 25 | b. alter or remove copyright and other proprietary notices which appear on or in any portion of the FLUX.1 [dev] Model; 26 | c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Company in connection with the FLUX.1 [dev] Model, or to circumvent or remove any usage restrictions, or to enable functionality disabled by FLUX.1 [dev] Model; or 27 | d. offer or impose any terms on the FLUX.1 [dev] Model that alter, restrict, or are inconsistent with the terms of this License. 28 | e. violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”) in connection with your use or Distribution of any FLUX.1 [dev] Model; 29 | f. directly or indirectly Distribute, export, or otherwise transfer FLUX.1 [dev] Model (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download FLUX.1 [dev] Model if you or they are (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods. 30 | 5. DISCLAIMERS. THE FLUX.1 [dev] MODEL IS PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE FLUX.1 [dev] MODEL, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE FLUX.1 [dev] MODEL WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS. 31 | 6. LIMITATION OF LIABILITY. TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS OR DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE FLUX.1 [dev] MODEL, ITS CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “MODEL MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE MODEL MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE MODEL MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE MODEL MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE. 32 | 7. INDEMNIFICATION 33 | 34 | You will indemnify, defend and hold harmless Company and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Company Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Company Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to (a) your access to or use of the FLUX.1 [dev] Model (as well as any Output, results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Company Parties of any such Claims, and cooperate with Company Parties in defending such Claims. You will also grant the Company Parties sole control of the defense or settlement, at Company’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Company or the other Company Parties. 35 | 8. Termination; Survival. 36 | a. This License will automatically terminate upon any breach by you of the terms of this License. 37 | b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you. 38 | c. If You initiate any legal action or proceedings against Company or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the FLUX.1 [dev] Model or any Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by you, then any licenses granted to you under this License will immediately terminate as of the date such legal action or claim is filed or initiated. 39 | d. Upon termination of this License, you must cease all use, access or Distribution of the FLUX.1 [dev] Model and any Derivatives. The following sections survive termination of this License 2(c), 2(d), 4-11. 40 | 9. Third Party Materials. The FLUX.1 [dev] Model may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Company does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk. 41 | 10. Trademarks. You have not been granted any trademark license as part of this License and may not use any name or mark associated with Company without the prior written permission of Company, except to the extent necessary to make the reference required in the Attribution Notice as specified above or as is reasonably necessary in describing the FLUX.1 [dev] Model and its creators. 42 | 11. General. This License will be governed and construed under the laws of the State of Delaware without regard to conflicts of law provisions. If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Company to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Company regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Company regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Company. 43 | 44 | 45 | Copyright (c) 2021, NeStyle Inc. 46 | All rights reserved. 47 | 48 | 49 | Attribution-NonCommercial 4.0 International 50 | 51 | ======================================================================= 52 | 53 | Creative Commons Corporation ("Creative Commons") is not a law firm and 54 | does not provide legal services or legal advice. Distribution of 55 | Creative Commons public licenses does not create a lawyer-client or 56 | other relationship. Creative Commons makes its licenses and related 57 | information available on an "as-is" basis. Creative Commons gives no 58 | warranties regarding its licenses, any material licensed under their 59 | terms and conditions, or any related information. Creative Commons 60 | disclaims all liability for damages resulting from their use to the 61 | fullest extent possible. 62 | 63 | Using Creative Commons Public Licenses 64 | 65 | Creative Commons public licenses provide a standard set of terms and 66 | conditions that creators and other rights holders may use to share 67 | original works of authorship and other material subject to copyright 68 | and certain other rights specified in the public license below. The 69 | following considerations are for informational purposes only, are not 70 | exhaustive, and do not form part of our licenses. 71 | 72 | Considerations for licensors: Our public licenses are 73 | intended for use by those authorized to give the public 74 | permission to use material in ways otherwise restricted by 75 | copyright and certain other rights. Our licenses are 76 | irrevocable. Licensors should read and understand the terms 77 | and conditions of the license they choose before applying it. 78 | Licensors should also secure all rights necessary before 79 | applying our licenses so that the public can reuse the 80 | material as expected. Licensors should clearly mark any 81 | material not subject to the license. This includes other CC- 82 | licensed material, or material used under an exception or 83 | limitation to copyright. More considerations for licensors: 84 | wiki.creativecommons.org/Considerations_for_licensors 85 | 86 | Considerations for the public: By using one of our public 87 | licenses, a licensor grants the public permission to use the 88 | licensed material under specified terms and conditions. If 89 | the licensor's permission is not necessary for any reason--for 90 | example, because of any applicable exception or limitation to 91 | copyright--then that use is not regulated by the license. Our 92 | licenses grant only permissions under copyright and certain 93 | other rights that a licensor has authority to grant. Use of 94 | the licensed material may still be restricted for other 95 | reasons, including because others have copyright or other 96 | rights in the material. A licensor may make special requests, 97 | such as asking that all changes be marked or described. 98 | Although not required by our licenses, you are encouraged to 99 | respect those requests where reasonable. More_considerations 100 | for the public: 101 | wiki.creativecommons.org/Considerations_for_licensees 102 | 103 | ======================================================================= 104 | 105 | Creative Commons Attribution-NonCommercial 4.0 International Public 106 | License 107 | 108 | By exercising the Licensed Rights (defined below), You accept and agree 109 | to be bound by the terms and conditions of this Creative Commons 110 | Attribution-NonCommercial 4.0 International Public License ("Public 111 | License"). To the extent this Public License may be interpreted as a 112 | contract, You are granted the Licensed Rights in consideration of Your 113 | acceptance of these terms and conditions, and the Licensor grants You 114 | such rights in consideration of benefits the Licensor receives from 115 | making the Licensed Material available under these terms and 116 | conditions. 117 | 118 | 119 | Section 1 -- Definitions. 120 | 121 | a. Adapted Material means material subject to Copyright and Similar 122 | Rights that is derived from or based upon the Licensed Material 123 | and in which the Licensed Material is translated, altered, 124 | arranged, transformed, or otherwise modified in a manner requiring 125 | permission under the Copyright and Similar Rights held by the 126 | Licensor. For purposes of this Public License, where the Licensed 127 | Material is a musical work, performance, or sound recording, 128 | Adapted Material is always produced where the Licensed Material is 129 | synched in timed relation with a moving image. 130 | 131 | b. Adapter's License means the license You apply to Your Copyright 132 | and Similar Rights in Your contributions to Adapted Material in 133 | accordance with the terms and conditions of this Public License. 134 | 135 | c. Copyright and Similar Rights means copyright and/or similar rights 136 | closely related to copyright including, without limitation, 137 | performance, broadcast, sound recording, and Sui Generis Database 138 | Rights, without regard to how the rights are labeled or 139 | categorized. For purposes of this Public License, the rights 140 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 141 | Rights. 142 | d. Effective Technological Measures means those measures that, in the 143 | absence of proper authority, may not be circumvented under laws 144 | fulfilling obligations under Article 11 of the WIPO Copyright 145 | Treaty adopted on December 20, 1996, and/or similar international 146 | agreements. 147 | 148 | e. Exceptions and Limitations means fair use, fair dealing, and/or 149 | any other exception or limitation to Copyright and Similar Rights 150 | that applies to Your use of the Licensed Material. 151 | 152 | f. Licensed Material means the artistic or literary work, database, 153 | or other material to which the Licensor applied this Public 154 | License. 155 | 156 | g. Licensed Rights means the rights granted to You subject to the 157 | terms and conditions of this Public License, which are limited to 158 | all Copyright and Similar Rights that apply to Your use of the 159 | Licensed Material and that the Licensor has authority to license. 160 | 161 | h. Licensor means the individual(s) or entity(ies) granting rights 162 | under this Public License. 163 | 164 | i. NonCommercial means not primarily intended for or directed towards 165 | commercial advantage or monetary compensation. For purposes of 166 | this Public License, the exchange of the Licensed Material for 167 | other material subject to Copyright and Similar Rights by digital 168 | file-sharing or similar means is NonCommercial provided there is 169 | no payment of monetary compensation in connection with the 170 | exchange. 171 | 172 | j. Share means to provide material to the public by any means or 173 | process that requires permission under the Licensed Rights, such 174 | as reproduction, public display, public performance, distribution, 175 | dissemination, communication, or importation, and to make material 176 | available to the public including in ways that members of the 177 | public may access the material from a place and at a time 178 | individually chosen by them. 179 | 180 | k. Sui Generis Database Rights means rights other than copyright 181 | resulting from Directive 96/9/EC of the European Parliament and of 182 | the Council of 11 March 1996 on the legal protection of databases, 183 | as amended and/or succeeded, as well as other essentially 184 | equivalent rights anywhere in the world. 185 | 186 | l. You means the individual or entity exercising the Licensed Rights 187 | under this Public License. Your has a corresponding meaning. 188 | 189 | 190 | Section 2 -- Scope. 191 | 192 | a. License grant. 193 | 194 | 1. Subject to the terms and conditions of this Public License, 195 | the Licensor hereby grants You a worldwide, royalty-free, 196 | non-sublicensable, non-exclusive, irrevocable license to 197 | exercise the Licensed Rights in the Licensed Material to: 198 | 199 | a. reproduce and Share the Licensed Material, in whole or 200 | in part, for NonCommercial purposes only; and 201 | 202 | b. produce, reproduce, and Share Adapted Material for 203 | NonCommercial purposes only. 204 | 205 | 2. Exceptions and Limitations. For the avoidance of doubt, where 206 | Exceptions and Limitations apply to Your use, this Public 207 | License does not apply, and You do not need to comply with 208 | its terms and conditions. 209 | 210 | 3. Term. The term of this Public License is specified in Section 211 | 6(a). 212 | 213 | 4. Media and formats; technical modifications allowed. The 214 | Licensor authorizes You to exercise the Licensed Rights in 215 | all media and formats whether now known or hereafter created, 216 | and to make technical modifications necessary to do so. The 217 | Licensor waives and/or agrees not to assert any right or 218 | authority to forbid You from making technical modifications 219 | necessary to exercise the Licensed Rights, including 220 | technical modifications necessary to circumvent Effective 221 | Technological Measures. For purposes of this Public License, 222 | simply making modifications authorized by this Section 2(a) 223 | (4) never produces Adapted Material. 224 | 225 | 5. Downstream recipients. 226 | 227 | a. Offer from the Licensor -- Licensed Material. Every 228 | recipient of the Licensed Material automatically 229 | receives an offer from the Licensor to exercise the 230 | Licensed Rights under the terms and conditions of this 231 | Public License. 232 | 233 | b. No downstream restrictions. You may not offer or impose 234 | any additional or different terms or conditions on, or 235 | apply any Effective Technological Measures to, the 236 | Licensed Material if doing so restricts exercise of the 237 | Licensed Rights by any recipient of the Licensed 238 | Material. 239 | 240 | 6. No endorsement. Nothing in this Public License constitutes or 241 | may be construed as permission to assert or imply that You 242 | are, or that Your use of the Licensed Material is, connected 243 | with, or sponsored, endorsed, or granted official status by, 244 | the Licensor or others designated to receive attribution as 245 | provided in Section 3(a)(1)(A)(i). 246 | 247 | b. Other rights. 248 | 249 | 1. Moral rights, such as the right of integrity, are not 250 | licensed under this Public License, nor are publicity, 251 | privacy, and/or other similar personality rights; however, to 252 | the extent possible, the Licensor waives and/or agrees not to 253 | assert any such rights held by the Licensor to the limited 254 | extent necessary to allow You to exercise the Licensed 255 | Rights, but not otherwise. 256 | 257 | 2. Patent and trademark rights are not licensed under this 258 | Public License. 259 | 260 | 3. To the extent possible, the Licensor waives any right to 261 | collect royalties from You for the exercise of the Licensed 262 | Rights, whether directly or through a collecting society 263 | under any voluntary or waivable statutory or compulsory 264 | licensing scheme. In all other cases the Licensor expressly 265 | reserves any right to collect such royalties, including when 266 | the Licensed Material is used other than for NonCommercial 267 | purposes. 268 | 269 | 270 | Section 3 -- License Conditions. 271 | 272 | Your exercise of the Licensed Rights is expressly made subject to the 273 | following conditions. 274 | 275 | a. Attribution. 276 | 277 | 1. If You Share the Licensed Material (including in modified 278 | form), You must: 279 | 280 | a. retain the following if it is supplied by the Licensor 281 | with the Licensed Material: 282 | 283 | i. identification of the creator(s) of the Licensed 284 | Material and any others designated to receive 285 | attribution, in any reasonable manner requested by 286 | the Licensor (including by pseudonym if 287 | designated); 288 | 289 | ii. a copyright notice; 290 | 291 | iii. a notice that refers to this Public License; 292 | 293 | iv. a notice that refers to the disclaimer of 294 | warranties; 295 | 296 | v. a URI or hyperlink to the Licensed Material to the 297 | extent reasonably practicable; 298 | 299 | b. indicate if You modified the Licensed Material and 300 | retain an indication of any previous modifications; and 301 | 302 | c. indicate the Licensed Material is licensed under this 303 | Public License, and include the text of, or the URI or 304 | hyperlink to, this Public License. 305 | 306 | 2. You may satisfy the conditions in Section 3(a)(1) in any 307 | reasonable manner based on the medium, means, and context in 308 | which You Share the Licensed Material. For example, it may be 309 | reasonable to satisfy the conditions by providing a URI or 310 | hyperlink to a resource that includes the required 311 | information. 312 | 313 | 3. If requested by the Licensor, You must remove any of the 314 | information required by Section 3(a)(1)(A) to the extent 315 | reasonably practicable. 316 | 317 | 4. If You Share Adapted Material You produce, the Adapter's 318 | License You apply must not prevent recipients of the Adapted 319 | Material from complying with this Public License. 320 | 321 | 322 | Section 4 -- Sui Generis Database Rights. 323 | 324 | Where the Licensed Rights include Sui Generis Database Rights that 325 | apply to Your use of the Licensed Material: 326 | 327 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 328 | to extract, reuse, reproduce, and Share all or a substantial 329 | portion of the contents of the database for NonCommercial purposes 330 | only; 331 | 332 | b. if You include all or a substantial portion of the database 333 | contents in a database in which You have Sui Generis Database 334 | Rights, then the database in which You have Sui Generis Database 335 | Rights (but not its individual contents) is Adapted Material; and 336 | 337 | c. You must comply with the conditions in Section 3(a) if You Share 338 | all or a substantial portion of the contents of the database. 339 | 340 | For the avoidance of doubt, this Section 4 supplements and does not 341 | replace Your obligations under this Public License where the Licensed 342 | Rights include other Copyright and Similar Rights. 343 | 344 | 345 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 346 | 347 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 348 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 349 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 350 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 351 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 352 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 353 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 354 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 355 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 356 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 357 | 358 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 359 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 360 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 361 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 362 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 363 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 364 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 365 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 366 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 367 | 368 | c. The disclaimer of warranties and limitation of liability provided 369 | above shall be interpreted in a manner that, to the extent 370 | possible, most closely approximates an absolute disclaimer and 371 | waiver of all liability. 372 | 373 | 374 | Section 6 -- Term and Termination. 375 | 376 | a. This Public License applies for the term of the Copyright and 377 | Similar Rights licensed here. However, if You fail to comply with 378 | this Public License, then Your rights under this Public License 379 | terminate automatically. 380 | 381 | b. Where Your right to use the Licensed Material has terminated under 382 | Section 6(a), it reinstates: 383 | 384 | 1. automatically as of the date the violation is cured, provided 385 | it is cured within 30 days of Your discovery of the 386 | violation; or 387 | 388 | 2. upon express reinstatement by the Licensor. 389 | 390 | For the avoidance of doubt, this Section 6(b) does not affect any 391 | right the Licensor may have to seek remedies for Your violations 392 | of this Public License. 393 | 394 | c. For the avoidance of doubt, the Licensor may also offer the 395 | Licensed Material under separate terms or conditions or stop 396 | distributing the Licensed Material at any time; however, doing so 397 | will not terminate this Public License. 398 | 399 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 400 | License. 401 | 402 | 403 | Section 7 -- Other Terms and Conditions. 404 | 405 | a. The Licensor shall not be bound by any additional or different 406 | terms or conditions communicated by You unless expressly agreed. 407 | 408 | b. Any arrangements, understandings, or agreements regarding the 409 | Licensed Material not stated herein are separate from and 410 | independent of the terms and conditions of this Public License. 411 | 412 | 413 | Section 8 -- Interpretation. 414 | 415 | a. For the avoidance of doubt, this Public License does not, and 416 | shall not be interpreted to, reduce, limit, restrict, or impose 417 | conditions on any use of the Licensed Material that could lawfully 418 | be made without permission under this Public License. 419 | 420 | b. To the extent possible, if any provision of this Public License is 421 | deemed unenforceable, it shall be automatically reformed to the 422 | minimum extent necessary to make it enforceable. If the provision 423 | cannot be reformed, it shall be severed from this Public License 424 | without affecting the enforceability of the remaining terms and 425 | conditions. 426 | 427 | c. No term or condition of this Public License will be waived and no 428 | failure to comply consented to unless expressly agreed to by the 429 | Licensor. 430 | 431 | d. Nothing in this Public License constitutes or may be interpreted 432 | as a limitation upon, or waiver of, any privileges and immunities 433 | that apply to the Licensor or You, including from the legal 434 | processes of any jurisdiction or authority. 435 | 436 | ======================================================================= 437 | 438 | Creative Commons is not a party to its public 439 | licenses. Notwithstanding, Creative Commons may elect to apply one of 440 | its public licenses to material it publishes and in those instances 441 | will be considered the "Licensor." The text of the Creative Commons 442 | public licenses is dedicated to the public domain under the CC0 Public 443 | Domain Dedication. Except for the limited purpose of indicating that 444 | material is shared under a Creative Commons public license or as 445 | otherwise permitted by the Creative Commons policies published at 446 | creativecommons.org/policies, Creative Commons does not authorize the 447 | use of the trademark "Creative Commons" or any other trademark or logo 448 | of Creative Commons without its prior written consent including, 449 | without limitation, in connection with any unauthorized modifications 450 | to any of its public licenses or any other arrangements, 451 | understandings, or agreements concerning use of licensed material. For 452 | the avoidance of doubt, this paragraph does not form part of the 453 | public licenses. 454 | 455 | Creative Commons may be contacted at creativecommons.org. 456 | -------------------------------------------------------------------------------- /train_flux_inpaint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import copy 18 | import gc 19 | import itertools 20 | import logging 21 | import math 22 | import os 23 | import random 24 | import shutil 25 | import warnings 26 | from contextlib import nullcontext 27 | from pathlib import Path 28 | 29 | import numpy as np 30 | import torch 31 | import torch.utils.checkpoint 32 | import transformers 33 | from accelerate import Accelerator 34 | from accelerate.logging import get_logger 35 | from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed, DistributedType 36 | from huggingface_hub import create_repo, upload_folder 37 | from huggingface_hub.utils import insecure_hashlib 38 | from PIL import Image 39 | from PIL.ImageOps import exif_transpose 40 | from torch.utils.data import Dataset 41 | from torchvision import transforms 42 | from torchvision.transforms.functional import crop 43 | from tqdm.auto import tqdm 44 | from transformers import CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast 45 | from image_datasets.cp_dataset import VitonHDTestDataset 46 | from paser_helper import parse_args 47 | from src.flux.train_utils import prepare_fill_with_mask, prepare_latents, encode_images_to_latents 48 | from diffusers import FluxTransformer2DModel, FluxFillPipeline 49 | # from src.flux.pipeline_flux_inpaint import FluxInpaintingPipeline 50 | from diffusers.image_processor import VaeImageProcessor 51 | from deepspeed.runtime.engine import DeepSpeedEngine 52 | 53 | import diffusers 54 | from diffusers import ( 55 | AutoencoderKL, 56 | FlowMatchEulerDiscreteScheduler, 57 | ) 58 | from diffusers.optimization import get_scheduler 59 | from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 60 | from diffusers.utils import ( 61 | check_min_version, 62 | is_wandb_available, 63 | load_image, 64 | ) 65 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card 66 | from diffusers.utils.torch_utils import is_compiled_module 67 | 68 | 69 | if is_wandb_available(): 70 | import wandb 71 | 72 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 73 | check_min_version("0.30.2") 74 | 75 | logger = get_logger(__name__) 76 | 77 | 78 | def save_model_card( 79 | repo_id: str, 80 | images=None, 81 | base_model: str = None, 82 | train_text_encoder=False, 83 | instance_prompt=None, 84 | repo_folder=None, 85 | ): 86 | widget_dict = [] 87 | if images is not None: 88 | for i, image in enumerate(images): 89 | image.save(os.path.join(repo_folder, f"image_{i}.png")) 90 | widget_dict.append( 91 | {"text": " ", "output": {"url": f"image_{i}.png"}} 92 | ) 93 | 94 | model_description = f""" 95 | # Flux [dev] DreamBooth - {repo_id} 96 | 97 | 98 | 99 | ## Model description 100 | 101 | These are {repo_id} DreamBooth weights for {base_model}. 102 | 103 | The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md). 104 | 105 | Was the text encoder fine-tuned? {train_text_encoder}. 106 | 107 | ## Trigger words 108 | 109 | You should use `{instance_prompt}` to trigger the image generation. 110 | 111 | ## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) 112 | 113 | ```py 114 | from diffusers import AutoPipelineForText2Image 115 | import torch 116 | pipeline = AutoPipelineForText2Image.from_pretrained('{repo_id}', torch_dtype=torch.bfloat16).to('cuda') 117 | ``` 118 | 119 | ## License 120 | 121 | Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md). 122 | """ 123 | model_card = load_or_create_model_card( 124 | repo_id_or_path=repo_id, 125 | from_training=True, 126 | license="other", 127 | base_model=base_model, 128 | prompt=instance_prompt, 129 | model_description=model_description, 130 | widget=widget_dict, 131 | ) 132 | tags = [ 133 | "text-to-image", 134 | "diffusers-training", 135 | "diffusers", 136 | "flux", 137 | "flux-diffusers", 138 | "template:sd-lora", 139 | ] 140 | 141 | model_card = populate_model_card(model_card, tags=tags) 142 | model_card.save(os.path.join(repo_folder, "README.md")) 143 | 144 | 145 | def load_text_encoders(class_one, class_two): 146 | text_encoder_one = class_one.from_pretrained( 147 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant 148 | ) 149 | text_encoder_two = class_two.from_pretrained( 150 | args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant 151 | ) 152 | return text_encoder_one, text_encoder_two 153 | 154 | 155 | from tqdm import tqdm 156 | 157 | def log_validation( 158 | pipeline, 159 | args, 160 | accelerator, 161 | epoch, 162 | dataloader, 163 | tag, 164 | is_final_validation=False, 165 | ): 166 | logger.info( 167 | f"Running {tag}... \n " 168 | ) 169 | 170 | pipeline = pipeline.to(accelerator.device) 171 | # pipeline.set_progress_bar_config(disable=True) 172 | 173 | # run inference 174 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None 175 | # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() 176 | autocast_ctx = nullcontext() 177 | 178 | with autocast_ctx: 179 | images = [] 180 | prompts = [] 181 | control_images = [] 182 | control_masks = [] 183 | for batch in dataloader: 184 | 185 | # prompt = batch['caption_cloth'] 186 | prompt = ["" 187 | f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " 188 | f"[IMAGE1] Detailed product shot of a clothing" 189 | f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting." 190 | # "[IMAGE1] A sleek black long-sleeved top is displayed against a neutral backdrop, featuring distinctive elbow pads, " 191 | # "a classic round neckline, and a cropped silhouette that combines 90s-inspired design with modern minimalism. " 192 | # "The garment showcases clean lines and a fitted cut, embracing realistic body proportions; " 193 | # "[IMAGE2] The same top is worn by a model in a lifestyle setting, where the versatile black fabric drapes naturally, " 194 | # "emphasizing its superflat construction and thin material. The styling creates a retro-contemporary fusion, " 195 | # "reminiscent of 60s fashion while maintaining a timeless cloud jumper aesthetic, all captured in a sophisticated black box presentation." 196 | ] * len(batch['image']) 197 | control_image = batch['image'] 198 | control_mask = batch['inpaint_mask'] 199 | 200 | 201 | height = args.height 202 | width = args.width*2 203 | 204 | result = pipeline( 205 | prompt=prompt, 206 | height=height, 207 | width=width, 208 | image=control_image, 209 | mask_image=control_mask, 210 | num_inference_steps=28, 211 | generator=generator, 212 | guidance_scale=30, 213 | ).images 214 | 215 | images.extend(result) 216 | prompts.extend(prompt) 217 | control_images.extend(control_image) 218 | control_masks.extend(control_mask) 219 | 220 | for tracker in accelerator.trackers: 221 | phase_name = tag 222 | if tracker.name == "wandb": 223 | tracker.log( 224 | { 225 | phase_name: [ 226 | wandb.Image( 227 | image, 228 | caption=f"{i} {prompt}", 229 | ) 230 | for i, (image, prompt, control_mask) in enumerate(zip(images, prompts, control_masks)) 231 | ], 232 | f"{phase_name}_control_images": [ 233 | wandb.Image(control_image, caption=f"{i} Control Image") 234 | for i, control_image in enumerate(control_images) 235 | ], 236 | } 237 | ) 238 | 239 | del pipeline 240 | if torch.cuda.is_available(): 241 | torch.cuda.empty_cache() 242 | 243 | return images 244 | 245 | 246 | def import_model_class_from_model_name_or_path( 247 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" 248 | ): 249 | text_encoder_config = PretrainedConfig.from_pretrained( 250 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision 251 | ) 252 | model_class = text_encoder_config.architectures[0] 253 | if model_class == "CLIPTextModel": 254 | from transformers import CLIPTextModel 255 | 256 | return CLIPTextModel 257 | elif model_class == "T5EncoderModel": 258 | from transformers import T5EncoderModel 259 | 260 | return T5EncoderModel 261 | else: 262 | raise ValueError(f"{model_class} is not supported.") 263 | 264 | 265 | def tokenize_prompt(tokenizer, prompt, max_sequence_length): 266 | text_inputs = tokenizer( 267 | prompt, 268 | padding="max_length", 269 | max_length=max_sequence_length, 270 | truncation=True, 271 | return_length=False, 272 | return_overflowing_tokens=False, 273 | return_tensors="pt", 274 | ) 275 | text_input_ids = text_inputs.input_ids 276 | return text_input_ids 277 | 278 | 279 | def _encode_prompt_with_t5( 280 | text_encoder, 281 | tokenizer, 282 | max_sequence_length=512, 283 | prompt=None, 284 | num_images_per_prompt=1, 285 | device=None, 286 | text_input_ids=None, 287 | ): 288 | prompt = [prompt] if isinstance(prompt, str) else prompt 289 | batch_size = len(prompt) 290 | 291 | if tokenizer is not None: 292 | text_inputs = tokenizer( 293 | prompt, 294 | padding="max_length", 295 | max_length=max_sequence_length, 296 | truncation=True, 297 | return_length=False, 298 | return_overflowing_tokens=False, 299 | return_tensors="pt", 300 | ) 301 | text_input_ids = text_inputs.input_ids 302 | else: 303 | if text_input_ids is None: 304 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified") 305 | 306 | prompt_embeds = text_encoder(text_input_ids.to(device))[0] 307 | 308 | dtype = text_encoder.dtype 309 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 310 | 311 | _, seq_len, _ = prompt_embeds.shape 312 | 313 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 314 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 315 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 316 | 317 | return prompt_embeds 318 | 319 | 320 | def _encode_prompt_with_clip( 321 | text_encoder, 322 | tokenizer, 323 | prompt: str, 324 | device=None, 325 | text_input_ids=None, 326 | num_images_per_prompt: int = 1, 327 | ): 328 | prompt = [prompt] if isinstance(prompt, str) else prompt 329 | batch_size = len(prompt) 330 | 331 | if tokenizer is not None: 332 | text_inputs = tokenizer( 333 | prompt, 334 | padding="max_length", 335 | max_length=77, 336 | truncation=True, 337 | return_overflowing_tokens=False, 338 | return_length=False, 339 | return_tensors="pt", 340 | ) 341 | 342 | text_input_ids = text_inputs.input_ids 343 | else: 344 | if text_input_ids is None: 345 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified") 346 | 347 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) 348 | 349 | # Use pooled output of CLIPTextModel 350 | prompt_embeds = prompt_embeds.pooler_output 351 | prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) 352 | 353 | # duplicate text embeddings for each generation per prompt, using mps friendly method 354 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 355 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 356 | 357 | return prompt_embeds 358 | 359 | 360 | def encode_prompt( 361 | text_encoders, 362 | tokenizers, 363 | prompt: str, 364 | max_sequence_length, 365 | device=None, 366 | num_images_per_prompt: int = 1, 367 | text_input_ids_list=None, 368 | ): 369 | prompt = [prompt] if isinstance(prompt, str) else prompt 370 | dtype = text_encoders[0].dtype 371 | device = device if device is not None else text_encoders[1].device 372 | pooled_prompt_embeds = _encode_prompt_with_clip( 373 | text_encoder=text_encoders[0], 374 | tokenizer=tokenizers[0], 375 | prompt=prompt, 376 | device=device, 377 | num_images_per_prompt=num_images_per_prompt, 378 | text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, 379 | ) 380 | 381 | prompt_embeds = _encode_prompt_with_t5( 382 | text_encoder=text_encoders[1], 383 | tokenizer=tokenizers[1], 384 | max_sequence_length=max_sequence_length, 385 | prompt=prompt, 386 | num_images_per_prompt=num_images_per_prompt, 387 | device=device, 388 | text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, 389 | ) 390 | 391 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) 392 | 393 | return prompt_embeds, pooled_prompt_embeds, text_ids 394 | 395 | 396 | def main(args): 397 | if args.report_to == "wandb" and args.hub_token is not None: 398 | raise ValueError( 399 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." 400 | " Please use `huggingface-cli login` to authenticate with the Hub." 401 | ) 402 | 403 | if torch.backends.mps.is_available() and args.mixed_precision == "bf16": 404 | # due to pytorch#99272, MPS does not yet support bfloat16. 405 | raise ValueError( 406 | "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." 407 | ) 408 | 409 | logging_dir = Path(args.output_dir, args.logging_dir) 410 | 411 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 412 | kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 413 | accelerator = Accelerator( 414 | gradient_accumulation_steps=args.gradient_accumulation_steps, 415 | mixed_precision=args.mixed_precision, 416 | log_with=args.report_to, 417 | project_config=accelerator_project_config, 418 | kwargs_handlers=[kwargs], 419 | ) 420 | 421 | # Disable AMP for MPS. 422 | if torch.backends.mps.is_available(): 423 | accelerator.native_amp = False 424 | 425 | if args.report_to == "wandb": 426 | if not is_wandb_available(): 427 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 428 | 429 | # Make one log on every process with the configuration for debugging. 430 | logging.basicConfig( 431 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 432 | datefmt="%m/%d/%Y %H:%M:%S", 433 | level=logging.INFO, 434 | ) 435 | logger.info(accelerator.state, main_process_only=False) 436 | if accelerator.is_local_main_process: 437 | transformers.utils.logging.set_verbosity_warning() 438 | diffusers.utils.logging.set_verbosity_info() 439 | else: 440 | transformers.utils.logging.set_verbosity_error() 441 | diffusers.utils.logging.set_verbosity_error() 442 | 443 | # If passed along, set the training seed now. 444 | if args.seed is not None: 445 | set_seed(args.seed) 446 | # Handle the repository creation 447 | if accelerator.is_main_process: 448 | if args.output_dir is not None: 449 | os.makedirs(args.output_dir, exist_ok=True) 450 | 451 | if args.push_to_hub: 452 | repo_id = create_repo( 453 | repo_id=args.hub_model_id or Path(args.output_dir).name, 454 | exist_ok=True, 455 | ).repo_id 456 | 457 | # Load the tokenizers 458 | tokenizer_one = CLIPTokenizer.from_pretrained( 459 | args.pretrained_model_name_or_path, 460 | subfolder="tokenizer", 461 | revision=args.revision, 462 | ) 463 | tokenizer_two = T5TokenizerFast.from_pretrained( 464 | args.pretrained_model_name_or_path, 465 | subfolder="tokenizer_2", 466 | revision=args.revision, 467 | ) 468 | 469 | # import correct text encoder classes 470 | text_encoder_cls_one = import_model_class_from_model_name_or_path( 471 | args.pretrained_model_name_or_path, args.revision 472 | ) 473 | text_encoder_cls_two = import_model_class_from_model_name_or_path( 474 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" 475 | ) 476 | 477 | # Load scheduler and models 478 | noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 479 | args.pretrained_model_name_or_path, subfolder="scheduler" 480 | ) 481 | noise_scheduler_copy = copy.deepcopy(noise_scheduler) 482 | text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) 483 | vae = AutoencoderKL.from_pretrained( 484 | args.pretrained_model_name_or_path, 485 | subfolder="vae", 486 | revision=args.revision, 487 | variant=args.variant, 488 | ) 489 | 490 | vae_scale_factor = ( 491 | 2 ** (len(vae.config.block_out_channels) - 1) if vae is not None else 8 492 | ) 493 | image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_resize=True, do_convert_rgb=True, do_normalize=True) 494 | mask_processor = VaeImageProcessor( 495 | vae_scale_factor=vae_scale_factor, 496 | do_resize=True, 497 | do_convert_grayscale=True, 498 | do_normalize=False, 499 | do_binarize=True, 500 | ) 501 | transformer = FluxTransformer2DModel.from_pretrained( 502 | args.pretrained_inpaint_model_name_or_path, revision=args.revision, variant=args.variant 503 | ) 504 | 505 | 506 | 507 | transformer.requires_grad_(False) 508 | vae.requires_grad_(False) 509 | text_encoder_one.requires_grad_(False) 510 | text_encoder_two.requires_grad_(False) 511 | 512 | grad_params = [ 513 | "transformer_blocks.0.", 514 | "transformer_blocks.1.", 515 | "transformer_blocks.2.", 516 | "transformer_blocks.3.", 517 | "transformer_blocks.4.", 518 | "transformer_blocks.5.", 519 | "transformer_blocks.6.", 520 | "transformer_blocks.7.", 521 | "transformer_blocks.8.", 522 | "transformer_blocks.9.", 523 | "transformer_blocks.10.", 524 | "transformer_blocks.11.", 525 | "transformer_blocks.12.", 526 | "transformer_blocks.13.", 527 | "transformer_blocks.14.", 528 | "transformer_blocks.15.", 529 | "transformer_blocks.16.", 530 | "transformer_blocks.17.", 531 | "transformer_blocks.18.", 532 | "single_transformer_blocks.0.", 533 | "single_transformer_blocks.1.", 534 | "single_transformer_blocks.2.", 535 | "single_transformer_blocks.3.", 536 | "single_transformer_blocks.4.", 537 | "single_transformer_blocks.5.", 538 | "single_transformer_blocks.6.", 539 | "single_transformer_blocks.7.", 540 | "single_transformer_blocks.8.", 541 | "single_transformer_blocks.9.", 542 | "single_transformer_blocks.10.", 543 | "single_transformer_blocks.13.", 544 | "single_transformer_blocks.14.", 545 | "single_transformer_blocks.15.", 546 | "single_transformer_blocks.16.", 547 | "single_transformer_blocks.17.", 548 | "single_transformer_blocks.18.", 549 | "single_transformer_blocks.19.", 550 | "single_transformer_blocks.20.", 551 | "single_transformer_blocks.21.", 552 | "single_transformer_blocks.22.", 553 | "single_transformer_blocks.23.", 554 | "single_transformer_blocks.24.", 555 | "single_transformer_blocks.25.", 556 | "single_transformer_blocks.26.", 557 | "single_transformer_blocks.27.", 558 | "single_transformer_blocks.28.", 559 | "single_transformer_blocks.29.", 560 | "single_transformer_blocks.30.", 561 | "single_transformer_blocks.31.", 562 | "single_transformer_blocks.32.", 563 | "single_transformer_blocks.33.", 564 | "single_transformer_blocks.34.", 565 | "single_transformer_blocks.35.", 566 | "single_transformer_blocks.36.", 567 | "single_transformer_blocks.37.", 568 | ] 569 | 570 | if args.train_base_model: 571 | transformer.requires_grad_(False) # Set all parameters to not require gradients by default 572 | 573 | for name, param in transformer.named_parameters(): 574 | if any(grad_param in name for grad_param in grad_params): 575 | if ("attn" in name): 576 | param.requires_grad = True 577 | print(f"Enabling gradients for: {name}") 578 | 579 | else: 580 | transformer.requires_grad_(False) 581 | 582 | 583 | # #you can train your own layers 584 | # for n, param in transformer.named_parameters(): 585 | # print(n) 586 | # if 'single_transformer_blocks' in n: 587 | # param.requires_grad = False 588 | # elif 'transformer_blocks' in n and '1.attn' in n: 589 | # param.requires_grad = True 590 | # else: 591 | # param.requires_grad = False 592 | 593 | print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'transformer parameters') 594 | 595 | # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision 596 | # as these weights are only used for inference, keeping weights in full precision is not required. 597 | weight_dtype = torch.float32 598 | if accelerator.mixed_precision == "fp16": 599 | weight_dtype = torch.float16 600 | elif accelerator.mixed_precision == "bf16": 601 | weight_dtype = torch.bfloat16 602 | 603 | if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: 604 | # due to pytorch#99272, MPS does not yet support bfloat16. 605 | raise ValueError( 606 | "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." 607 | ) 608 | 609 | vae.to(accelerator.device, dtype=weight_dtype) 610 | text_encoder_one.to(accelerator.device, dtype=weight_dtype) 611 | text_encoder_two.to(accelerator.device, dtype=weight_dtype) 612 | transformer.to(accelerator.device, dtype=weight_dtype) 613 | 614 | 615 | if args.gradient_checkpointing: 616 | if args.train_base_model: 617 | transformer.enable_gradient_checkpointing() 618 | 619 | 620 | def unwrap_model(model): 621 | model = accelerator.unwrap_model(model) 622 | model = model._orig_mod if is_compiled_module(model) else model 623 | return model 624 | 625 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 626 | def save_model_hook(models, weights, output_dir): 627 | if accelerator.is_main_process: 628 | for i, model in enumerate(models): 629 | if isinstance(model, DeepSpeedEngine): 630 | # For DeepSpeed models, we need to get the underlying model 631 | model = model.module 632 | if isinstance(unwrap_model(model), FluxTransformer2DModel): 633 | unwrap_model(model).save_pretrained(os.path.join(output_dir, "transformer")) 634 | elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): 635 | if isinstance(unwrap_model(model), CLIPTextModelWithProjection): 636 | unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder")) 637 | else: 638 | unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_2")) 639 | else: 640 | raise ValueError(f"Wrong model supplied: {type(model)=}.") 641 | 642 | # make sure to pop weight so that corresponding model is not saved again 643 | if weights: 644 | weights.pop() 645 | else: 646 | print('no weights') 647 | 648 | def load_model_hook(models, input_dir): 649 | for _ in range(len(models)): 650 | # pop models so that they are not loaded again 651 | model = models.pop() 652 | 653 | # load diffusers style into model 654 | if isinstance(unwrap_model(model), FluxTransformer2DModel): 655 | load_model = FluxTransformer2DModel.from_pretrained(input_dir, subfolder="transformer") 656 | model.register_to_config(**load_model.config) 657 | 658 | model.load_state_dict(load_model.state_dict()) 659 | elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): 660 | try: 661 | load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder") 662 | model(**load_model.config) 663 | model.load_state_dict(load_model.state_dict()) 664 | except Exception: 665 | try: 666 | load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_2") 667 | model(**load_model.config) 668 | model.load_state_dict(load_model.state_dict()) 669 | except Exception: 670 | raise ValueError(f"Couldn't load the model of type: ({type(model)}).") 671 | else: 672 | raise ValueError(f"Unsupported model found: {type(model)=}") 673 | 674 | del load_model 675 | 676 | accelerator.register_save_state_pre_hook(save_model_hook) 677 | accelerator.register_load_state_pre_hook(load_model_hook) 678 | 679 | # Enable TF32 for faster training on Ampere GPUs, 680 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 681 | if args.allow_tf32 and torch.cuda.is_available(): 682 | torch.backends.cuda.matmul.allow_tf32 = True 683 | 684 | if args.scale_lr: 685 | args.learning_rate = ( 686 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 687 | ) 688 | 689 | # Optimization parameters 690 | if args.train_base_model: 691 | transformer_parameters_with_lr = {"params": transformer.parameters(), "lr": args.learning_rate} 692 | 693 | params_to_optimize = [transformer_parameters_with_lr] 694 | 695 | # Optimizer creation 696 | if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): 697 | logger.warning( 698 | f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." 699 | "Defaulting to adamW" 700 | ) 701 | args.optimizer = "adamw" 702 | 703 | if args.use_8bit_adam and not args.optimizer.lower() == "adamw": 704 | logger.warning( 705 | f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " 706 | f"set to {args.optimizer.lower()}" 707 | ) 708 | 709 | if args.optimizer.lower() == "adamw": 710 | if args.use_8bit_adam: 711 | try: 712 | import bitsandbytes as bnb 713 | except ImportError: 714 | raise ImportError( 715 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 716 | ) 717 | 718 | optimizer_class = bnb.optim.AdamW8bit 719 | else: 720 | optimizer_class = torch.optim.AdamW 721 | 722 | optimizer = optimizer_class( 723 | params_to_optimize, 724 | betas=(args.adam_beta1, args.adam_beta2), 725 | weight_decay=args.adam_weight_decay, 726 | eps=args.adam_epsilon, 727 | ) 728 | 729 | if args.optimizer.lower() == "prodigy": 730 | try: 731 | import prodigyopt 732 | except ImportError: 733 | raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") 734 | 735 | optimizer_class = prodigyopt.Prodigy 736 | 737 | if args.learning_rate <= 0.1: 738 | logger.warning( 739 | "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" 740 | ) 741 | 742 | optimizer = optimizer_class( 743 | params_to_optimize, 744 | lr=args.learning_rate, 745 | betas=(args.adam_beta1, args.adam_beta2), 746 | beta3=args.prodigy_beta3, 747 | weight_decay=args.adam_weight_decay, 748 | eps=args.adam_epsilon, 749 | decouple=args.prodigy_decouple, 750 | use_bias_correction=args.prodigy_use_bias_correction, 751 | safeguard_warmup=args.prodigy_safeguard_warmup, 752 | ) 753 | 754 | # Dataset and DataLoaders creation: 755 | train_dataset = VitonHDTestDataset( 756 | dataroot_path=args.dataroot, 757 | phase="train", 758 | order="paired", 759 | size=(args.height, args.width), 760 | data_list=args.train_data_list, 761 | ) 762 | 763 | train_verification_dataset = VitonHDTestDataset( 764 | dataroot_path=args.dataroot, 765 | phase="train", 766 | order="paired", 767 | size=(args.height, args.width), 768 | data_list=args.train_verification_list, 769 | ) 770 | 771 | validation_dataset = VitonHDTestDataset( 772 | dataroot_path=args.dataroot, 773 | phase="test", 774 | order="paired", 775 | size=(args.height, args.width), 776 | data_list=args.validation_data_list, 777 | ) 778 | 779 | train_dataloader = torch.utils.data.DataLoader( 780 | train_dataset, 781 | shuffle=False, 782 | batch_size=args.train_batch_size, 783 | ) 784 | 785 | train_verification_dataloader = torch.utils.data.DataLoader( 786 | train_verification_dataset, 787 | shuffle=False, 788 | batch_size=args.train_batch_size, 789 | ) 790 | 791 | validation_dataloader = torch.utils.data.DataLoader( 792 | validation_dataset, 793 | shuffle=False, 794 | batch_size=args.train_batch_size, 795 | ) 796 | 797 | tokenizers = [tokenizer_one, tokenizer_two] 798 | text_encoders = [text_encoder_one, text_encoder_two] 799 | 800 | def compute_text_embeddings(prompt, text_encoders, tokenizers): 801 | with torch.no_grad(): 802 | prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( 803 | text_encoders, tokenizers, prompt, args.max_sequence_length 804 | ) 805 | prompt_embeds = prompt_embeds.to(accelerator.device) 806 | pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) 807 | text_ids = text_ids.to(accelerator.device) 808 | return prompt_embeds, pooled_prompt_embeds, text_ids 809 | 810 | # If no type of tuning is done on the text_encoder and custom instance prompts are NOT 811 | # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid 812 | # the redundant encoding. 813 | 814 | # Handle class prompt for prior-preservation. 815 | 816 | 817 | # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), 818 | # pack the statically computed variables appropriately here. This is so that we don't 819 | # have to pass them to the dataloader. 820 | 821 | # Scheduler and math around the number of training steps. 822 | overrode_max_train_steps = False 823 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 824 | if args.max_train_steps is None: 825 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 826 | overrode_max_train_steps = True 827 | 828 | lr_scheduler = get_scheduler( 829 | args.lr_scheduler, 830 | optimizer=optimizer, 831 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 832 | num_training_steps=args.max_train_steps * accelerator.num_processes, 833 | num_cycles=args.lr_num_cycles, 834 | power=args.lr_power, 835 | ) 836 | 837 | 838 | if args.train_base_model: 839 | transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 840 | transformer, optimizer, train_dataloader, lr_scheduler 841 | ) 842 | 843 | 844 | 845 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 846 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 847 | if overrode_max_train_steps: 848 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 849 | # Afterwards we recalculate our number of training epochs 850 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 851 | 852 | # We need to initialize the trackers we use, and also store our configuration. 853 | # The trackers initializes automatically on the main process. 854 | if accelerator.is_main_process: 855 | tracker_name = "dreambooth-flux-inpaint" 856 | accelerator.init_trackers(tracker_name, config=vars(args), init_kwargs={"wandb": {"settings": wandb.Settings(code_dir=".")}}) 857 | # Train! 858 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 859 | 860 | logger.info("***** Running training *****") 861 | logger.info(f" Num examples = {len(train_dataset)}") 862 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 863 | logger.info(f" Num Epochs = {args.num_train_epochs}") 864 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 865 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 866 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 867 | logger.info(f" Total optimization steps = {args.max_train_steps}") 868 | global_step = 0 869 | first_epoch = 0 870 | epoch = first_epoch 871 | 872 | 873 | # Potentially load in the weights and states from a previous save 874 | if args.resume_from_checkpoint: 875 | if args.resume_from_checkpoint != "latest": 876 | path = os.path.basename(args.resume_from_checkpoint) 877 | else: 878 | # Get the mos recent checkpoint 879 | dirs = os.listdir(args.output_dir) 880 | dirs = [d for d in dirs if d.startswith("checkpoint")] 881 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 882 | path = dirs[-1] if len(dirs) > 0 else None 883 | 884 | if path is None: 885 | accelerator.print( 886 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 887 | ) 888 | args.resume_from_checkpoint = None 889 | initial_global_step = 0 890 | else: 891 | accelerator.print(f"Resuming from checkpoint {path}") 892 | accelerator.load_state(os.path.join(args.output_dir, path)) 893 | global_step = int(path.split("-")[1]) 894 | 895 | initial_global_step = global_step 896 | first_epoch = global_step // num_update_steps_per_epoch 897 | 898 | else: 899 | initial_global_step = 0 900 | 901 | progress_bar = tqdm( 902 | range(0, args.max_train_steps), 903 | initial=initial_global_step, 904 | desc="Steps", 905 | # Only show the progress bar once on each machine. 906 | disable=not accelerator.is_local_main_process, 907 | ) 908 | 909 | def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): 910 | sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) 911 | schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) 912 | timesteps = timesteps.to(accelerator.device) 913 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] 914 | 915 | sigma = sigmas[step_indices].flatten() 916 | while len(sigma.shape) < n_dim: 917 | sigma = sigma.unsqueeze(-1) 918 | return sigma 919 | 920 | for epoch in range(first_epoch, args.num_train_epochs): 921 | if args.train_base_model: 922 | transformer.train() 923 | 924 | for step, batch in enumerate(train_dataloader): 925 | if args.train_base_model: 926 | models_to_accumulate = [transformer] 927 | 928 | 929 | with accelerator.accumulate(models_to_accumulate): 930 | # vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) 931 | batch_size = batch["image"].shape[0] 932 | pixel_values = batch["image"].to(dtype=vae.dtype) 933 | # prompts = batch["caption_cloth"] 934 | prompts = ["" 935 | f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " 936 | f"[IMAGE1] Detailed product shot of a clothing" 937 | f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting." 938 | # "[IMAGE1] A sleek black long-sleeved top is displayed against a neutral backdrop, featuring distinctive elbow pads, " 939 | # "a classic round neckline, and a cropped silhouette that combines 90s-inspired design with modern minimalism. " 940 | # "The garment showcases clean lines and a fitted cut, embracing realistic body proportions; " 941 | # "[IMAGE2] The same top is worn by a model in a lifestyle setting, where the versatile black fabric drapes naturally, " 942 | # "emphasizing its superflat construction and thin material. The styling creates a retro-contemporary fusion, " 943 | # "reminiscent of 60s fashion while maintaining a timeless cloud jumper aesthetic, all captured in a sophisticated black box presentation." 944 | ] * len(pixel_values) 945 | # prompts = ["upperbody"] * len(pixel_values) 946 | 947 | control_mask = batch["inpaint_mask"].to(dtype=vae.dtype) 948 | control_image = batch["im_mask"].to(dtype=vae.dtype) 949 | garment_image = batch["cloth_pure"] 950 | # garment_image_0_1 = (batch["cloth_pure"] + 1.0) / 2 951 | garment_image = garment_image.to(dtype=vae.dtype) 952 | 953 | 954 | # print("image_proj.shape", image_proj.shape) 955 | 956 | # encode batch prompts when custom prompts are provided for each image - 957 | prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( 958 | prompts, text_encoders, tokenizers 959 | ) 960 | 961 | inpaint_cond, _, _ = prepare_fill_with_mask( 962 | image_processor=image_processor, 963 | mask_processor=mask_processor, 964 | vae=vae, 965 | vae_scale_factor=vae_scale_factor, 966 | image=control_image, 967 | mask=control_mask, 968 | width=args.width*2, 969 | height=args.height, 970 | batch_size=batch_size, 971 | num_images_per_prompt=1, 972 | device=accelerator.device, 973 | dtype=weight_dtype, 974 | ) 975 | 976 | 977 | # TODO: controlnet dropout might cause instability, need to run more experiments 978 | if args.dropout_prob > 0: 979 | dropout = torch.nn.Dropout(p=args.dropout_prob) 980 | inpaint_cond = dropout(inpaint_cond) 981 | 982 | model_input = encode_images_to_latents(vae, pixel_values, weight_dtype, args.height, args.width*2) 983 | 984 | latent_image_ids = prepare_latents( 985 | vae_scale_factor, 986 | batch_size, 987 | args.height, 988 | args.width*2, 989 | weight_dtype, 990 | accelerator.device, 991 | ) 992 | 993 | # Sample noise that we'll add to the latents 994 | noise = torch.randn_like(model_input) 995 | bsz = model_input.shape[0] 996 | 997 | # Sample a random timestep for each image 998 | # for weighting schemes where we sample timesteps non-uniformly 999 | u = compute_density_for_timestep_sampling( 1000 | weighting_scheme=args.weighting_scheme, 1001 | batch_size=bsz, 1002 | logit_mean=args.logit_mean, 1003 | logit_std=args.logit_std, 1004 | mode_scale=args.mode_scale, 1005 | ) 1006 | indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() 1007 | timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) 1008 | 1009 | # Add noise according to flow matching. 1010 | # zt = (1 - texp) * x + texp * z1 1011 | sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) 1012 | noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise 1013 | 1014 | packed_noisy_model_input = FluxFillPipeline._pack_latents( 1015 | noisy_model_input, 1016 | batch_size=model_input.shape[0], 1017 | num_channels_latents=model_input.shape[1], 1018 | height=model_input.shape[2], 1019 | width=model_input.shape[3], 1020 | ) 1021 | 1022 | # handle guidance 1023 | # guidance = torch.tensor([args.guidance_scale], device=accelerator.device) 1024 | guidance = torch.full([1], args.guidance_scale, device=accelerator.device) 1025 | guidance = guidance.expand(model_input.shape[0]) 1026 | 1027 | # print("before concat packed_noisy_model_input.shape", packed_noisy_model_input.shape, "inpaint_cond.shape", inpaint_cond.shape) 1028 | 1029 | if inpaint_cond is not None: 1030 | packed_noisy_model_input = torch.cat([packed_noisy_model_input, inpaint_cond], dim=-1) 1031 | 1032 | # print("guidance", guidance, "pooled_prompt_embeds.shape", pooled_prompt_embeds.shape, "prompt_embeds.shape", prompt_embeds.shape) 1033 | 1034 | # Predict the noise residual 1035 | model_pred = transformer( 1036 | hidden_states=packed_noisy_model_input, 1037 | # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) 1038 | timestep=timesteps / 1000, 1039 | guidance=guidance, 1040 | pooled_projections=pooled_prompt_embeds, 1041 | encoder_hidden_states=prompt_embeds, 1042 | txt_ids=text_ids, 1043 | img_ids=latent_image_ids, 1044 | return_dict=False, 1045 | )[0] 1046 | 1047 | # print("model_pred.shape", model_pred.shape, "prompt_embeds.shape", prompt_embeds.shape, "packed_noisy_model_input.shape", packed_noisy_model_input.shape, "refnet_image.shape", refnet_image.shape) 1048 | # upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042 1049 | model_pred = FluxFillPipeline._unpack_latents( 1050 | model_pred, 1051 | height=args.height, 1052 | width=args.width*2, 1053 | vae_scale_factor=vae_scale_factor, 1054 | ) 1055 | 1056 | # these weighting schemes use a uniform timestep sampling 1057 | # and instead post-weight the loss 1058 | weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) 1059 | 1060 | # flow matching loss 1061 | target = noise - model_input 1062 | 1063 | # Compute regular loss. 1064 | loss = torch.mean( 1065 | (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1066 | 1, 1067 | ) 1068 | loss = loss.mean() 1069 | 1070 | accelerator.backward(loss) 1071 | if accelerator.sync_gradients: 1072 | if args.train_base_model: 1073 | params_to_clip = ( 1074 | transformer.parameters() 1075 | ) 1076 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 1077 | 1078 | optimizer.step() 1079 | lr_scheduler.step() 1080 | optimizer.zero_grad() 1081 | 1082 | # Checks if the accelerator has performed an optimization step behind the scenes 1083 | if accelerator.sync_gradients: 1084 | progress_bar.update(1) 1085 | global_step += 1 1086 | 1087 | if global_step % args.checkpointing_steps == 0: 1088 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 1089 | if args.checkpoints_total_limit is not None: 1090 | checkpoints = os.listdir(args.output_dir) 1091 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 1092 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 1093 | 1094 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 1095 | if len(checkpoints) >= args.checkpoints_total_limit: 1096 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 1097 | removing_checkpoints = checkpoints[0:num_to_remove] 1098 | 1099 | logger.info( 1100 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 1101 | ) 1102 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 1103 | 1104 | for removing_checkpoint in removing_checkpoints: 1105 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 1106 | shutil.rmtree(removing_checkpoint) 1107 | 1108 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 1109 | accelerator.save_state(save_path) 1110 | logger.info(f"Saved state to {save_path}") 1111 | 1112 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 1113 | progress_bar.set_postfix(**logs) 1114 | accelerator.log(logs, step=global_step) 1115 | 1116 | if accelerator.sync_gradients: 1117 | if global_step % args.validation_steps == 1: 1118 | pipeline = FluxFillPipeline.from_pretrained( 1119 | args.pretrained_model_name_or_path, 1120 | transformer=accelerator.unwrap_model(transformer), 1121 | torch_dtype=weight_dtype, 1122 | vae=vae, 1123 | tokenizer=tokenizer_one, 1124 | tokenizer_2=tokenizer_two, 1125 | text_encoder=text_encoder_one, 1126 | text_encoder_2=text_encoder_two, 1127 | ) 1128 | 1129 | log_validation( 1130 | pipeline=pipeline, 1131 | args=args, 1132 | accelerator=accelerator, 1133 | dataloader=train_verification_dataloader, 1134 | tag="train verification", 1135 | epoch=epoch, 1136 | ) 1137 | 1138 | log_validation( 1139 | pipeline=pipeline, 1140 | args=args, 1141 | accelerator=accelerator, 1142 | dataloader=validation_dataloader, 1143 | tag="validation", 1144 | epoch=epoch, 1145 | ) 1146 | 1147 | if global_step >= args.max_train_steps: 1148 | break 1149 | 1150 | # Save the lora layers 1151 | accelerator.wait_for_everyone() 1152 | if accelerator.is_main_process: 1153 | transformer = unwrap_model(transformer) 1154 | 1155 | pipeline = FluxFillPipeline.from_pretrained(args.pretrained_model_name_or_path, transformer=transformer) 1156 | 1157 | # save the pipeline 1158 | pipeline.save_pretrained(args.output_dir) 1159 | 1160 | # Final inference 1161 | # Load previous pipeline 1162 | pipeline = FluxFillPipeline.from_pretrained( 1163 | args.output_dir, 1164 | revision=args.revision, 1165 | variant=args.variant, 1166 | torch_dtype=weight_dtype, 1167 | ) 1168 | 1169 | accelerator.end_training() 1170 | 1171 | 1172 | if __name__ == "__main__": 1173 | args = parse_args() 1174 | main(args) --------------------------------------------------------------------------------