├── .dockerignore ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── app ├── app_pixart_dmd.py └── app_pixart_sigma.py ├── asset ├── PixArt.svg ├── docs │ ├── convert_image2json.md │ ├── data_feature_extraction.md │ ├── pixart.md │ ├── pixart_dmd.md │ └── pixart_lora.md ├── examples.py ├── imgs │ ├── dmd.png │ ├── lora_512.png │ └── noise_snr.png ├── logo-sigma.png ├── logo.png └── samples.txt ├── configs ├── PixArt_xl2_internal.py ├── pixart_alpha_config │ ├── PixArt_xl2_img1024_dreambooth.py │ ├── PixArt_xl2_img1024_internal.py │ ├── PixArt_xl2_img1024_internalms.py │ ├── PixArt_xl2_img256_internal.py │ ├── PixArt_xl2_img512_internal.py │ └── PixArt_xl2_img512_internalms.py ├── pixart_app_config │ └── PixArt-DMD_xl2_img512_internalms.py └── pixart_sigma_config │ ├── PixArt_sigma_xl2_img1024_internalms.py │ ├── PixArt_sigma_xl2_img1024_internalms_kvcompress.py │ ├── PixArt_sigma_xl2_img1024_lcm.py │ ├── PixArt_sigma_xl2_img256_internal.py │ ├── PixArt_sigma_xl2_img2K_internalms_kvcompress.py │ └── PixArt_sigma_xl2_img512_internalms.py ├── diffusion ├── __init__.py ├── data │ ├── __init__.py │ ├── builder.py │ ├── datasets │ │ ├── InternalData.py │ │ ├── InternalData_ms.py │ │ ├── __init__.py │ │ ├── dmd.py │ │ └── utils.py │ └── transforms.py ├── dpm_solver.py ├── iddpm.py ├── lcm_scheduler.py ├── model │ ├── __init__.py │ ├── builder.py │ ├── diffusion_utils.py │ ├── dpm_solver.py │ ├── edm_sample.py │ ├── gaussian_diffusion.py │ ├── llava │ │ ├── __init__.py │ │ ├── llava_mpt.py │ │ └── mpt │ │ │ ├── attention.py │ │ │ ├── blocks.py │ │ │ ├── configuration_mpt.py │ │ │ ├── modeling_mpt.py │ │ │ ├── norm.py │ │ │ └── param_init_fns.py │ ├── nets │ │ ├── PixArt.py │ │ ├── PixArtMS.py │ │ ├── PixArt_blocks.py │ │ └── __init__.py │ ├── respace.py │ ├── sa_solver.py │ ├── t5.py │ ├── timestep_sampler.py │ └── utils.py ├── sa_sampler.py ├── sa_solver_diffusers.py └── utils │ ├── __init__.py │ ├── checkpoint.py │ ├── data_sampler.py │ ├── dist_utils.py │ ├── logger.py │ ├── lr_scheduler.py │ ├── misc.py │ └── optimizer.py ├── environment.yml ├── notebooks ├── PixArt_xl2_img512_internal_for_pokemon_sample_training.py ├── convert-checkpoint-to-diffusers.ipynb ├── infer.ipynb └── train.ipynb ├── requirements.txt ├── scripts ├── DMD │ └── transformer_train │ │ ├── args.py │ │ ├── attention_processor.py │ │ ├── generate.py │ │ └── utils.py ├── diffusers_patches.py ├── inference.py ├── inference_pipeline.py ├── interface.py ├── run_pixart_dmd.py └── style.css ├── tools ├── convert_diffusers_to_pipeline.py ├── convert_diffusers_to_pixart.py ├── convert_images_to_json.py ├── convert_pixart_to_diffusers.py ├── download.py ├── extract_features.py ├── generate_dmd_data_noise_pairs.py └── merge_transformers.py └── train_scripts ├── train.py ├── train_dmd.sh ├── train_dreambooth_lora.py ├── train_pixart_dmd.py ├── train_pixart_lcm.py ├── train_pixart_lora.sh └── train_pixart_lora_hf.py /.dockerignore: -------------------------------------------------------------------------------- 1 | docker -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug 139 | .idea/ 140 | cloud_tools/ 141 | output/ 142 | output_cv/ 143 | 144 | # added by ylw 145 | *.pt 146 | *.pth 147 | *mj* 148 | s3helper/ 149 | TODO.md 150 | pretrained_models 151 | work_dir 152 | #demo.py 153 | develop/ 154 | tmp.py 155 | output_cv/ 156 | output_all/ 157 | output_demo/ 158 | output_debug/ 159 | output/ 160 | 161 | #cache for docker 162 | docker/cache/gradio 163 | docker/cache/huggingface -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # This is a sample Dockefile that builds a runtime container and runs the sample Gradio app. 2 | # Note, you must pass in the pretrained models when you run the container. 3 | 4 | FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04 5 | 6 | WORKDIR /workspace 7 | 8 | RUN apt-get update && \ 9 | apt-get install -y \ 10 | git \ 11 | python3 \ 12 | python-is-python3 \ 13 | python3-pip \ 14 | python3.10-venv \ 15 | libgl1 \ 16 | libgl1-mesa-glx \ 17 | libglib2.0-0 \ 18 | && rm -rf /var/lib/apt/lists/* 19 | 20 | ADD requirements.txt . 21 | 22 | RUN pip install -r requirements.txt 23 | 24 | ADD . . 25 | 26 | RUN chmod a+x docker-entrypoint.sh 27 | 28 | ENV DEMO_PORT=12345 29 | ENTRYPOINT [ "/workspace/docker-entrypoint.sh" ] -------------------------------------------------------------------------------- /asset/PixArt.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 15 | 17 | 26 | 31 | 34 | 38 | 42 | 46 | 50 | 54 | 58 | 62 | 66 | 70 | 74 | 78 | 82 | 86 | 90 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /asset/docs/convert_image2json.md: -------------------------------------------------------------------------------- 1 | ## Tools 2 | 3 | # tools/convert_images_to_json.py 4 | 5 | This script is used to convert a folder that contains images and caption files to a dataset folder structure required by the PixArt-Sigma training scripts. 6 | 7 | Before : 8 | - root 9 | - image1.png 10 | - image1.txt 11 | - image2.jpg 12 | - image2.txt 13 | - image3.webp 14 | - image3.txt 15 | 16 | After : 17 | - out 18 | - InternData 19 | - data_info.json 20 | - InternImgs 21 | - image1.png 22 | - image2.jpg 23 | - image3.webp 24 | 25 | The script detects all images with a paired caption and copies these in the InternImgs folder and its prompt in the data_info.json. 26 | 27 | The usage is the following: 28 | 29 | `python tools/convert_images_to_json.py [params] images_path output_path` 30 | 31 | The caption file extension is by default .txt but the user can change it with the argument `--caption_extension .caption` for example. -------------------------------------------------------------------------------- /asset/docs/data_feature_extraction.md: -------------------------------------------------------------------------------- 1 | # 📕 Data Preparation 2 | 3 | ### 1.Downloading the toy dataset 4 | 5 | Download the [toy dataset](https://huggingface.co/datasets/PixArt-alpha/pixart-sigma-toy-dataset) first. 6 | The dataset structure for training is: 7 | 8 | ``` 9 | cd your_project_path/pixart-sigma-toy-dataset 10 | 11 | Dataset Structure 12 | ├──InternImgs/ (images are saved here) 13 | │ ├──000000000000.png 14 | │ ├──000000000001.png 15 | │ ├──...... 16 | ├──InternData/ 17 | │ ├──data_info.json (meta data) 18 | Optional(👇) 19 | │ ├──img_sdxl_vae_features_512resolution_ms_new (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension) 20 | │ │ ├──000000000000.npy 21 | │ │ ├──000000000001.npy 22 | │ │ ├──...... 23 | │ ├──caption_features_new 24 | │ │ ├──000000000000.npz 25 | │ │ ├──000000000001.npz 26 | │ │ ├──...... 27 | │ ├──sharegpt4v_caption_features_new (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension) 28 | │ │ ├──000000000000.npz 29 | │ │ ├──000000000001.npz 30 | │ │ ├──...... 31 | ``` 32 | ### You are already able to run the [training code](https://github.com/PixArt-alpha/PixArt-sigma#12-download-pretrained-checkpoint) 33 | 34 | --- 35 | ## Optional(👇) 36 | > [!IMPORTANT] 37 | > You don't have to extract following feature to do the training, BUT 38 | > 39 | > if you want to train with **faster speed** and **lower GPU occupancy**, you can pre-process all the VAE & T5 features 40 | 41 | ### 2. Extract VAE features 42 | 43 | ```bash 44 | python tools/extract_features.py --run_vae_feature_extract \ 45 | --multi_scale \ 46 | --img_size=512 \ 47 | --dataset_root=pixart-sigma-toy-dataset/InternData \ 48 | --vae_json_file=data_info.json \ 49 | --vae_models_dir=madebyollin/sdxl-vae-fp16-fix \ 50 | --vae_save_root=pixart-sigma-toy-dataset/InternData 51 | ``` 52 | **SDXL-VAE** features will be saved at: `pixart-sigma-toy-dataset/InternData/img_sdxl_vae_features_512resolution_ms_new` 53 | as shown in the [DataTree](#1downloading-the-toy-dataset). 54 | They will be later used in [InternalData_ms.py](https://github.com/PixArt-alpha/PixArt-sigma/blob/d5adc756dd6a8b64f1f0aaa1d266e90949e873c0/diffusion/data/datasets/InternalData_ms.py#L242) 55 | 56 | ### 3. Extract T5 features (prompt) 57 | 58 | ```bash 59 | python tools/extract_features.py --run_t5_feature_extract \ 60 | --max_length=300 \ 61 | --t5_json_path=pixart-sigma-toy-dataset/InternData/data_info.json \ 62 | --t5_models_dir=PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers \ 63 | --caption_label=prompt \ 64 | --t5_save_root=pixart-sigma-toy-dataset/InternData 65 | ``` 66 | **T5 features** will be saved at: `pixart-sigma-toy-dataset/InternData/caption_features_new` 67 | as shown in the [DataTree](#1downloading-the-toy-dataset). 68 | They will be later used in [InternalData_ms.py](https://github.com/PixArt-alpha/PixArt-sigma/blob/d5adc756dd6a8b64f1f0aaa1d266e90949e873c0/diffusion/data/datasets/InternalData_ms.py#L227) 69 | 70 | --- 71 | > [!TIP] 72 | > Ignore it if you don't have `sharegpt4v` in your data_info.json 73 | 74 | ### 3.1. Extract T5 features (sharegpt4v) 75 | 76 | ```bash 77 | python tools/extract_features.py --run_t5_feature_extract \ 78 | --max_length=300 \ 79 | --t5_json_path=pixart-sigma-toy-dataset/InternData/data_info.json \ 80 | --t5_models_dir=PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers \ 81 | --caption_label=sharegpt4v \ 82 | --t5_save_root=pixart-sigma-toy-dataset/InternData 83 | ``` 84 | **T5 features** will be saved at: `pixart-sigma-toy-dataset/InternData/caption_features_new` 85 | as shown in the [DataTree](#1downloading-the-toy-dataset). 86 | They will be later used in [InternalData_ms.py](https://github.com/PixArt-alpha/PixArt-sigma/blob/d5adc756dd6a8b64f1f0aaa1d266e90949e873c0/diffusion/data/datasets/InternalData_ms.py#L234) 87 | 88 | -------------------------------------------------------------------------------- /asset/docs/pixart.md: -------------------------------------------------------------------------------- 1 | 12 | 13 | [//]: # ((reference from [hugging Face](https://github.com/huggingface/diffusers/blob/docs/8bit-inference-pixart/docs/source/en/api/pipelines/pixart.md))) 14 | 15 | ## Running the `PixArtAlphaPipeline` in under 8GB GPU VRAM 16 | 17 | It is possible to run the [`PixArtAlphaPipeline`] under 8GB GPU VRAM by loading the text encoder in 8-bit numerical precision. Let's walk through a full-fledged example. 18 | 19 | First, install the `bitsandbytes` library: 20 | 21 | ```bash 22 | pip install -U bitsandbytes 23 | ``` 24 | 25 | Then load the text encoder in 8-bit: 26 | 27 | ```python 28 | from transformers import T5EncoderModel 29 | from diffusers import PixArtAlphaPipeline 30 | 31 | text_encoder = T5EncoderModel.from_pretrained( 32 | "PixArt-alpha/PixArt-XL-2-1024-MS", 33 | subfolder="text_encoder", 34 | load_in_8bit=True, 35 | device_map="auto", 36 | 37 | ) 38 | pipe = PixArtAlphaPipeline.from_pretrained( 39 | "PixArt-alpha/PixArt-XL-2-1024-MS", 40 | text_encoder=text_encoder, 41 | transformer=None, 42 | device_map="auto" 43 | ) 44 | ``` 45 | 46 | Now, use the `pipe` to encode a prompt: 47 | 48 | ```python 49 | with torch.no_grad(): 50 | prompt = "cute cat" 51 | prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt) 52 | 53 | del text_encoder 54 | del pipe 55 | flush() 56 | ``` 57 | 58 | `flush()` is just a utility function to clear the GPU VRAM and is implemented like so: 59 | 60 | ```python 61 | import gc 62 | 63 | def flush(): 64 | gc.collect() 65 | torch.cuda.empty_cache() 66 | ``` 67 | 68 | Then compute the latents providing the prompt embeddings as inputs: 69 | 70 | ```python 71 | pipe = PixArtAlphaPipeline.from_pretrained( 72 | "PixArt-alpha/PixArt-XL-2-1024-MS", 73 | text_encoder=None, 74 | torch_dtype=torch.float16, 75 | ).to("cuda") 76 | 77 | latents = pipe( 78 | negative_prompt=None, 79 | prompt_embeds=prompt_embeds, 80 | negative_prompt_embeds=negative_embeds, 81 | prompt_attention_mask=prompt_attention_mask, 82 | negative_prompt_attention_mask=negative_prompt_attention_mask, 83 | num_images_per_prompt=1, 84 | output_type="latent", 85 | ).images 86 | 87 | del pipe.transformer 88 | flush() 89 | ``` 90 | 91 | Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded. 92 | 93 | Once the latents are computed, pass it off the VAE to decode into a real image: 94 | 95 | ```python 96 | with torch.no_grad(): 97 | image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] 98 | image = pipe.image_processor.postprocess(image, output_type="pil") 99 | image.save("cat.png") 100 | ``` 101 | 102 | All of this, put together, should allow you to run [`PixArtAlphaPipeline`] under 8GB GPU VRAM. 103 | 104 | ![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/8bits_cat.png) 105 | 106 | Find the script [here](https://gist.github.com/sayakpaul/3ae0f847001d342af27018a96f467e4e) that can be run end-to-end to report the memory being used. 107 | 108 | 109 | 110 | Text embeddings computed in 8-bit can have an impact on the quality of the generated images because of the information loss in the representation space induced by the reduced precision. It's recommended to compare the outputs with and without 8-bit. 111 | 112 | -------------------------------------------------------------------------------- /asset/docs/pixart_dmd.md: -------------------------------------------------------------------------------- 1 | ## Summary 2 | 3 | **We combine [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha) and [DMD](https://arxiv.org/abs/2311.18828) 4 | to achieve one step image generation. This document will guide you how to train and test.** 5 | 6 | ![compare samples](../imgs/noise_snr.png) 7 | 8 | > [!IMPORTANT] 9 | > Due to the difference between the DiT & Stable Diffusion, 10 | > 11 | > We find that the setting of the `start timestep` of the student model is very important for DMD training. 12 | 13 | refer to PixArt-Sigma paper's Supplementary for more details: https://arxiv.org/abs/2403.04692 14 | 15 | --- 16 | ## How to Train 17 | 18 | ### 1. Environment 19 | Refer to the PixArt-Sigma Environment [Here](https://github.com/PixArt-alpha/PixArt-sigma/tree/d5adc756dd6a8b64f1f0aaa1d266e90949e873c0?tab=readme-ov-file#-dependencies-and-installation). 20 | To use FSDP, you need to make sure: 21 | - [PyTorch >= 2.0.1+cu11.7](https://pytorch.org/) 22 | 23 | ### 2. Data preparation (Image-Noise pairs) 24 | 25 | #### Generate image-noise pairs 26 | ```bash 27 | python tools/generate_dmd_data_noise_pairs.py \ 28 | --pipeline_load_from=PixArt-alpha/PixArt-XL-2-512x512 \ 29 | --model_path=PixArt-alpha/PixArt-XL-2-512x512 \ 30 | --save_img # (optinal) 31 | ``` 32 | 33 | #### Extract features in advance (Ignore if you already have) 34 | ```bash 35 | python tools/extract_features.py --run_t5_feature_extract \ 36 | --t5_models_dir=PixArt-alpha/PixArt-XL-2-512x512 \ 37 | --t5_save_root=pixart-sigma-toy-dataset/InternData \ 38 | --caption_label=prompt \ 39 | --t5_json_path=pixart-sigma-toy-dataset/InternData/data_info.json 40 | ``` 41 | 42 | ### 3. 🔥 Run 43 | ```bash 44 | bash train_scripts/train_dmd.sh 45 | ``` 46 | 47 | --- 48 | ## How to Test 49 | ### PixArt-DMD Demo 50 | ```bash 51 | pip install git+https://github.com/huggingface/diffusers 52 | 53 | # PixArt-Sigma One step Sampler(DMD) 54 | DEMO_PORT=12345 python app/app_pixart_dmd.py 55 | ``` 56 | Let's have a look at a simple example using the `http://your-server-ip:12345`. 57 | 58 | --- 59 | ## Samples 60 | ![compare samples](../imgs/dmd.png) 61 | 62 | -------------------------------------------------------------------------------- /asset/docs/pixart_lora.md: -------------------------------------------------------------------------------- 1 | ## Summary 2 | 3 | **We adapt from the LoRA training code from [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha) 4 | to achieve Transformer-LoRA fine-tuning. This document will guide you how to train and test.** 5 | 6 | > [!IMPORTANT] 7 | > Somehow due to the implementation of `diffusers` and `transformers`, 8 | > LoRA training for `transformers` can only be done in FP32. 9 | > 10 | > We welcome everyone to help for solving this issue. 11 | 12 | ## How to Train 13 | ### 🔥 Run 14 | ```bahs 15 | bash train_scripts/train_pixart_lora.sh 16 | ``` 17 | 18 | Details👇: 19 | 20 | ```bash 21 | pip install -U peft 22 | 23 | dataset_id=svjack/pokemon-blip-captions-en-zh 24 | model_id=PixArt-alpha/PixArt-XL-2-512x512 25 | 26 | accelerate launch --num_processes=1 --main_process_port=36667 train_scripts/train_pixart_lora_hf.py \ 27 | --mixed_precision="fp16" \ 28 | --pretrained_model_name_or_path=$model_id \ 29 | --dataset_name=$dataset_id \ 30 | --caption_column="text" \ 31 | --resolution=512 \ 32 | --random_flip \ 33 | --train_batch_size=16 \ 34 | --num_train_epochs=80 \ 35 | --checkpointing_steps=1000 \ 36 | --learning_rate=1e-05 \ 37 | --lr_scheduler="constant" \ 38 | --lr_warmup_steps=0 \ 39 | --seed=42 \ 40 | --output_dir="output/pixart-pokemon-model" \ 41 | --validation_prompt="cute dragon creature" \ 42 | --report_to="tensorboard" \ 43 | --gradient_checkpointing \ 44 | --checkpoints_total_limit=10 \ 45 | --validation_epochs=5 \ 46 | --max_token_length=120 \ # chang to 300 for Sigma 47 | --rank=16 48 | ``` 49 | 50 | ## How to Test 51 | 52 | ```python 53 | import torch 54 | from diffusers import PixArtAlphaPipeline, Transformer2DModel 55 | from peft import PeftModel 56 | 57 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 58 | 59 | # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-1024-MS" too. 60 | MODEL_ID = "PixArt-alpha/PixArt-XL-2-512x512" 61 | 62 | # LoRA model 63 | transformer = Transformer2DModel.from_pretrained(MODEL_ID, subfolder="transformer", torch_dtype=torch.float16) 64 | transformer = PeftModel.from_pretrained(transformer, "Your-LoRA-Model-Path") 65 | 66 | # Pipeline 67 | pipe = PixArtAlphaPipeline.from_pretrained(MODEL_ID, transformer=transformer, torch_dtype=torch.float16) 68 | del transformer 69 | 70 | pipe.to(device) 71 | 72 | prompt = "a drawing of a green pokemon with red eyes" 73 | image = pipe(prompt).images[0] 74 | image.save("./pokemon_with_red_eyes.png") 75 | ``` 76 | 77 | --- 78 | ## Samples 79 | ![compare samples](../imgs/lora_512.png) -------------------------------------------------------------------------------- /asset/examples.py: -------------------------------------------------------------------------------- 1 | 2 | examples = [ 3 | [ 4 | "A small cactus with a happy face in the Sahara desert.", 5 | "dpm-solver", 20, 4.5, 6 | ], 7 | [ 8 | "An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history" 9 | "of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits " 10 | "mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt, he wears a brown beret " 11 | "and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile " 12 | "as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and " 13 | "the Parisian streets and city in the background, depth of field, cinematic 35mm film.", 14 | "dpm-solver", 20, 4.5, 15 | ], 16 | [ 17 | "An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. " 18 | "Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. " 19 | "The quote 'Find the universe within you' is etched in bold letters across the horizon." 20 | "blue and pink, brilliantly illuminated in the background.", 21 | "dpm-solver", 20, 4.5, 22 | ], 23 | [ 24 | "A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.", 25 | "dpm-solver", 20, 4.5, 26 | ], 27 | [ 28 | "A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.", 29 | "dpm-solver", 20, 4.5, 30 | ], 31 | [ 32 | "a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, " 33 | "national geographic photo, 8k resolution, crayon art, interactive artwork", 34 | "dpm-solver", 20, 4.5, 35 | ] 36 | ] 37 | -------------------------------------------------------------------------------- /asset/imgs/dmd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma/1ce521afddcc2fab329b35b7374aa86d654e12f7/asset/imgs/dmd.png -------------------------------------------------------------------------------- /asset/imgs/lora_512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma/1ce521afddcc2fab329b35b7374aa86d654e12f7/asset/imgs/lora_512.png -------------------------------------------------------------------------------- /asset/imgs/noise_snr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma/1ce521afddcc2fab329b35b7374aa86d654e12f7/asset/imgs/noise_snr.png -------------------------------------------------------------------------------- /asset/logo-sigma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma/1ce521afddcc2fab329b35b7374aa86d654e12f7/asset/logo-sigma.png -------------------------------------------------------------------------------- /asset/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma/1ce521afddcc2fab329b35b7374aa86d654e12f7/asset/logo.png -------------------------------------------------------------------------------- /configs/PixArt_xl2_internal.py: -------------------------------------------------------------------------------- 1 | data_root = '/data/data' 2 | data = dict(type='InternalData', root='images', image_list_json=['data_info.json'], transform='default_train', load_vae_feat=True, load_t5_feat=True) 3 | image_size = 256 # the generated image resolution 4 | train_batch_size = 32 5 | eval_batch_size = 16 6 | use_fsdp=False # if use FSDP mode 7 | valid_num=0 # take as valid aspect-ratio when sample number >= valid_num 8 | fp32_attention = True 9 | # model setting 10 | model = 'PixArt_XL_2' 11 | aspect_ratio_type = None # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256] 12 | multi_scale = False # if use multiscale dataset model training 13 | pe_interpolation = 1.0 # positional embedding interpolation 14 | # qk norm 15 | qk_norm = False 16 | # kv token compression 17 | kv_compress = False 18 | kv_compress_config = { 19 | 'sampling': None, 20 | 'scale_factor': 1, 21 | 'kv_compress_layer': [], 22 | } 23 | 24 | # training setting 25 | num_workers=4 26 | train_sampling_steps = 1000 27 | visualize=False 28 | # Keep the same seed during validation 29 | deterministic_validation = False 30 | eval_sampling_steps = 250 31 | model_max_length = 120 32 | lora_rank = 4 33 | num_epochs = 80 34 | gradient_accumulation_steps = 1 35 | grad_checkpointing = False 36 | gradient_clip = 1.0 37 | gc_step = 1 38 | auto_lr = dict(rule='sqrt') 39 | validation_prompts = [ 40 | "dog", 41 | "portrait photo of a girl, photograph, highly detailed face, depth of field", 42 | "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", 43 | "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", 44 | "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", 45 | ] 46 | 47 | # we use different weight decay with the official implementation since it results better result 48 | optimizer = dict(type='AdamW', lr=1e-4, weight_decay=3e-2, eps=1e-10) 49 | lr_schedule = 'constant' 50 | lr_schedule_args = dict(num_warmup_steps=500) 51 | 52 | save_image_epochs = 1 53 | save_model_epochs = 1 54 | save_model_steps=1000000 55 | 56 | sample_posterior = True 57 | mixed_precision = 'fp16' 58 | scale_factor = 0.18215 # ldm vae: 0.18215; sdxl vae: 0.13025 59 | ema_rate = 0.9999 60 | tensorboard_mox_interval = 50 61 | log_interval = 50 62 | cfg_scale = 4 63 | mask_type='null' 64 | num_group_tokens=0 65 | mask_loss_coef=0. 66 | load_mask_index=False # load prepared mask_type index 67 | # load model settings 68 | vae_pretrained = "/cache/pretrained_models/sd-vae-ft-ema" 69 | load_from = None 70 | resume_from = dict(checkpoint=None, load_ema=False, resume_optimizer=True, resume_lr_scheduler=True) 71 | snr_loss=False 72 | real_prompt_ratio = 1.0 73 | # classifier free guidance 74 | class_dropout_prob = 0.1 75 | # work dir settings 76 | work_dir = '/cache/exps/' 77 | s3_work_dir = None 78 | micro_condition = False 79 | seed = 43 80 | skip_step=0 81 | 82 | # LCM 83 | loss_type = 'huber' 84 | huber_c = 0.001 85 | num_ddim_timesteps=50 86 | w_max = 15.0 87 | w_min = 3.0 88 | ema_decay = 0.95 89 | 90 | -------------------------------------------------------------------------------- /configs/pixart_alpha_config/PixArt_xl2_img1024_dreambooth.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../PixArt_xl2_internal.py'] 2 | data_root = 'data/dreambooth/dataset' 3 | 4 | data = dict(type='DreamBooth', root='dog6', prompt=['a photo of sks dog'], transform='default_train', load_vae_feat=True) 5 | image_size = 1024 6 | 7 | # model setting 8 | model = 'PixArtMS_XL_2' # model for multi-scale training 9 | fp32_attention = True 10 | load_from = 'Path/to/PixArt-XL-2-1024-MS.pth' 11 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema" 12 | aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256] 13 | multi_scale = True # if use multiscale dataset model training 14 | pe_interpolation = 2.0 15 | 16 | # training setting 17 | num_workers=1 18 | train_batch_size = 1 19 | num_epochs = 200 20 | gradient_accumulation_steps = 1 21 | grad_checkpointing = True 22 | gradient_clip = 0.01 23 | optimizer = dict(type='AdamW', lr=5e-6, weight_decay=3e-2, eps=1e-10) 24 | lr_schedule_args = dict(num_warmup_steps=0) 25 | auto_lr = None 26 | 27 | log_interval = 1 28 | save_model_epochs=10000 29 | save_model_steps=100 30 | work_dir = 'output/debug' 31 | -------------------------------------------------------------------------------- /configs/pixart_alpha_config/PixArt_xl2_img1024_internal.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../PixArt_xl2_internal.py'] 2 | data_root = 'data' 3 | image_list_json = ['data_info.json',] 4 | 5 | data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True) 6 | image_size = 1024 7 | 8 | # model setting 9 | model = 'PixArt_XL_2' 10 | fp32_attention = True 11 | load_from = None 12 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema" 13 | pe_interpolation = 2.0 14 | 15 | # training setting 16 | num_workers=10 17 | train_batch_size = 2 # 32 18 | num_epochs = 200 # 3 19 | gradient_accumulation_steps = 1 20 | grad_checkpointing = True 21 | gradient_clip = 0.01 22 | optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10) 23 | lr_schedule_args = dict(num_warmup_steps=1000) 24 | 25 | eval_sampling_steps = 200 26 | log_interval = 20 27 | save_model_epochs=1 28 | save_model_steps=2000 29 | work_dir = 'output/debug' 30 | -------------------------------------------------------------------------------- /configs/pixart_alpha_config/PixArt_xl2_img1024_internalms.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../PixArt_xl2_internal.py'] 2 | data_root = 'data' 3 | image_list_json = ['data_info.json',] 4 | 5 | data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True) 6 | image_size = 1024 7 | 8 | # model setting 9 | model = 'PixArtMS_XL_2' # model for multi-scale training 10 | fp32_attention = True 11 | load_from = None 12 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema" 13 | aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256] 14 | multi_scale = True # if use multiscale dataset model training 15 | pe_interpolation = 2.0 16 | 17 | # training setting 18 | num_workers=10 19 | train_batch_size = 12 # max 14 for PixArt-xL/2 when grad_checkpoint 20 | num_epochs = 10 # 3 21 | gradient_accumulation_steps = 1 22 | grad_checkpointing = True 23 | gradient_clip = 0.01 24 | optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10) 25 | lr_schedule_args = dict(num_warmup_steps=1000) 26 | save_model_epochs=1 27 | save_model_steps=2000 28 | 29 | log_interval = 20 30 | eval_sampling_steps = 200 31 | work_dir = 'output/debug' 32 | micro_condition = True 33 | -------------------------------------------------------------------------------- /configs/pixart_alpha_config/PixArt_xl2_img256_internal.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../PixArt_xl2_internal.py'] 2 | data_root = 'data' 3 | image_list_json = ['data_info.json',] 4 | 5 | data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True) 6 | image_size = 256 7 | 8 | # model setting 9 | model = 'PixArt_XL_2' 10 | fp32_attention = True 11 | load_from = None 12 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema" 13 | # training setting 14 | eval_sampling_steps = 200 15 | 16 | num_workers=10 17 | train_batch_size = 176 # 32 # max 96 for PixArt-L/4 when grad_checkpoint 18 | num_epochs = 200 # 3 19 | gradient_accumulation_steps = 1 20 | grad_checkpointing = True 21 | gradient_clip = 0.01 22 | optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10) 23 | lr_schedule_args = dict(num_warmup_steps=1000) 24 | 25 | log_interval = 20 26 | save_model_epochs=5 27 | work_dir = 'output/debug' 28 | -------------------------------------------------------------------------------- /configs/pixart_alpha_config/PixArt_xl2_img512_internal.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../PixArt_xl2_internal.py'] 2 | data_root = 'data' 3 | image_list_json = ['data_info.json',] 4 | 5 | data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True) 6 | image_size = 512 7 | 8 | # model setting 9 | model = 'PixArt_XL_2' 10 | fp32_attention = True 11 | load_from = None 12 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema" 13 | pe_interpolation = 1.0 14 | 15 | # training setting 16 | use_fsdp=False # if use FSDP mode 17 | num_workers=10 18 | train_batch_size = 38 # 32 19 | num_epochs = 200 # 3 20 | gradient_accumulation_steps = 1 21 | grad_checkpointing = True 22 | gradient_clip = 0.01 23 | optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10) 24 | lr_schedule_args = dict(num_warmup_steps=1000) 25 | 26 | eval_sampling_steps = 200 27 | log_interval = 20 28 | save_model_epochs=1 29 | work_dir = 'output/debug' 30 | -------------------------------------------------------------------------------- /configs/pixart_alpha_config/PixArt_xl2_img512_internalms.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../PixArt_xl2_internal.py'] 2 | data_root = 'data' 3 | image_list_json = ['data_info.json',] 4 | 5 | data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True) 6 | image_size = 512 7 | 8 | # model setting 9 | model = 'PixArtMS_XL_2' # model for multi-scale training 10 | fp32_attention = True 11 | load_from = None 12 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema" 13 | aspect_ratio_type = 'ASPECT_RATIO_512' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256] 14 | multi_scale = True # if use multiscale dataset model training 15 | pe_interpolation = 1.0 16 | 17 | # training setting 18 | num_workers=10 19 | train_batch_size = 40 # max 40 for PixArt-xL/2 when grad_checkpoint 20 | num_epochs = 20 # 3 21 | gradient_accumulation_steps = 1 22 | grad_checkpointing = True 23 | gradient_clip = 0.01 24 | optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10) 25 | lr_schedule_args = dict(num_warmup_steps=1000) 26 | save_model_epochs=1 27 | save_model_steps=2000 28 | 29 | log_interval = 20 30 | eval_sampling_steps = 200 31 | work_dir = 'output/debug' 32 | -------------------------------------------------------------------------------- /configs/pixart_app_config/PixArt-DMD_xl2_img512_internalms.py: -------------------------------------------------------------------------------- 1 | # Config for PixArt-DMD 2 | _base_ = ['../PixArt_xl2_internal.py'] 3 | data_root = 'pixart-sigma-toy-dataset' 4 | 5 | image_list_json = ['data_info.json'] 6 | 7 | data = dict( 8 | type='DMD', root='InternData', image_list_json=image_list_json, transform='default_train', 9 | load_vae_feat=True, load_t5_feat=True 10 | ) 11 | image_size = 512 12 | 13 | # model setting 14 | model = 'PixArtMS_XL_2' # model for multi-scale training 15 | fp32_attention = True 16 | load_from = "PixArt-alpha/PixArt-XL-2-512x512" 17 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema" 18 | tiny_vae_pretrained = "output/pretrained_models/tinyvae" 19 | aspect_ratio_type = 'ASPECT_RATIO_512' 20 | multi_scale = True # if use multiscale dataset model training 21 | pe_interpolation = 1.0 22 | 23 | # training setting 24 | num_workers = 10 25 | train_batch_size = 1 # max 40 for PixArt-xL/2 when grad_checkpoint 26 | num_epochs = 10 # 3 27 | gradient_accumulation_steps = 1 28 | grad_checkpointing = True 29 | gradient_clip = 0.01 30 | optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16)) 31 | lr_schedule_args = dict(num_warmup_steps=1000) 32 | 33 | log_interval = 20 34 | save_model_epochs=1 35 | save_model_steps=2000 36 | work_dir = 'output/debug' 37 | -------------------------------------------------------------------------------- /configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../PixArt_xl2_internal.py'] 2 | data_root = 'pixart-sigma-toy-dataset' 3 | image_list_json = ['data_info.json'] 4 | 5 | data = dict( 6 | type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train', 7 | load_vae_feat=False, load_t5_feat=False 8 | ) 9 | image_size = 1024 10 | 11 | # model setting 12 | model = 'PixArtMS_XL_2' 13 | mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16'] 14 | fp32_attention = True 15 | load_from = None 16 | resume_from = None 17 | vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae 18 | aspect_ratio_type = 'ASPECT_RATIO_1024' 19 | multi_scale = True # if use multiscale dataset model training 20 | pe_interpolation = 2.0 21 | 22 | # training setting 23 | num_workers = 10 24 | train_batch_size = 2 # 3 for w.o feature extraction; 12 for feature extraction 25 | num_epochs = 2 # 3 26 | gradient_accumulation_steps = 1 27 | grad_checkpointing = True 28 | gradient_clip = 0.01 29 | optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16)) 30 | lr_schedule_args = dict(num_warmup_steps=1000) 31 | 32 | eval_sampling_steps = 500 33 | visualize = True 34 | log_interval = 20 35 | save_model_epochs = 1 36 | save_model_steps = 1000 37 | work_dir = 'output/debug' 38 | 39 | # pixart-sigma 40 | scale_factor = 0.13025 41 | real_prompt_ratio = 0.5 42 | model_max_length = 300 43 | class_dropout_prob = 0.1 44 | 45 | qk_norm = False 46 | skip_step = 0 # skip steps during data loading 47 | -------------------------------------------------------------------------------- /configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms_kvcompress.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../PixArt_xl2_internal.py'] 2 | data_root = 'data' 3 | image_list_json = ['data_info.json'] 4 | 5 | data = dict( 6 | type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train', 7 | load_vae_feat=False, load_t5_feat=False 8 | ) 9 | image_size = 1024 10 | 11 | # model setting 12 | model = 'PixArtMS_XL_2' 13 | mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16'] 14 | fp32_attention = True 15 | load_from = None 16 | resume_from = None 17 | vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae 18 | aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256] 19 | multi_scale = True # if use multiscale dataset model training 20 | pe_interpolation = 2.0 21 | 22 | # training setting 23 | num_workers = 10 24 | train_batch_size = 4 # 16 25 | num_epochs = 2 # 3 26 | gradient_accumulation_steps = 1 27 | grad_checkpointing = True 28 | gradient_clip = 0.01 29 | optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16)) 30 | lr_schedule_args = dict(num_warmup_steps=500) 31 | 32 | eval_sampling_steps = 250 33 | visualize = True 34 | log_interval = 10 35 | save_model_epochs = 1 36 | save_model_steps = 1000 37 | work_dir = 'output/debug' 38 | 39 | # pixart-sigma 40 | scale_factor = 0.13025 41 | real_prompt_ratio = 0.5 42 | model_max_length = 300 43 | class_dropout_prob = 0.1 44 | kv_compress = True 45 | kv_compress_config = { 46 | 'sampling': 'conv', # ['conv', 'uniform', 'ave'] 47 | 'scale_factor': 2, 48 | 'kv_compress_layer': [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27], 49 | } 50 | qk_norm = False 51 | skip_step = 0 # skip steps during data loading 52 | -------------------------------------------------------------------------------- /configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_lcm.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../PixArt_xl2_internal.py'] 2 | data_root = 'pixart-sigma-toy-dataset' 3 | image_list_json = ['data_info.json'] 4 | 5 | data = dict( 6 | type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train', 7 | load_vae_feat=True, load_t5_feat=True, 8 | ) 9 | image_size = 1024 10 | 11 | # model setting 12 | model = 'PixArtMS_XL_2' # model for multi-scale training 13 | fp32_attention = False 14 | load_from = None 15 | resume_from = None 16 | vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae 17 | aspect_ratio_type = 'ASPECT_RATIO_1024' 18 | multi_scale = True # if use multiscale dataset model training 19 | pe_interpolation = 2.0 20 | 21 | # training setting 22 | num_workers = 4 23 | train_batch_size = 12 # max 12 for PixArt-xL/2 when grad_checkpoint 24 | num_epochs = 10 # 3 25 | gradient_accumulation_steps = 1 26 | grad_checkpointing = True 27 | gradient_clip = 0.01 28 | optimizer = dict(type='CAMEWrapper', lr=1e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16)) 29 | lr_schedule_args = dict(num_warmup_steps=100) 30 | save_model_epochs = 10 31 | save_model_steps = 1000 32 | valid_num = 0 # take as valid aspect-ratio when sample number >= valid_num 33 | 34 | log_interval = 10 35 | eval_sampling_steps = 5 36 | visualize = True 37 | work_dir = 'output/debug' 38 | 39 | # pixart-sigma 40 | scale_factor = 0.13025 41 | real_prompt_ratio = 0.5 42 | model_max_length = 300 43 | class_dropout_prob = 0.1 44 | 45 | # LCM 46 | loss_type = 'huber' 47 | huber_c = 0.001 48 | num_ddim_timesteps = 50 49 | w_max = 15.0 50 | w_min = 3.0 51 | ema_decay = 0.95 52 | cfg_scale = 4.5 53 | -------------------------------------------------------------------------------- /configs/pixart_sigma_config/PixArt_sigma_xl2_img256_internal.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../PixArt_xl2_internal.py'] 2 | data_root = 'pixart-sigma-toy-dataset' 3 | image_list_json = ['data_info.json'] 4 | 5 | data = dict( 6 | type='InternalDataSigma', root='InternData', image_list_json=image_list_json, transform='default_train', 7 | load_vae_feat=False, load_t5_feat=False, 8 | ) 9 | image_size = 256 10 | 11 | # model setting 12 | model = 'PixArt_XL_2' 13 | mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16'] 14 | fp32_attention = True 15 | load_from = "output/pretrained_models/PixArt-Sigma-XL-2-256x256.pth" # https://huggingface.co/PixArt-alpha/PixArt-Sigma 16 | resume_from = None 17 | vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae 18 | multi_scale = False # if use multiscale dataset model training 19 | pe_interpolation = 0.5 20 | 21 | # training setting 22 | num_workers = 10 23 | train_batch_size = 64 # 64 as default 24 | num_epochs = 200 # 3 25 | gradient_accumulation_steps = 1 26 | grad_checkpointing = True 27 | gradient_clip = 0.01 28 | optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16)) 29 | lr_schedule_args = dict(num_warmup_steps=1000) 30 | 31 | eval_sampling_steps = 500 32 | log_interval = 20 33 | save_model_epochs = 5 34 | save_model_steps = 2500 35 | work_dir = 'output/debug' 36 | 37 | # pixart-sigma 38 | scale_factor = 0.13025 39 | real_prompt_ratio = 0.5 40 | model_max_length = 300 41 | class_dropout_prob = 0.1 42 | -------------------------------------------------------------------------------- /configs/pixart_sigma_config/PixArt_sigma_xl2_img2K_internalms_kvcompress.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../PixArt_xl2_internal.py'] 2 | data_root = 'data' 3 | image_list_json = ['data_info.json'] 4 | 5 | data = dict( 6 | type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train', 7 | load_vae_feat=False, load_t5_feat=False 8 | ) 9 | image_size = 2048 10 | 11 | # model setting 12 | model = 'PixArtMS_XL_2' 13 | mixed_precision = 'fp16' 14 | fp32_attention = True 15 | load_from = None 16 | resume_from = None 17 | vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae 18 | aspect_ratio_type = 'ASPECT_RATIO_2048' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256] 19 | multi_scale = True # if use multiscale dataset model training 20 | pe_interpolation = 4.0 21 | 22 | # training setting 23 | num_workers = 10 24 | train_batch_size = 4 # 48 25 | num_epochs = 10 # 3 26 | gradient_accumulation_steps = 1 27 | grad_checkpointing = True 28 | gradient_clip = 0.01 29 | optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16)) 30 | lr_schedule_args = dict(num_warmup_steps=100) 31 | 32 | eval_sampling_steps = 100 33 | visualize = True 34 | log_interval = 10 35 | save_model_epochs = 10 36 | save_model_steps = 100 37 | work_dir = 'output/debug' 38 | 39 | # pixart-sigma 40 | scale_factor = 0.13025 41 | real_prompt_ratio = 0.5 42 | model_max_length = 300 43 | class_dropout_prob = 0.1 44 | kv_compress = False 45 | kv_compress_config = { 46 | 'sampling': 'conv', 47 | 'scale_factor': 2, 48 | 'kv_compress_layer': [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27], 49 | } 50 | -------------------------------------------------------------------------------- /configs/pixart_sigma_config/PixArt_sigma_xl2_img512_internalms.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../PixArt_xl2_internal.py'] 2 | data_root = 'pixart-sigma-toy-dataset' 3 | image_list_json = ['data_info.json'] 4 | 5 | data = dict( 6 | type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train', 7 | load_vae_feat=False, load_t5_feat=False, 8 | ) 9 | image_size = 512 10 | 11 | # model setting 12 | model = 'PixArtMS_XL_2' 13 | mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16'] 14 | fp32_attention = True 15 | load_from = "output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth" # https://huggingface.co/PixArt-alpha/PixArt-Sigma 16 | resume_from = None 17 | vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae 18 | aspect_ratio_type = 'ASPECT_RATIO_512' 19 | multi_scale = True # if use multiscale dataset model training 20 | pe_interpolation = 1.0 21 | 22 | # training setting 23 | num_workers = 10 24 | train_batch_size = 2 # 48 as default 25 | num_epochs = 10 # 3 26 | gradient_accumulation_steps = 1 27 | grad_checkpointing = True 28 | gradient_clip = 0.01 29 | optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16)) 30 | lr_schedule_args = dict(num_warmup_steps=1000) 31 | 32 | eval_sampling_steps = 500 33 | visualize = True 34 | log_interval = 20 35 | save_model_epochs = 5 36 | save_model_steps = 2500 37 | work_dir = 'output/debug' 38 | 39 | # pixart-sigma 40 | scale_factor = 0.13025 41 | real_prompt_ratio = 0.5 42 | model_max_length = 300 43 | class_dropout_prob = 0.1 44 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from .iddpm import IDDPM 7 | from .dpm_solver import DPMS 8 | from .sa_sampler import SASolverSampler 9 | -------------------------------------------------------------------------------- /diffusion/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import * 2 | from .transforms import get_transform 3 | -------------------------------------------------------------------------------- /diffusion/data/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | from mmcv import Registry, build_from_cfg 5 | from torch.utils.data import DataLoader 6 | 7 | from diffusion.data.transforms import get_transform 8 | from diffusion.utils.logger import get_root_logger 9 | 10 | DATASETS = Registry('datasets') 11 | 12 | DATA_ROOT = '/cache/data' 13 | 14 | 15 | def set_data_root(data_root): 16 | global DATA_ROOT 17 | DATA_ROOT = data_root 18 | 19 | 20 | def get_data_path(data_dir): 21 | if os.path.isabs(data_dir): 22 | return data_dir 23 | global DATA_ROOT 24 | return os.path.join(DATA_ROOT, data_dir) 25 | 26 | 27 | def get_data_root_and_path(data_dir): 28 | if os.path.isabs(data_dir): 29 | return data_dir 30 | global DATA_ROOT 31 | return DATA_ROOT, os.path.join(DATA_ROOT, data_dir) 32 | 33 | 34 | def build_dataset(cfg, resolution=224, **kwargs): 35 | logger = get_root_logger() 36 | 37 | dataset_type = cfg.get('type') 38 | logger.info(f"Constructing dataset {dataset_type}...") 39 | t = time.time() 40 | transform = cfg.pop('transform', 'default_train') 41 | transform = get_transform(transform, resolution) 42 | dataset = build_from_cfg(cfg, DATASETS, default_args=dict(transform=transform, resolution=resolution, **kwargs)) 43 | logger.info(f"Dataset {dataset_type} constructed. time: {(time.time() - t):.2f} s, length (use/ori): {len(dataset)}/{dataset.ori_imgs_nums}") 44 | return dataset 45 | 46 | 47 | def build_dataloader(dataset, batch_size=256, num_workers=4, shuffle=True, **kwargs): 48 | if 'batch_sampler' in kwargs: 49 | dataloader = DataLoader(dataset, batch_sampler=kwargs['batch_sampler'], num_workers=num_workers, pin_memory=True) 50 | else: 51 | dataloader = DataLoader(dataset, 52 | batch_size=batch_size, 53 | shuffle=shuffle, 54 | num_workers=num_workers, 55 | pin_memory=True, 56 | **kwargs) 57 | return dataloader 58 | -------------------------------------------------------------------------------- /diffusion/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .InternalData import InternalData, InternalDataSigma 2 | from .InternalData_ms import InternalDataMS, InternalDataSigma 3 | from .dmd import DMD 4 | from .utils import * 5 | -------------------------------------------------------------------------------- /diffusion/data/datasets/dmd.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import numpy as np 4 | import glob, torch 5 | from diffusion.data.builder import get_data_root_and_path 6 | from PIL import Image 7 | import json 8 | from diffusers.utils.torch_utils import randn_tensor 9 | from torchvision import transforms as T 10 | from torchvision.datasets.folder import default_loader 11 | 12 | 13 | def read_prompt(prompt_path): 14 | with open(prompt_path, 'r') as f: 15 | prompt_list = f.readlines() 16 | return prompt_list 17 | 18 | 19 | # Dataloader for pixart-dmd model 20 | class DMD(Dataset): 21 | ## rewrite the dataloader to avoid data loading bugs 22 | def __init__(self, 23 | root, 24 | transform=None, 25 | image_list_json='data_info.json', 26 | resolution=512, 27 | load_vae_feat=False, 28 | load_t5_feat=False, 29 | max_samples=None, 30 | max_length=120, 31 | ): 32 | ''' 33 | :param root: the root of saving txt features ./data/data/ 34 | :param latent_root: the root of saving latent image pairs 35 | :param image_list_json: 36 | :param resolution: 37 | :param max_samples: 38 | :param offset: 39 | :param kwargs: 40 | ''' 41 | super().__init__() 42 | DATA_ROOT, root = get_data_root_and_path(root) 43 | self.root = root 44 | self.transform = transform 45 | self.load_vae_feat = load_vae_feat 46 | self.load_t5_feat = load_t5_feat 47 | self.DATA_ROOT = DATA_ROOT 48 | self.ori_imgs_nums = 0 49 | self.resolution = resolution 50 | self.max_samples = max_samples 51 | self.max_lenth = max_length 52 | self.meta_data_clean = [] 53 | self.img_samples = [] 54 | self.txt_feat_samples = [] 55 | self.noise_samples = [] 56 | self.vae_feat_samples = [] 57 | self.gen_image_samples = [] 58 | self.txt_samples = [] 59 | 60 | image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json] 61 | for json_file in image_list_json: 62 | meta_data = self.load_json(os.path.join(self.root, json_file)) 63 | self.ori_imgs_nums += len(meta_data) 64 | meta_data_clean = [item for item in meta_data if item['ratio'] <= 4.5] 65 | self.meta_data_clean.extend(meta_data_clean) 66 | self.img_samples.extend([ 67 | os.path.join(self.root.replace('InternData', 'InternImgs'), item['path']) for item in meta_data_clean 68 | ]) 69 | self.gen_image_samples.extend([ 70 | os.path.join(self.root, 'InternImgs_DMD_images', item['path']) for item in meta_data_clean 71 | ]) 72 | self.txt_samples.extend([item['prompt'] for item in meta_data_clean]) 73 | self.txt_feat_samples.extend([ 74 | os.path.join( 75 | self.root, 76 | 'caption_features_new', 77 | item['path'].rsplit('/', 1)[-1].replace('.png', '.npz') 78 | ) for item in meta_data_clean 79 | ]) 80 | self.noise_samples.extend([ 81 | os.path.join( 82 | self.root, 83 | 'InternImgs_DMD_noises', 84 | item['path'].rsplit('/', 1)[-1].replace('.png', '.npy') 85 | ) for item in meta_data_clean 86 | ]) 87 | self.vae_feat_samples.extend( 88 | [ 89 | os.path.join( 90 | self.root, 91 | 'InternImgs_DMD_latents', 92 | item['path'].rsplit('/', 1)[-1].replace('.png', '.npy') 93 | ) for item in meta_data_clean 94 | ]) 95 | 96 | # Set loader and extensions 97 | if load_vae_feat: 98 | self.transform = None 99 | self.loader = self.latent_feat_loader 100 | else: 101 | self.loader = default_loader 102 | 103 | def __len__(self): 104 | return min(self.max_samples, len(self.img_samples)) 105 | 106 | @staticmethod 107 | def vae_feat_loader(path): 108 | # [mean, std] 109 | mean, std = torch.from_numpy(np.load(path)).chunk(2) 110 | sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype) 111 | return mean + std * sample 112 | 113 | @staticmethod 114 | def latent_feat_loader(path): 115 | return torch.from_numpy(np.load(path)) 116 | 117 | def load_ori_img(self, img_path): 118 | # 加载图像并转换为Tensor 119 | transform = T.Compose([ 120 | T.Resize(512), # Image.BICUBIC 121 | T.CenterCrop(512), 122 | T.ToTensor(), 123 | ]) 124 | img = transform(Image.open(img_path)) 125 | img = img * 2.0 - 1.0 126 | return img 127 | 128 | def load_json(self, file_path): 129 | with open(file_path, 'r') as f: 130 | meta_data = json.load(f) 131 | 132 | return meta_data 133 | 134 | def getdata(self, index): 135 | img_gt_path = self.img_samples[index] 136 | gen_img_path = self.gen_image_samples[index] 137 | npz_path = self.txt_feat_samples[index] 138 | txt = self.txt_samples[index] 139 | npy_path = self.vae_feat_samples[index] 140 | data_info = { 141 | 'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32), 142 | 'aspect_ratio': torch.tensor(1.) 143 | } 144 | 145 | if self.load_vae_feat: 146 | gen_img = self.loader(npy_path) 147 | else: 148 | gen_img = self.loader(gen_img_path) 149 | 150 | attention_mask = torch.ones(1, 1, self.max_lenth) # 1x1xT 151 | if self.load_t5_feat: 152 | txt_info = np.load(npz_path) 153 | txt_fea = torch.from_numpy(txt_info['caption_feature']) # 1xTx4096 154 | if 'attention_mask' in txt_info.keys(): 155 | attention_mask = torch.from_numpy(txt_info['attention_mask'])[None] 156 | if txt_fea.shape[1] < self.max_lenth: 157 | txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1) 158 | attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1) 159 | elif txt_fea.shape[1] > self.max_lenth: 160 | txt_fea = txt_fea[:, :self.max_lenth] 161 | attention_mask = attention_mask[:, :, :self.max_lenth] 162 | else: 163 | txt_fea = txt 164 | 165 | noise = torch.from_numpy(np.load(self.noise_samples[index])) 166 | img_gt = self.load_ori_img(img_gt_path) 167 | 168 | return {'noise': noise, 169 | 'base_latent': gen_img, 170 | 'latent_path': self.vae_feat_samples[index], 171 | 'text': txt, 172 | 'data_info': data_info, 173 | 'txt_fea': txt_fea, 174 | 'attention_mask': attention_mask, 175 | 'img_gt': img_gt} 176 | 177 | def __getitem__(self, idx): 178 | data = self.getdata(idx) 179 | return data 180 | # for _ in range(20): 181 | # try: 182 | # data = self.getdata(idx) 183 | # return data 184 | # except Exception as e: 185 | # print(f"Error details: {str(e)}") 186 | # idx = np.random.randint(len(self)) 187 | # raise RuntimeError('Too many bad data.') 188 | -------------------------------------------------------------------------------- /diffusion/data/datasets/utils.py: -------------------------------------------------------------------------------- 1 | 2 | ASPECT_RATIO_2880 = { 3 | '0.25': [1408.0, 5760.0], '0.26': [1408.0, 5568.0], '0.27': [1408.0, 5376.0], '0.28': [1408.0, 5184.0], 4 | '0.32': [1600.0, 4992.0], '0.33': [1600.0, 4800.0], '0.34': [1600.0, 4672.0], '0.4': [1792.0, 4480.0], 5 | '0.42': [1792.0, 4288.0], '0.47': [1920.0, 4096.0], '0.49': [1920.0, 3904.0], '0.51': [1920.0, 3776.0], 6 | '0.55': [2112.0, 3840.0], '0.59': [2112.0, 3584.0], '0.68': [2304.0, 3392.0], '0.72': [2304.0, 3200.0], 7 | '0.78': [2496.0, 3200.0], '0.83': [2496.0, 3008.0], '0.89': [2688.0, 3008.0], '0.93': [2688.0, 2880.0], 8 | '1.0': [2880.0, 2880.0], '1.07': [2880.0, 2688.0], '1.12': [3008.0, 2688.0], '1.21': [3008.0, 2496.0], 9 | '1.28': [3200.0, 2496.0], '1.39': [3200.0, 2304.0], '1.47': [3392.0, 2304.0], '1.7': [3584.0, 2112.0], 10 | '1.82': [3840.0, 2112.0], '2.03': [3904.0, 1920.0], '2.13': [4096.0, 1920.0], '2.39': [4288.0, 1792.0], 11 | '2.5': [4480.0, 1792.0], '2.92': [4672.0, 1600.0], '3.0': [4800.0, 1600.0], '3.12': [4992.0, 1600.0], 12 | '3.68': [5184.0, 1408.0], '3.82': [5376.0, 1408.0], '3.95': [5568.0, 1408.0], '4.0': [5760.0, 1408.0] 13 | } 14 | 15 | ASPECT_RATIO_2048 = { 16 | '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0], '0.27': [1024.0, 3840.0], '0.28': [1024.0, 3712.0], 17 | '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0], 18 | '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0], 19 | '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0], 20 | '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0], 21 | '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0], 22 | '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0], 23 | '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0], 24 | '2.5': [3200.0, 1280.0], '2.89': [3328.0, 1152.0], '3.0': [3456.0, 1152.0], '3.11': [3584.0, 1152.0], 25 | '3.62': [3712.0, 1024.0], '3.75': [3840.0, 1024.0], '3.88': [3968.0, 1024.0], '4.0': [4096.0, 1024.0] 26 | } 27 | 28 | ASPECT_RATIO_1024 = { 29 | '0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.], 30 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.], 31 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.], 32 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.], 33 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.], 34 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.], 35 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.], 36 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.], 37 | '2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.], 38 | '3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.], 39 | } 40 | 41 | ASPECT_RATIO_512 = { 42 | '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0], 43 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], 44 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], 45 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], 46 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], 47 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], 48 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], 49 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], 50 | '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0], 51 | '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0] 52 | } 53 | 54 | ASPECT_RATIO_256 = { 55 | '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0], 56 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0], 57 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0], 58 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0], 59 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0], 60 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0], 61 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0], 62 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0], 63 | '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0], 64 | '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0] 65 | } 66 | 67 | ASPECT_RATIO_256_TEST = { 68 | '0.25': [128.0, 512.0], '0.28': [128.0, 464.0], 69 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0], 70 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0], 71 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0], 72 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0], 73 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0], 74 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0], 75 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0], 76 | '2.5': [400.0, 160.0], '3.0': [432.0, 144.0], 77 | '4.0': [512.0, 128.0] 78 | } 79 | 80 | ASPECT_RATIO_512_TEST = { 81 | '0.25': [256.0, 1024.0], '0.28': [256.0, 928.0], 82 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], 83 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], 84 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], 85 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], 86 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], 87 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], 88 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], 89 | '2.5': [800.0, 320.0], '3.0': [864.0, 288.0], 90 | '4.0': [1024.0, 256.0] 91 | } 92 | 93 | ASPECT_RATIO_1024_TEST = { 94 | '0.25': [512., 2048.], '0.28': [512., 1856.], 95 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.], 96 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.], 97 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.], 98 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.], 99 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.], 100 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.], 101 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.], 102 | '2.5': [1600., 640.], '3.0': [1728., 576.], 103 | '4.0': [2048., 512.], 104 | } 105 | 106 | ASPECT_RATIO_2048_TEST = { 107 | '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0], 108 | '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0], 109 | '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0], 110 | '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0], 111 | '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0], 112 | '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0], 113 | '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0], 114 | '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0], 115 | '2.5': [3200.0, 1280.0], '3.0': [3456.0, 1152.0], 116 | '4.0': [4096.0, 1024.0] 117 | } 118 | 119 | ASPECT_RATIO_2880_TEST = { 120 | '0.25': [2048.0, 8192.0], '0.26': [2048.0, 7936.0], 121 | '0.32': [2304.0, 7168.0], '0.33': [2304.0, 6912.0], '0.35': [2304.0, 6656.0], '0.4': [2560.0, 6400.0], 122 | '0.42': [2560.0, 6144.0], '0.48': [2816.0, 5888.0], '0.5': [2816.0, 5632.0], '0.52': [2816.0, 5376.0], 123 | '0.57': [3072.0, 5376.0], '0.6': [3072.0, 5120.0], '0.68': [3328.0, 4864.0], '0.72': [3328.0, 4608.0], 124 | '0.78': [3584.0, 4608.0], '0.82': [3584.0, 4352.0], '0.88': [3840.0, 4352.0], '0.94': [3840.0, 4096.0], 125 | '1.0': [4096.0, 4096.0], '1.07': [4096.0, 3840.0], '1.13': [4352.0, 3840.0], '1.21': [4352.0, 3584.0], 126 | '1.29': [4608.0, 3584.0], '1.38': [4608.0, 3328.0], '1.46': [4864.0, 3328.0], '1.67': [5120.0, 3072.0], 127 | '1.75': [5376.0, 3072.0], '2.0': [5632.0, 2816.0], '2.09': [5888.0, 2816.0], '2.4': [6144.0, 2560.0], 128 | '2.5': [6400.0, 2560.0], '3.0': [6912.0, 2304.0], 129 | '4.0': [8192.0, 2048.0], 130 | } 131 | 132 | def get_chunks(lst, n): 133 | for i in range(0, len(lst), n): 134 | yield lst[i:i + n] 135 | -------------------------------------------------------------------------------- /diffusion/data/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | 3 | TRANSFORMS = dict() 4 | 5 | 6 | def register_transform(transform): 7 | name = transform.__name__ 8 | if name in TRANSFORMS: 9 | raise RuntimeError(f'Transform {name} has already registered.') 10 | TRANSFORMS.update({name: transform}) 11 | 12 | 13 | def get_transform(type, resolution): 14 | transform = TRANSFORMS[type](resolution) 15 | transform = T.Compose(transform) 16 | transform.image_size = resolution 17 | return transform 18 | 19 | 20 | @register_transform 21 | def default_train(n_px): 22 | transform = [ 23 | T.Lambda(lambda img: img.convert('RGB')), 24 | T.Resize(n_px), # Image.BICUBIC 25 | T.CenterCrop(n_px), 26 | # T.RandomHorizontalFlip(), 27 | T.ToTensor(), 28 | T.Normalize([.5], [.5]), 29 | ] 30 | return transform 31 | -------------------------------------------------------------------------------- /diffusion/dpm_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .model import gaussian_diffusion as gd 3 | from .model.dpm_solver import model_wrapper, DPM_Solver, NoiseScheduleVP 4 | 5 | 6 | def DPMS( 7 | model, 8 | condition, 9 | uncondition, 10 | cfg_scale, 11 | model_type='noise', # or "x_start" or "v" or "score" 12 | noise_schedule="linear", 13 | guidance_type='classifier-free', 14 | model_kwargs={}, 15 | diffusion_steps=1000 16 | ): 17 | betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps)) 18 | 19 | ## 1. Define the noise schedule. 20 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas) 21 | 22 | ## 2. Convert your discrete-time `model` to the continuous-time 23 | ## noise prediction model. Here is an example for a diffusion model 24 | ## `model` with the noise prediction type ("noise") . 25 | model_fn = model_wrapper( 26 | model, 27 | noise_schedule, 28 | model_type=model_type, 29 | model_kwargs=model_kwargs, 30 | guidance_type=guidance_type, 31 | condition=condition, 32 | unconditional_condition=uncondition, 33 | guidance_scale=cfg_scale, 34 | ) 35 | ## 3. Define dpm-solver and sample by multistep DPM-Solver. 36 | return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") -------------------------------------------------------------------------------- /diffusion/iddpm.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | from diffusion.model.respace import SpacedDiffusion, space_timesteps 6 | from .model import gaussian_diffusion as gd 7 | 8 | 9 | def IDDPM( 10 | timestep_respacing, 11 | noise_schedule="linear", 12 | use_kl=False, 13 | sigma_small=False, 14 | predict_xstart=False, 15 | learn_sigma=True, 16 | pred_sigma=True, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000, 19 | snr=False, 20 | return_startx=False, 21 | ): 22 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 23 | if use_kl: 24 | loss_type = gd.LossType.RESCALED_KL 25 | elif rescale_learned_sigmas: 26 | loss_type = gd.LossType.RESCALED_MSE 27 | else: 28 | loss_type = gd.LossType.MSE 29 | if timestep_respacing is None or timestep_respacing == "": 30 | timestep_respacing = [diffusion_steps] 31 | return SpacedDiffusion( 32 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 33 | betas=betas, 34 | model_mean_type=( 35 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 36 | ), 37 | model_var_type=( 38 | (( 39 | gd.ModelVarType.FIXED_LARGE 40 | if not sigma_small 41 | else gd.ModelVarType.FIXED_SMALL 42 | ) 43 | if not learn_sigma 44 | else gd.ModelVarType.LEARNED_RANGE 45 | ) 46 | if pred_sigma 47 | else None 48 | ), 49 | loss_type=loss_type, 50 | snr=snr, 51 | return_startx=return_startx, 52 | # rescale_timesteps=rescale_timesteps, 53 | ) -------------------------------------------------------------------------------- /diffusion/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .nets import * 2 | -------------------------------------------------------------------------------- /diffusion/model/builder.py: -------------------------------------------------------------------------------- 1 | from mmcv import Registry 2 | 3 | from diffusion.model.utils import set_grad_checkpoint 4 | 5 | MODELS = Registry('models') 6 | 7 | 8 | def build_model(cfg, use_grad_checkpoint=False, use_fp32_attention=False, gc_step=1, **kwargs): 9 | if isinstance(cfg, str): 10 | cfg = dict(type=cfg) 11 | model = MODELS.build(cfg, default_args=kwargs) 12 | if use_grad_checkpoint: 13 | set_grad_checkpoint(model, use_fp32_attention=use_fp32_attention, gc_step=gc_step) 14 | return model 15 | -------------------------------------------------------------------------------- /diffusion/model/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /diffusion/model/edm_sample.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | from diffusion.model.utils import * 6 | 7 | 8 | # ---------------------------------------------------------------------------- 9 | # Proposed EDM sampler (Algorithm 2). 10 | 11 | def edm_sampler( 12 | net, latents, class_labels=None, cfg_scale=None, randn_like=torch.randn_like, 13 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, 14 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **kwargs 15 | ): 16 | # Adjust noise levels based on what's supported by the network. 17 | sigma_min = max(sigma_min, net.sigma_min) 18 | sigma_max = min(sigma_max, net.sigma_max) 19 | 20 | # Time step discretization. 21 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 22 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * ( 23 | sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 24 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 25 | 26 | # Main sampling loop. 27 | x_next = latents.to(torch.float64) * t_steps[0] 28 | for i, (t_cur, t_next) in tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:])))): # 0, ..., N-1 29 | x_cur = x_next 30 | 31 | # Increase noise temporarily. 32 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 33 | t_hat = net.round_sigma(t_cur + gamma * t_cur) 34 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) 35 | 36 | # Euler step. 37 | denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64) 38 | d_cur = (x_hat - denoised) / t_hat 39 | x_next = x_hat + (t_next - t_hat) * d_cur 40 | 41 | # Apply 2nd order correction. 42 | if i < num_steps - 1: 43 | denoised = net(x_next.float(), t_next, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64) 44 | d_prime = (x_next - denoised) / t_next 45 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 46 | 47 | return x_next 48 | 49 | 50 | # ---------------------------------------------------------------------------- 51 | # Generalized ablation sampler, representing the superset of all sampling 52 | # methods discussed in the paper. 53 | 54 | def ablation_sampler( 55 | net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like, 56 | num_steps=18, sigma_min=None, sigma_max=None, rho=7, 57 | solver='heun', discretization='edm', schedule='linear', scaling='none', 58 | epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1, 59 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 60 | ): 61 | assert solver in ['euler', 'heun'] 62 | assert discretization in ['vp', 've', 'iddpm', 'edm'] 63 | assert schedule in ['vp', 've', 'linear'] 64 | assert scaling in ['vp', 'none'] 65 | 66 | # Helper functions for VP & VE noise level schedules. 67 | vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5 68 | vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t)) 69 | vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * ( 70 | sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d 71 | ve_sigma = lambda t: t.sqrt() 72 | ve_sigma_deriv = lambda t: 0.5 / t.sqrt() 73 | ve_sigma_inv = lambda sigma: sigma ** 2 74 | 75 | # Select default noise level range based on the specified time step discretization. 76 | if sigma_min is None: 77 | vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s) 78 | sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization] 79 | if sigma_max is None: 80 | vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1) 81 | sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization] 82 | 83 | # Adjust noise levels based on what's supported by the network. 84 | sigma_min = max(sigma_min, net.sigma_min) 85 | sigma_max = min(sigma_max, net.sigma_max) 86 | 87 | # Compute corresponding betas for VP. 88 | vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1) 89 | vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d 90 | 91 | # Define time steps in terms of noise level. 92 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 93 | if discretization == 'vp': 94 | orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) 95 | sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) 96 | elif discretization == 've': 97 | orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1))) 98 | sigma_steps = ve_sigma(orig_t_steps) 99 | elif discretization == 'iddpm': 100 | u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) 101 | alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 102 | for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 103 | u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt() 104 | u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] 105 | sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] 106 | else: 107 | assert discretization == 'edm' 108 | sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * ( 109 | sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 110 | 111 | # Define noise level schedule. 112 | if schedule == 'vp': 113 | sigma = vp_sigma(vp_beta_d, vp_beta_min) 114 | sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) 115 | sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) 116 | elif schedule == 've': 117 | sigma = ve_sigma 118 | sigma_deriv = ve_sigma_deriv 119 | sigma_inv = ve_sigma_inv 120 | else: 121 | assert schedule == 'linear' 122 | sigma = lambda t: t 123 | sigma_deriv = lambda t: 1 124 | sigma_inv = lambda sigma: sigma 125 | 126 | # Define scaling schedule. 127 | if scaling == 'vp': 128 | s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() 129 | s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) 130 | else: 131 | assert scaling == 'none' 132 | s = lambda t: 1 133 | s_deriv = lambda t: 0 134 | 135 | # Compute final time steps based on the corresponding noise levels. 136 | t_steps = sigma_inv(net.round_sigma(sigma_steps)) 137 | t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 138 | 139 | # Main sampling loop. 140 | t_next = t_steps[0] 141 | x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) 142 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 143 | x_cur = x_next 144 | 145 | # Increase noise temporarily. 146 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0 147 | t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) 148 | x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s( 149 | t_hat) * S_noise * randn_like(x_cur) 150 | 151 | # Euler step. 152 | h = t_next - t_hat 153 | denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to( 154 | torch.float64) 155 | d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s( 156 | t_hat) / sigma(t_hat) * denoised 157 | x_prime = x_hat + alpha * h * d_cur 158 | t_prime = t_hat + alpha * h 159 | 160 | # Apply 2nd order correction. 161 | if solver == 'euler' or i == num_steps - 1: 162 | x_next = x_hat + h * d_cur 163 | else: 164 | assert solver == 'heun' 165 | denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to( 166 | torch.float64) 167 | d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv( 168 | t_prime) * s(t_prime) / sigma(t_prime) * denoised 169 | x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime) 170 | 171 | return x_next 172 | -------------------------------------------------------------------------------- /diffusion/model/llava/__init__.py: -------------------------------------------------------------------------------- 1 | from diffusion.model.llava.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig -------------------------------------------------------------------------------- /diffusion/model/llava/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | from .attention import ATTN_CLASS_REGISTRY 6 | from .norm import NORM_CLASS_REGISTRY 7 | 8 | class MPTMLP(nn.Module): 9 | 10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): 11 | super().__init__() 12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 13 | self.act = nn.GELU(approximate='none') 14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 15 | self.down_proj._is_residual = True 16 | 17 | def forward(self, x): 18 | return self.down_proj(self.act(self.up_proj(x))) 19 | 20 | class MPTBlock(nn.Module): 21 | 22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs): 23 | del kwargs 24 | super().__init__() 25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] 27 | self.norm_1 = norm_class(d_model, device=device) 28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device) 29 | self.norm_2 = norm_class(d_model, device=device) 30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) 31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 33 | 34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 35 | a = self.norm_1(x) 36 | (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) 37 | x = x + self.resid_attn_dropout(b) 38 | m = self.norm_2(x) 39 | n = self.ffn(m) 40 | x = x + self.resid_ffn_dropout(n) 41 | return (x, past_key_value) -------------------------------------------------------------------------------- /diffusion/model/llava/mpt/configuration_mpt.py: -------------------------------------------------------------------------------- 1 | """A HuggingFace-style model configuration.""" 2 | from typing import Dict, Optional, Union 3 | from transformers import PretrainedConfig 4 | attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8} 5 | init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'} 6 | 7 | class MPTConfig(PretrainedConfig): 8 | model_type = 'mpt' 9 | 10 | def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs): 11 | """The MPT configuration class. 12 | 13 | Args: 14 | d_model (int): The size of the embedding dimension of the model. 15 | n_heads (int): The number of attention heads. 16 | n_layers (int): The number of layers in the model. 17 | expansion_ratio (int): The ratio of the up/down scale in the MLP. 18 | max_seq_len (int): The maximum sequence length of the model. 19 | vocab_size (int): The size of the vocabulary. 20 | resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. 21 | emb_pdrop (float): The dropout probability for the embedding layer. 22 | learned_pos_emb (bool): Whether to use learned positional embeddings 23 | attn_config (Dict): A dictionary used to configure the model's attention module: 24 | attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention 25 | attn_pdrop (float): The dropout probability for the attention layers. 26 | attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. 27 | qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. 28 | clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to 29 | this value. 30 | softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, 31 | use the default scale of ``1/sqrt(d_keys)``. 32 | prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an 33 | extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix 34 | can attend to one another bi-directionally. Tokens outside the prefix use causal attention. 35 | attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. 36 | When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates 37 | which sub-sequence each token belongs to. 38 | Defaults to ``False`` meaning any provided `sequence_id` will be ignored. 39 | alibi (bool): Whether to use the alibi bias instead of position embeddings. 40 | alibi_bias_max (int): The maximum value of the alibi bias. 41 | init_device (str): The device to use for parameter initialization. 42 | logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. 43 | no_bias (bool): Whether to use bias in all layers. 44 | verbose (int): The verbosity level. 0 is silent. 45 | embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. 46 | norm_type (str): choose type of norm to use 47 | multiquery_attention (bool): Whether to use multiquery attention implementation. 48 | use_cache (bool): Whether or not the model should return the last key/values attentions 49 | init_config (Dict): A dictionary used to configure the model initialization: 50 | init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', 51 | 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 52 | 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. 53 | init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. 54 | emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. 55 | emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution 56 | used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. 57 | init_std (float): The standard deviation of the normal distribution used to initialize the model, 58 | if using the baseline_ parameter initialization scheme. 59 | init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. 60 | fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. 61 | init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. 62 | --- 63 | See llmfoundry.models.utils.param_init_fns.py for info on other param init config options 64 | """ 65 | self.d_model = d_model 66 | self.n_heads = n_heads 67 | self.n_layers = n_layers 68 | self.expansion_ratio = expansion_ratio 69 | self.max_seq_len = max_seq_len 70 | self.vocab_size = vocab_size 71 | self.resid_pdrop = resid_pdrop 72 | self.emb_pdrop = emb_pdrop 73 | self.learned_pos_emb = learned_pos_emb 74 | self.attn_config = attn_config 75 | self.init_device = init_device 76 | self.logit_scale = logit_scale 77 | self.no_bias = no_bias 78 | self.verbose = verbose 79 | self.embedding_fraction = embedding_fraction 80 | self.norm_type = norm_type 81 | self.use_cache = use_cache 82 | self.init_config = init_config 83 | if 'name' in kwargs: 84 | del kwargs['name'] 85 | if 'loss_fn' in kwargs: 86 | del kwargs['loss_fn'] 87 | super().__init__(**kwargs) 88 | self._validate_config() 89 | 90 | def _set_config_defaults(self, config, config_defaults): 91 | for (k, v) in config_defaults.items(): 92 | if k not in config: 93 | config[k] = v 94 | return config 95 | 96 | def _validate_config(self): 97 | self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults) 98 | self.init_config = self._set_config_defaults(self.init_config, init_config_defaults) 99 | if self.d_model % self.n_heads != 0: 100 | raise ValueError('d_model must be divisible by n_heads') 101 | if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])): 102 | raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1") 103 | if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: 104 | raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") 105 | if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 106 | raise NotImplementedError('prefix_lm only implemented with torch and triton attention.') 107 | if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 108 | raise NotImplementedError('alibi only implemented with torch and triton attention.') 109 | if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 110 | raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.') 111 | if self.embedding_fraction > 1 or self.embedding_fraction <= 0: 112 | raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!') 113 | if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model': 114 | raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 115 | if self.init_config.get('name', None) is None: 116 | raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.") 117 | if not self.learned_pos_emb and (not self.attn_config['alibi']): 118 | raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.') -------------------------------------------------------------------------------- /diffusion/model/llava/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _cast_if_autocast_enabled(tensor): 4 | if torch.is_autocast_enabled(): 5 | if tensor.device.type == 'cuda': 6 | dtype = torch.get_autocast_gpu_dtype() 7 | elif tensor.device.type == 'cpu': 8 | dtype = torch.get_autocast_cpu_dtype() 9 | else: 10 | raise NotImplementedError() 11 | return tensor.to(dtype=dtype) 12 | return tensor 13 | 14 | class LPLayerNorm(torch.nn.LayerNorm): 15 | 16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): 17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) 18 | 19 | def forward(self, x): 20 | module_device = x.device 21 | downcast_x = _cast_if_autocast_enabled(x) 22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 24 | with torch.autocast(enabled=False, device_type=module_device.type): 25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) 26 | 27 | def rms_norm(x, weight=None, eps=1e-05): 28 | output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 29 | if weight is not None: 30 | return output * weight 31 | return output 32 | 33 | class RMSNorm(torch.nn.Module): 34 | 35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 36 | super().__init__() 37 | self.eps = eps 38 | if weight: 39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) 40 | else: 41 | self.register_parameter('weight', None) 42 | 43 | def forward(self, x): 44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 45 | 46 | class LPRMSNorm(RMSNorm): 47 | 48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) 50 | 51 | def forward(self, x): 52 | downcast_x = _cast_if_autocast_enabled(x) 53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 54 | with torch.autocast(enabled=False, device_type=x.device.type): 55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} -------------------------------------------------------------------------------- /diffusion/model/nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .PixArt import PixArt, PixArt_XL_2 2 | from .PixArtMS import PixArtMS, PixArtMS_XL_2, PixArtMSBlock -------------------------------------------------------------------------------- /diffusion/model/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def training_losses_diffusers( 100 | self, model, *args, **kwargs 101 | ): # pylint: disable=signature-differs 102 | return super().training_losses_diffusers(self._wrap_model(model), *args, **kwargs) 103 | 104 | def condition_mean(self, cond_fn, *args, **kwargs): 105 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 106 | 107 | def condition_score(self, cond_fn, *args, **kwargs): 108 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 109 | 110 | def _wrap_model(self, model): 111 | if isinstance(model, _WrappedModel): 112 | return model 113 | return _WrappedModel( 114 | model, self.timestep_map, self.original_num_steps 115 | ) 116 | 117 | def _scale_timesteps(self, t): 118 | # Scaling is done by the wrapped model. 119 | return t 120 | 121 | 122 | class _WrappedModel: 123 | def __init__(self, model, timestep_map, original_num_steps): 124 | self.model = model 125 | self.timestep_map = timestep_map 126 | # self.rescale_timesteps = rescale_timesteps 127 | self.original_num_steps = original_num_steps 128 | 129 | def __call__(self, x, timestep, **kwargs): 130 | map_tensor = th.tensor(self.timestep_map, device=timestep.device, dtype=timestep.dtype) 131 | new_ts = map_tensor[timestep] 132 | # if self.rescale_timesteps: 133 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 134 | return self.model(x, timestep=new_ts, **kwargs) 135 | -------------------------------------------------------------------------------- /diffusion/model/t5.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import re 4 | import html 5 | import urllib.parse as ul 6 | 7 | import ftfy 8 | import torch 9 | from bs4 import BeautifulSoup 10 | from transformers import T5EncoderModel, AutoTokenizer 11 | from huggingface_hub import hf_hub_download 12 | 13 | class T5Embedder: 14 | 15 | available_models = ['t5-v1_1-xxl'] 16 | bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa 17 | 18 | def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True, 19 | t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120): 20 | self.device = torch.device(device) 21 | self.torch_dtype = torch_dtype or torch.bfloat16 22 | if t5_model_kwargs is None: 23 | t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype} 24 | if use_offload_folder is not None: 25 | t5_model_kwargs['offload_folder'] = use_offload_folder 26 | t5_model_kwargs['device_map'] = { 27 | 'shared': self.device, 28 | 'encoder.embed_tokens': self.device, 29 | 'encoder.block.0': self.device, 30 | 'encoder.block.1': self.device, 31 | 'encoder.block.2': self.device, 32 | 'encoder.block.3': self.device, 33 | 'encoder.block.4': self.device, 34 | 'encoder.block.5': self.device, 35 | 'encoder.block.6': self.device, 36 | 'encoder.block.7': self.device, 37 | 'encoder.block.8': self.device, 38 | 'encoder.block.9': self.device, 39 | 'encoder.block.10': self.device, 40 | 'encoder.block.11': self.device, 41 | 'encoder.block.12': 'disk', 42 | 'encoder.block.13': 'disk', 43 | 'encoder.block.14': 'disk', 44 | 'encoder.block.15': 'disk', 45 | 'encoder.block.16': 'disk', 46 | 'encoder.block.17': 'disk', 47 | 'encoder.block.18': 'disk', 48 | 'encoder.block.19': 'disk', 49 | 'encoder.block.20': 'disk', 50 | 'encoder.block.21': 'disk', 51 | 'encoder.block.22': 'disk', 52 | 'encoder.block.23': 'disk', 53 | 'encoder.final_layer_norm': 'disk', 54 | 'encoder.dropout': 'disk', 55 | } 56 | else: 57 | t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device} 58 | 59 | self.use_text_preprocessing = use_text_preprocessing 60 | self.hf_token = hf_token 61 | self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_') 62 | self.dir_or_name = dir_or_name 63 | tokenizer_path, path = dir_or_name, dir_or_name 64 | if local_cache: 65 | cache_dir = os.path.join(self.cache_dir, dir_or_name) 66 | tokenizer_path, path = cache_dir, cache_dir 67 | elif dir_or_name in self.available_models: 68 | cache_dir = os.path.join(self.cache_dir, dir_or_name) 69 | for filename in [ 70 | 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', 71 | 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin' 72 | ]: 73 | hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir, 74 | force_filename=filename, token=self.hf_token) 75 | tokenizer_path, path = cache_dir, cache_dir 76 | else: 77 | cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl') 78 | for filename in [ 79 | 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', 80 | ]: 81 | hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir, 82 | force_filename=filename, token=self.hf_token) 83 | tokenizer_path = cache_dir 84 | 85 | print(tokenizer_path) 86 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 87 | self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval() 88 | self.model_max_length = model_max_length 89 | 90 | def get_text_embeddings(self, texts): 91 | texts = [self.text_preprocessing(text) for text in texts] 92 | 93 | text_tokens_and_mask = self.tokenizer( 94 | texts, 95 | max_length=self.model_max_length, 96 | padding='max_length', 97 | truncation=True, 98 | return_attention_mask=True, 99 | add_special_tokens=True, 100 | return_tensors='pt' 101 | ) 102 | 103 | text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids'] 104 | text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask'] 105 | 106 | with torch.no_grad(): 107 | text_encoder_embs = self.model( 108 | input_ids=text_tokens_and_mask['input_ids'].to(self.device), 109 | attention_mask=text_tokens_and_mask['attention_mask'].to(self.device), 110 | )['last_hidden_state'].detach() 111 | return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device) 112 | 113 | def text_preprocessing(self, text): 114 | if self.use_text_preprocessing: 115 | # The exact text cleaning as was in the training stage: 116 | text = self.clean_caption(text) 117 | text = self.clean_caption(text) 118 | return text 119 | else: 120 | return text.lower().strip() 121 | 122 | @staticmethod 123 | def basic_clean(text): 124 | text = ftfy.fix_text(text) 125 | text = html.unescape(html.unescape(text)) 126 | return text.strip() 127 | 128 | def clean_caption(self, caption): 129 | caption = str(caption) 130 | caption = ul.unquote_plus(caption) 131 | caption = caption.strip().lower() 132 | caption = re.sub('', 'person', caption) 133 | # urls: 134 | caption = re.sub( 135 | r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa 136 | '', caption) # regex for urls 137 | caption = re.sub( 138 | r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa 139 | '', caption) # regex for urls 140 | # html: 141 | caption = BeautifulSoup(caption, features='html.parser').text 142 | 143 | # @ 144 | caption = re.sub(r'@[\w\d]+\b', '', caption) 145 | 146 | # 31C0—31EF CJK Strokes 147 | # 31F0—31FF Katakana Phonetic Extensions 148 | # 3200—32FF Enclosed CJK Letters and Months 149 | # 3300—33FF CJK Compatibility 150 | # 3400—4DBF CJK Unified Ideographs Extension A 151 | # 4DC0—4DFF Yijing Hexagram Symbols 152 | # 4E00—9FFF CJK Unified Ideographs 153 | caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) 154 | caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) 155 | caption = re.sub(r'[\u3200-\u32ff]+', '', caption) 156 | caption = re.sub(r'[\u3300-\u33ff]+', '', caption) 157 | caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) 158 | caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) 159 | caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) 160 | ####################################################### 161 | 162 | # все виды тире / all types of dash --> "-" 163 | caption = re.sub( 164 | r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa 165 | '-', caption) 166 | 167 | # кавычки к одному стандарту 168 | caption = re.sub(r'[`´«»“”¨]', '"', caption) 169 | caption = re.sub(r'[‘’]', "'", caption) 170 | 171 | # " 172 | caption = re.sub(r'"?', '', caption) 173 | # & 174 | caption = re.sub(r'&', '', caption) 175 | 176 | # ip adresses: 177 | caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) 178 | 179 | # article ids: 180 | caption = re.sub(r'\d:\d\d\s+$', '', caption) 181 | 182 | # \n 183 | caption = re.sub(r'\\n', ' ', caption) 184 | 185 | # "#123" 186 | caption = re.sub(r'#\d{1,3}\b', '', caption) 187 | # "#12345.." 188 | caption = re.sub(r'#\d{5,}\b', '', caption) 189 | # "123456.." 190 | caption = re.sub(r'\b\d{6,}\b', '', caption) 191 | # filenames: 192 | caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) 193 | 194 | # 195 | caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT""" 196 | caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT""" 197 | 198 | caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT 199 | caption = re.sub(r'\s+\.\s+', r' ', caption) # " . " 200 | 201 | # this-is-my-cute-cat / this_is_my_cute_cat 202 | regex2 = re.compile(r'(?:\-|\_)') 203 | if len(re.findall(regex2, caption)) > 3: 204 | caption = re.sub(regex2, ' ', caption) 205 | 206 | caption = self.basic_clean(caption) 207 | 208 | caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 209 | caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc 210 | caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 211 | 212 | caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) 213 | caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) 214 | caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) 215 | caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) 216 | caption = re.sub(r'\bpage\s+\d+\b', '', caption) 217 | 218 | caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... 219 | 220 | caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) 221 | 222 | caption = re.sub(r'\b\s+\:\s+', r': ', caption) 223 | caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) 224 | caption = re.sub(r'\s+', ' ', caption) 225 | 226 | caption.strip() 227 | 228 | caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) 229 | caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) 230 | caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) 231 | caption = re.sub(r'^\.\S+$', '', caption) 232 | 233 | return caption.strip() 234 | -------------------------------------------------------------------------------- /diffusion/model/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs, device=local_ts.device) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs, device=local_losses.device) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /diffusion/sa_sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from diffusion.model.sa_solver import NoiseScheduleVP, model_wrapper, SASolver 7 | from .model import gaussian_diffusion as gd 8 | 9 | 10 | class SASolverSampler(object): 11 | def __init__(self, model, 12 | noise_schedule="linear", 13 | diffusion_steps=1000, 14 | device='cpu', 15 | ): 16 | super().__init__() 17 | self.model = model 18 | self.device = device 19 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(device) 20 | betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps)) 21 | alphas = 1.0 - betas 22 | self.register_buffer('alphas_cumprod', to_torch(np.cumprod(alphas, axis=0))) 23 | 24 | def register_buffer(self, name, attr): 25 | if type(attr) == torch.Tensor: 26 | if attr.device != torch.device("cuda"): 27 | attr = attr.to(torch.device("cuda")) 28 | setattr(self, name, attr) 29 | 30 | @torch.no_grad() 31 | def sample(self, 32 | S, 33 | batch_size, 34 | shape, 35 | conditioning=None, 36 | callback=None, 37 | normals_sequence=None, 38 | img_callback=None, 39 | quantize_x0=False, 40 | eta=0., 41 | mask=None, 42 | x0=None, 43 | temperature=1., 44 | noise_dropout=0., 45 | score_corrector=None, 46 | corrector_kwargs=None, 47 | verbose=True, 48 | x_T=None, 49 | log_every_t=100, 50 | unconditional_guidance_scale=1., 51 | unconditional_conditioning=None, 52 | model_kwargs={}, 53 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 54 | **kwargs 55 | ): 56 | if conditioning is not None: 57 | if isinstance(conditioning, dict): 58 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 59 | if cbs != batch_size: 60 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 61 | else: 62 | if conditioning.shape[0] != batch_size: 63 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 64 | 65 | # sampling 66 | C, H, W = shape 67 | size = (batch_size, C, H, W) 68 | 69 | device = self.device 70 | if x_T is None: 71 | img = torch.randn(size, device=device) 72 | else: 73 | img = x_T 74 | 75 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 76 | 77 | model_fn = model_wrapper( 78 | self.model, 79 | ns, 80 | model_type="noise", 81 | guidance_type="classifier-free", 82 | condition=conditioning, 83 | unconditional_condition=unconditional_conditioning, 84 | guidance_scale=unconditional_guidance_scale, 85 | model_kwargs=model_kwargs, 86 | ) 87 | 88 | sasolver = SASolver(model_fn, ns, algorithm_type="data_prediction") 89 | 90 | tau_t = lambda t: eta if 0.2 <= t <= 0.8 else 0 91 | 92 | x = sasolver.sample(mode='few_steps', x=img, tau=tau_t, steps=S, skip_type='time', skip_order=1, predictor_order=2, corrector_order=2, pc_mode='PEC', return_intermediate=False) 93 | 94 | return x.to(device), None -------------------------------------------------------------------------------- /diffusion/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma/1ce521afddcc2fab329b35b7374aa86d654e12f7/diffusion/utils/__init__.py -------------------------------------------------------------------------------- /diffusion/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import torch 4 | 5 | from diffusion.utils.logger import get_root_logger 6 | 7 | 8 | def save_checkpoint(work_dir, 9 | epoch, 10 | model, 11 | model_ema=None, 12 | optimizer=None, 13 | lr_scheduler=None, 14 | keep_last=False, 15 | step=None, 16 | ): 17 | os.makedirs(work_dir, exist_ok=True) 18 | state_dict = dict(state_dict=model.state_dict()) 19 | if model_ema is not None: 20 | state_dict['state_dict_ema'] = model_ema.state_dict() 21 | if optimizer is not None: 22 | state_dict['optimizer'] = optimizer.state_dict() 23 | if lr_scheduler is not None: 24 | state_dict['scheduler'] = lr_scheduler.state_dict() 25 | if epoch is not None: 26 | state_dict['epoch'] = epoch 27 | file_path = os.path.join(work_dir, f"epoch_{epoch}.pth") 28 | if step is not None: 29 | file_path = file_path.split('.pth')[0] + f"_step_{step}.pth" 30 | logger = get_root_logger() 31 | torch.save(state_dict, file_path) 32 | logger.info(f'Saved checkpoint of epoch {epoch} to {file_path.format(epoch)}.') 33 | if keep_last: 34 | for i in range(epoch): 35 | previous_ckgt = file_path.format(i) 36 | if os.path.exists(previous_ckgt): 37 | os.remove(previous_ckgt) 38 | 39 | 40 | def load_checkpoint(checkpoint, 41 | model, 42 | model_ema=None, 43 | optimizer=None, 44 | lr_scheduler=None, 45 | load_ema=False, 46 | resume_optimizer=True, 47 | resume_lr_scheduler=True, 48 | max_length=120, 49 | ): 50 | assert isinstance(checkpoint, str) 51 | ckpt_file = checkpoint 52 | checkpoint = torch.load(ckpt_file, map_location="cpu") 53 | 54 | state_dict_keys = ['pos_embed', 'base_model.pos_embed', 'model.pos_embed'] 55 | for key in state_dict_keys: 56 | if key in checkpoint['state_dict']: 57 | del checkpoint['state_dict'][key] 58 | if 'state_dict_ema' in checkpoint and key in checkpoint['state_dict_ema']: 59 | del checkpoint['state_dict_ema'][key] 60 | break 61 | 62 | if load_ema: 63 | state_dict = checkpoint['state_dict_ema'] 64 | else: 65 | state_dict = checkpoint.get('state_dict', checkpoint) # to be compatible with the official checkpoint 66 | 67 | null_embed = torch.load(f'output/pretrained_models/null_embed_diffusers_{max_length}token.pth', map_location='cpu') 68 | state_dict['y_embedder.y_embedding'] = null_embed['uncond_prompt_embeds'][0] 69 | 70 | missing, unexpect = model.load_state_dict(state_dict, strict=False) 71 | if model_ema is not None: 72 | model_ema.load_state_dict(checkpoint['state_dict_ema'], strict=False) 73 | if optimizer is not None and resume_optimizer: 74 | optimizer.load_state_dict(checkpoint['optimizer']) 75 | if lr_scheduler is not None and resume_lr_scheduler: 76 | lr_scheduler.load_state_dict(checkpoint['scheduler']) 77 | logger = get_root_logger() 78 | if optimizer is not None: 79 | epoch = checkpoint.get('epoch', re.match(r'.*epoch_(\d*).*.pth', ckpt_file).group()[0]) 80 | logger.info(f'Resume checkpoint of epoch {epoch} from {ckpt_file}. Load ema: {load_ema}, ' 81 | f'resume optimizer: {resume_optimizer}, resume lr scheduler: {resume_lr_scheduler}.') 82 | return epoch, missing, unexpect 83 | logger.info(f'Load checkpoint from {ckpt_file}. Load ema: {load_ema}.') 84 | return missing, unexpect 85 | -------------------------------------------------------------------------------- /diffusion/utils/data_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | from typing import Sequence 4 | from torch.utils.data import BatchSampler, Sampler, Dataset 5 | from random import shuffle, choice 6 | from copy import deepcopy 7 | from diffusion.utils.logger import get_root_logger 8 | 9 | 10 | class AspectRatioBatchSampler(BatchSampler): 11 | """A sampler wrapper for grouping images with similar aspect ratio into a same batch. 12 | 13 | Args: 14 | sampler (Sampler): Base sampler. 15 | dataset (Dataset): Dataset providing data information. 16 | batch_size (int): Size of mini-batch. 17 | drop_last (bool): If ``True``, the sampler will drop the last batch if 18 | its size would be less than ``batch_size``. 19 | aspect_ratios (dict): The predefined aspect ratios. 20 | """ 21 | 22 | def __init__(self, 23 | sampler: Sampler, 24 | dataset: Dataset, 25 | batch_size: int, 26 | aspect_ratios: dict, 27 | drop_last: bool = False, 28 | config=None, 29 | valid_num=0, # take as valid aspect-ratio when sample number >= valid_num 30 | **kwargs) -> None: 31 | if not isinstance(sampler, Sampler): 32 | raise TypeError('sampler should be an instance of ``Sampler``, ' 33 | f'but got {sampler}') 34 | if not isinstance(batch_size, int) or batch_size <= 0: 35 | raise ValueError('batch_size should be a positive integer value, ' 36 | f'but got batch_size={batch_size}') 37 | self.sampler = sampler 38 | self.dataset = dataset 39 | self.batch_size = batch_size 40 | self.aspect_ratios = aspect_ratios 41 | self.drop_last = drop_last 42 | self.ratio_nums_gt = kwargs.get('ratio_nums', None) 43 | self.config = config 44 | assert self.ratio_nums_gt 45 | # buckets for each aspect ratio 46 | self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios.keys()} 47 | self.current_available_bucket_keys = [str(k) for k, v in self.ratio_nums_gt.items() if v >= valid_num] 48 | logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) 49 | logger.warning(f"Using valid_num={valid_num} in config file. Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}") 50 | 51 | def __iter__(self) -> Sequence[int]: 52 | for idx in self.sampler: 53 | data_info = self.dataset.get_data_info(idx) 54 | height, width = data_info['height'], data_info['width'] 55 | ratio = height / width 56 | # find the closest aspect ratio 57 | closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) 58 | if closest_ratio not in self.current_available_bucket_keys: 59 | continue 60 | bucket = self._aspect_ratio_buckets[closest_ratio] 61 | bucket.append(idx) 62 | # yield a batch of indices in the same aspect ratio group 63 | if len(bucket) == self.batch_size: 64 | yield bucket[:] 65 | del bucket[:] 66 | 67 | # yield the rest data and reset the buckets 68 | for bucket in self._aspect_ratio_buckets.values(): 69 | while len(bucket) > 0: 70 | if len(bucket) <= self.batch_size: 71 | if not self.drop_last: 72 | yield bucket[:] 73 | bucket = [] 74 | else: 75 | yield bucket[:self.batch_size] 76 | bucket = bucket[self.batch_size:] 77 | 78 | 79 | class BalancedAspectRatioBatchSampler(AspectRatioBatchSampler): 80 | def __init__(self, *args, **kwargs): 81 | super().__init__(*args, **kwargs) 82 | # Assign samples to each bucket 83 | self.ratio_nums_gt = kwargs.get('ratio_nums', None) 84 | assert self.ratio_nums_gt 85 | self._aspect_ratio_buckets = {float(ratio): [] for ratio in self.aspect_ratios.keys()} 86 | self.original_buckets = {} 87 | self.current_available_bucket_keys = [k for k, v in self.ratio_nums_gt.items() if v >= 3000] 88 | self.all_available_keys = deepcopy(self.current_available_bucket_keys) 89 | self.exhausted_bucket_keys = [] 90 | self.total_batches = len(self.sampler) // self.batch_size 91 | self._aspect_ratio_count = {} 92 | for k in self.all_available_keys: 93 | self._aspect_ratio_count[float(k)] = 0 94 | self.original_buckets[float(k)] = [] 95 | logger = get_root_logger(os.path.join(self.config.work_dir, 'train_log.log')) 96 | logger.warning(f"Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}") 97 | 98 | def __iter__(self) -> Sequence[int]: 99 | i = 0 100 | for idx in self.sampler: 101 | data_info = self.dataset.get_data_info(idx) 102 | height, width = data_info['height'], data_info['width'] 103 | ratio = height / width 104 | closest_ratio = float(min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))) 105 | if closest_ratio not in self.all_available_keys: 106 | continue 107 | if self._aspect_ratio_count[closest_ratio] < self.ratio_nums_gt[closest_ratio]: 108 | self._aspect_ratio_count[closest_ratio] += 1 109 | self._aspect_ratio_buckets[closest_ratio].append(idx) 110 | self.original_buckets[closest_ratio].append(idx) # Save the original samples for each bucket 111 | if not self.current_available_bucket_keys: 112 | self.current_available_bucket_keys, self.exhausted_bucket_keys = self.exhausted_bucket_keys, [] 113 | 114 | if closest_ratio not in self.current_available_bucket_keys: 115 | continue 116 | key = closest_ratio 117 | bucket = self._aspect_ratio_buckets[key] 118 | if len(bucket) == self.batch_size: 119 | yield bucket[:self.batch_size] 120 | del bucket[:self.batch_size] 121 | i += 1 122 | self.exhausted_bucket_keys.append(key) 123 | self.current_available_bucket_keys.remove(key) 124 | 125 | for _ in range(self.total_batches - i): 126 | key = choice(self.all_available_keys) 127 | bucket = self._aspect_ratio_buckets[key] 128 | if len(bucket) >= self.batch_size: 129 | yield bucket[:self.batch_size] 130 | del bucket[:self.batch_size] 131 | 132 | # If a bucket is exhausted 133 | if not bucket: 134 | self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:]) 135 | shuffle(self._aspect_ratio_buckets[key]) 136 | else: 137 | self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:]) 138 | shuffle(self._aspect_ratio_buckets[key]) 139 | -------------------------------------------------------------------------------- /diffusion/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch.distributed as dist 4 | from datetime import datetime 5 | from .dist_utils import is_local_master 6 | from mmcv.utils.logging import logger_initialized 7 | 8 | 9 | def get_root_logger(log_file=None, log_level=logging.INFO, name='PixArt'): 10 | """Get root logger. 11 | 12 | Args: 13 | log_file (str, optional): File path of log. Defaults to None. 14 | log_level (int, optional): The level of logger. 15 | Defaults to logging.INFO. 16 | name (str): logger name 17 | Returns: 18 | :obj:`logging.Logger`: The obtained logger 19 | """ 20 | if log_file is None: 21 | log_file = '/dev/null' 22 | logger = get_logger(name=name, log_file=log_file, log_level=log_level) 23 | return logger 24 | 25 | 26 | def get_logger(name, log_file=None, log_level=logging.INFO): 27 | """Initialize and get a logger by name. 28 | 29 | If the logger has not been initialized, this method will initialize the 30 | logger by adding one or two handlers, otherwise the initialized logger will 31 | be directly returned. During initialization, a StreamHandler will always be 32 | added. If `log_file` is specified and the process rank is 0, a FileHandler 33 | will also be added. 34 | 35 | Args: 36 | name (str): Logger name. 37 | log_file (str | None): The log filename. If specified, a FileHandler 38 | will be added to the logger. 39 | log_level (int): The logger level. Note that only the process of 40 | rank 0 is affected, and other processes will set the level to 41 | "Error" thus be silent most of the time. 42 | 43 | Returns: 44 | logging.Logger: The expected logger. 45 | """ 46 | logger = logging.getLogger(name) 47 | logger.propagate = False # disable root logger to avoid duplicate logging 48 | 49 | if name in logger_initialized: 50 | return logger 51 | # handle hierarchical names 52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 53 | # initialization since it is a child of "a". 54 | for logger_name in logger_initialized: 55 | if name.startswith(logger_name): 56 | return logger 57 | 58 | stream_handler = logging.StreamHandler() 59 | handlers = [stream_handler] 60 | 61 | if dist.is_available() and dist.is_initialized(): 62 | rank = dist.get_rank() 63 | else: 64 | rank = 0 65 | 66 | # only rank 0 will add a FileHandler 67 | if rank == 0 and log_file is not None: 68 | file_handler = logging.FileHandler(log_file, 'w') 69 | handlers.append(file_handler) 70 | 71 | formatter = logging.Formatter( 72 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 73 | for handler in handlers: 74 | handler.setFormatter(formatter) 75 | handler.setLevel(log_level) 76 | logger.addHandler(handler) 77 | 78 | # only rank0 for each node will print logs 79 | log_level = log_level if is_local_master() else logging.ERROR 80 | logger.setLevel(log_level) 81 | 82 | logger_initialized[name] = True 83 | 84 | return logger 85 | 86 | def rename_file_with_creation_time(file_path): 87 | # 获取文件的创建时间 88 | creation_time = os.path.getctime(file_path) 89 | creation_time_str = datetime.fromtimestamp(creation_time).strftime('%Y-%m-%d_%H-%M-%S') 90 | 91 | # 构建新的文件名 92 | dir_name, file_name = os.path.split(file_path) 93 | name, ext = os.path.splitext(file_name) 94 | new_file_name = f"{name}_{creation_time_str}{ext}" 95 | new_file_path = os.path.join(dir_name, new_file_name) 96 | 97 | # 重命名文件 98 | os.rename(file_path, new_file_path) 99 | print(f"File renamed to: {new_file_path}") 100 | -------------------------------------------------------------------------------- /diffusion/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from diffusers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup 2 | from torch.optim import Optimizer 3 | from torch.optim.lr_scheduler import LambdaLR 4 | import math 5 | 6 | from diffusion.utils.logger import get_root_logger 7 | 8 | 9 | def build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio): 10 | if not config.get('lr_schedule_args', None): 11 | config.lr_schedule_args = dict() 12 | if config.get('lr_warmup_steps', None): 13 | config['num_warmup_steps'] = config.get('lr_warmup_steps') # for compatibility with old version 14 | 15 | logger = get_root_logger() 16 | logger.info( 17 | f'Lr schedule: {config.lr_schedule}, ' + ",".join( 18 | [f"{key}:{value}" for key, value in config.lr_schedule_args.items()]) + '.') 19 | if config.lr_schedule == 'cosine': 20 | lr_scheduler = get_cosine_schedule_with_warmup( 21 | optimizer=optimizer, 22 | **config.lr_schedule_args, 23 | num_training_steps=(len(train_dataloader) * config.num_epochs), 24 | ) 25 | elif config.lr_schedule == 'constant': 26 | lr_scheduler = get_constant_schedule_with_warmup( 27 | optimizer=optimizer, 28 | **config.lr_schedule_args, 29 | ) 30 | elif config.lr_schedule == 'cosine_decay_to_constant': 31 | assert lr_scale_ratio >= 1 32 | lr_scheduler = get_cosine_decay_to_constant_with_warmup( 33 | optimizer=optimizer, 34 | **config.lr_schedule_args, 35 | final_lr=1 / lr_scale_ratio, 36 | num_training_steps=(len(train_dataloader) * config.num_epochs), 37 | ) 38 | else: 39 | raise RuntimeError(f'Unrecognized lr schedule {config.lr_schedule}.') 40 | return lr_scheduler 41 | 42 | 43 | def get_cosine_decay_to_constant_with_warmup(optimizer: Optimizer, 44 | num_warmup_steps: int, 45 | num_training_steps: int, 46 | final_lr: float = 0.0, 47 | num_decay: float = 0.667, 48 | num_cycles: float = 0.5, 49 | last_epoch: int = -1 50 | ): 51 | """ 52 | Create a schedule with a cosine annealing lr followed by a constant lr. 53 | 54 | Args: 55 | optimizer ([`~torch.optim.Optimizer`]): 56 | The optimizer for which to schedule the learning rate. 57 | num_warmup_steps (`int`): 58 | The number of steps for the warmup phase. 59 | num_training_steps (`int`): 60 | The number of total training steps. 61 | final_lr (`int`): 62 | The final constant lr after cosine decay. 63 | num_decay (`int`): 64 | The 65 | last_epoch (`int`, *optional*, defaults to -1): 66 | The index of the last epoch when resuming training. 67 | 68 | Return: 69 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 70 | """ 71 | 72 | def lr_lambda(current_step): 73 | if current_step < num_warmup_steps: 74 | return float(current_step) / float(max(1, num_warmup_steps)) 75 | 76 | num_decay_steps = int(num_training_steps * num_decay) 77 | if current_step > num_decay_steps: 78 | return final_lr 79 | 80 | progress = float(current_step - num_warmup_steps) / float(max(1, num_decay_steps - num_warmup_steps)) 81 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) * ( 82 | 1 - final_lr) + final_lr 83 | 84 | return LambdaLR(optimizer, lr_lambda, last_epoch) 85 | -------------------------------------------------------------------------------- /diffusion/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from mmcv import Config 4 | from mmcv.runner import build_optimizer as mm_build_optimizer, OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, \ 5 | OPTIMIZERS 6 | from mmcv.utils import _BatchNorm, _InstanceNorm 7 | from torch.nn import GroupNorm, LayerNorm 8 | 9 | from .logger import get_root_logger 10 | 11 | from typing import Tuple, Optional, Callable 12 | 13 | import torch 14 | from torch.optim.optimizer import Optimizer 15 | from came_pytorch import CAME 16 | 17 | 18 | def auto_scale_lr(effective_bs, optimizer_cfg, rule='linear', base_batch_size=256): 19 | assert rule in ['linear', 'sqrt'] 20 | logger = get_root_logger() 21 | # scale by world size 22 | if rule == 'sqrt': 23 | scale_ratio = math.sqrt(effective_bs / base_batch_size) 24 | elif rule == 'linear': 25 | scale_ratio = effective_bs / base_batch_size 26 | optimizer_cfg['lr'] *= scale_ratio 27 | logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.5f} (using {rule} scaling rule).') 28 | return scale_ratio 29 | 30 | 31 | @OPTIMIZER_BUILDERS.register_module() 32 | class MyOptimizerConstructor(DefaultOptimizerConstructor): 33 | 34 | def add_params(self, params, module, prefix='', is_dcn_module=None): 35 | """Add all parameters of module to the params list. 36 | 37 | The parameters of the given module will be added to the list of param 38 | groups, with specific rules defined by paramwise_cfg. 39 | 40 | Args: 41 | params (list[dict]): A list of param groups, it will be modified 42 | in place. 43 | module (nn.Module): The module to be added. 44 | prefix (str): The prefix of the module 45 | 46 | """ 47 | # get param-wise options 48 | custom_keys = self.paramwise_cfg.get('custom_keys', {}) 49 | # first sort with alphabet order and then sort with reversed len of str 50 | # sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) 51 | 52 | bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.) 53 | bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.) 54 | norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.) 55 | bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False) 56 | 57 | # special rules for norm layers and depth-wise conv layers 58 | is_norm = isinstance(module, 59 | (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) 60 | 61 | for name, param in module.named_parameters(recurse=False): 62 | base_lr = self.base_lr 63 | if name == 'bias' and not (is_norm or is_dcn_module): 64 | base_lr *= bias_lr_mult 65 | 66 | # apply weight decay policies 67 | base_wd = self.base_wd 68 | if self.base_wd is not None: 69 | # norm decay 70 | if is_norm: 71 | base_wd *= norm_decay_mult 72 | # bias lr and decay 73 | elif name == 'bias' and not is_dcn_module: 74 | # TODO: current bias_decay_mult will have affect on DCN 75 | base_wd *= bias_decay_mult 76 | 77 | param_group = {'params': [param]} 78 | if not param.requires_grad: 79 | param_group['requires_grad'] = False 80 | params.append(param_group) 81 | continue 82 | if bypass_duplicate and self._is_in(param_group, params): 83 | logger = get_root_logger() 84 | logger.warn(f'{prefix} is duplicate. It is skipped since ' 85 | f'bypass_duplicate={bypass_duplicate}') 86 | continue 87 | # if the parameter match one of the custom keys, ignore other rules 88 | is_custom = False 89 | for key in custom_keys: 90 | if isinstance(key, tuple): 91 | scope, key_name = key 92 | else: 93 | scope, key_name = None, key 94 | if scope is not None and scope not in f'{prefix}': 95 | continue 96 | if key_name in f'{prefix}.{name}': 97 | is_custom = True 98 | if 'lr_mult' in custom_keys[key]: 99 | # if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}': 100 | # param_group['lr'] = self.base_lr 101 | # else: 102 | param_group['lr'] = self.base_lr * custom_keys[key]['lr_mult'] 103 | elif 'lr' not in param_group: 104 | param_group['lr'] = base_lr 105 | if self.base_wd is not None: 106 | if 'decay_mult' in custom_keys[key]: 107 | param_group['weight_decay'] = self.base_wd * custom_keys[key]['decay_mult'] 108 | elif 'weight_decay' not in param_group: 109 | param_group['weight_decay'] = base_wd 110 | 111 | if not is_custom: 112 | # bias_lr_mult affects all bias parameters 113 | # except for norm.bias dcn.conv_offset.bias 114 | if base_lr != self.base_lr: 115 | param_group['lr'] = base_lr 116 | if base_wd != self.base_wd: 117 | param_group['weight_decay'] = base_wd 118 | params.append(param_group) 119 | 120 | for child_name, child_mod in module.named_children(): 121 | child_prefix = f'{prefix}.{child_name}' if prefix else child_name 122 | self.add_params( 123 | params, 124 | child_mod, 125 | prefix=child_prefix, 126 | is_dcn_module=is_dcn_module) 127 | 128 | 129 | def build_optimizer(model, optimizer_cfg): 130 | # default parameter-wise config 131 | logger = get_root_logger() 132 | 133 | if hasattr(model, 'module'): 134 | model = model.module 135 | # set optimizer constructor 136 | optimizer_cfg.setdefault('constructor', 'MyOptimizerConstructor') 137 | # parameter-wise setting: cancel weight decay for some specific modules 138 | custom_keys = dict() 139 | for name, module in model.named_modules(): 140 | if hasattr(module, 'zero_weight_decay'): 141 | custom_keys.update({(name, key): dict(decay_mult=0) for key in module.zero_weight_decay}) 142 | 143 | paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys))) 144 | given_cfg = optimizer_cfg.get('paramwise_cfg') 145 | if given_cfg: 146 | paramwise_cfg.merge_from_dict(dict(cfg=given_cfg)) 147 | optimizer_cfg['paramwise_cfg'] = paramwise_cfg.cfg 148 | # build optimizer 149 | optimizer = mm_build_optimizer(model, optimizer_cfg) 150 | 151 | weight_decay_groups = dict() 152 | lr_groups = dict() 153 | for group in optimizer.param_groups: 154 | if not group.get('requires_grad', True): continue 155 | lr_groups.setdefault(group['lr'], []).append(group) 156 | weight_decay_groups.setdefault(group['weight_decay'], []).append(group) 157 | 158 | learnable_count, fix_count = 0, 0 159 | for p in model.parameters(): 160 | if p.requires_grad: 161 | learnable_count += 1 162 | else: 163 | fix_count += 1 164 | fix_info = f"{learnable_count} are learnable, {fix_count} are fix" 165 | lr_info = "Lr group: " + ", ".join([f'{len(group)} params with lr {lr:.5f}' for lr, group in lr_groups.items()]) 166 | wd_info = "Weight decay group: " + ", ".join( 167 | [f'{len(group)} params with weight decay {wd}' for wd, group in weight_decay_groups.items()]) 168 | opt_info = f"{optimizer.__class__.__name__} Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}." 169 | logger.info(opt_info) 170 | 171 | return optimizer 172 | 173 | 174 | @OPTIMIZERS.register_module() 175 | class Lion(Optimizer): 176 | def __init__( 177 | self, 178 | params, 179 | lr: float = 1e-4, 180 | betas: Tuple[float, float] = (0.9, 0.99), 181 | weight_decay: float = 0.0, 182 | ): 183 | assert lr > 0. 184 | assert all([0. <= beta <= 1. for beta in betas]) 185 | 186 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 187 | 188 | super().__init__(params, defaults) 189 | 190 | @staticmethod 191 | def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): 192 | # stepweight decay 193 | p.data.mul_(1 - lr * wd) 194 | 195 | # weight update 196 | update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_() 197 | p.add_(update, alpha=-lr) 198 | 199 | # decay the momentum running average coefficient 200 | exp_avg.lerp_(grad, 1 - beta2) 201 | 202 | @staticmethod 203 | def exists(val): 204 | return val is not None 205 | 206 | @torch.no_grad() 207 | def step( 208 | self, 209 | closure: Optional[Callable] = None 210 | ): 211 | 212 | loss = None 213 | if self.exists(closure): 214 | with torch.enable_grad(): 215 | loss = closure() 216 | 217 | for group in self.param_groups: 218 | for p in filter(lambda p: self.exists(p.grad), group['params']): 219 | 220 | grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \ 221 | self.state[p] 222 | 223 | # init state - exponential moving average of gradient values 224 | if len(state) == 0: 225 | state['exp_avg'] = torch.zeros_like(p) 226 | 227 | exp_avg = state['exp_avg'] 228 | 229 | self.update_fn( 230 | p, 231 | grad, 232 | exp_avg, 233 | lr, 234 | wd, 235 | beta1, 236 | beta2 237 | ) 238 | 239 | return loss 240 | 241 | 242 | @OPTIMIZERS.register_module() 243 | class CAMEWrapper(CAME): 244 | def __init__(self, *args, **kwargs): 245 | 246 | super().__init__(*args, **kwargs) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: PixArt 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pytorch >= 1.13 8 | - torchvision 9 | - pytorch-cuda=11.7 10 | - pip: 11 | - timm==0.6.12 12 | - diffusers 13 | - accelerate 14 | - mmcv==1.7.0 15 | - diffusers 16 | - accelerate==0.15.0 17 | - tensorboard 18 | - transformers==4.26.1 19 | - sentencepiece~=0.1.97 20 | - ftfy~=6.1.1 21 | - beautifulsoup4~=4.11.1 22 | - opencv-python 23 | - bs4 24 | - einops 25 | - xformers -------------------------------------------------------------------------------- /notebooks/PixArt_xl2_img512_internal_for_pokemon_sample_training.py: -------------------------------------------------------------------------------- 1 | _base_ = ['/workspace/PixArt-alpha/configs/PixArt_xl2_internal.py'] 2 | data_root = '/workspace' 3 | 4 | image_list_json = ['data_info.json',] 5 | 6 | data = dict(type='InternalData', root='/workspace/pixart-pokemon', image_list_json=image_list_json, transform='default_train', load_vae_feat=True) 7 | image_size = 512 8 | 9 | # model setting 10 | model = 'PixArt_XL_2' 11 | fp32_attention = True 12 | load_from = "/workspace/PixArt-alpha/output/pretrained_models/PixArt-XL-2-512x512.pth" 13 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema" 14 | pe_interpolation = 1.0 15 | 16 | # training setting 17 | use_fsdp=False # if use FSDP mode 18 | num_workers=10 19 | train_batch_size = 38 # 32 20 | num_epochs = 200 # 3 21 | gradient_accumulation_steps = 1 22 | grad_checkpointing = True 23 | gradient_clip = 0.01 24 | optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10) 25 | lr_schedule_args = dict(num_warmup_steps=1000) 26 | 27 | eval_sampling_steps = 200 28 | log_interval = 20 29 | save_model_steps=100 30 | work_dir = 'output/debug' 31 | -------------------------------------------------------------------------------- /notebooks/convert-checkpoint-to-diffusers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "2878bb5d-33a3-4a5b-b15c-c832c700129b", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "/workspace/PixArt-alpha\n" 14 | ] 15 | }, 16 | { 17 | "name": "stderr", 18 | "output_type": "stream", 19 | "text": [ 20 | "/usr/local/lib/python3.10/dist-packages/IPython/core/magics/osm.py:417: UserWarning: using dhist requires you to install the `pickleshare` library.\n", 21 | " self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "%cd PixArt-alpha" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 14, 32 | "id": "7dd2d98c-3f8f-40f1-a9e1-bc916774afb3", 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "name": "stdout", 37 | "output_type": "stream", 38 | "text": [ 39 | "Total number of transformer parameters: 610856096\n" 40 | ] 41 | } 42 | ], 43 | "source": [ 44 | "!python tools/convert_pixart_to_diffusers.py \\\n", 45 | " --orig_ckpt_path \"/workspace/PixArt-alpha/output/trained_model/checkpoints/epoch_5_step_110.pth\" \\\n", 46 | " --dump_path \"/workspace/PixArt-alpha/output/epoch_5_step_110_diffusers\" \\\n", 47 | " --only_transformer=True \\\n", 48 | " --image_size 512 \\\n", 49 | " --version sigma\n" 50 | ] 51 | } 52 | ], 53 | "metadata": { 54 | "kernelspec": { 55 | "display_name": "Python 3 (ipykernel)", 56 | "language": "python", 57 | "name": "python3" 58 | }, 59 | "language_info": { 60 | "codemirror_mode": { 61 | "name": "ipython", 62 | "version": 3 63 | }, 64 | "file_extension": ".py", 65 | "mimetype": "text/x-python", 66 | "name": "python", 67 | "nbconvert_exporter": "python", 68 | "pygments_lexer": "ipython3", 69 | "version": "3.10.12" 70 | } 71 | }, 72 | "nbformat": 4, 73 | "nbformat_minor": 5 74 | } 75 | -------------------------------------------------------------------------------- /notebooks/train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c423d2a1-475e-482e-b759-f16456fd6707", 6 | "metadata": {}, 7 | "source": [ 8 | "# Install" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "0440d6a7-78b9-49e9-98a2-9a5ed75e1a2f", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "!git clone https://github.com/kopyl/PixArt-alpha.git" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "0abadf51-a7e3-4091-bb02-0bdd8d28fb73", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "%cd PixArt-alpha" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "4df1af24-f439-485d-a946-966dbf16c49b", 35 | "metadata": { 36 | "scrolled": true 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "!pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu117\n", 41 | "!pip install -r requirements.txt\n", 42 | "!pip install wandb" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "id": "d44474fd-0b92-48fc-b4cf-142b59d3917c", 48 | "metadata": {}, 49 | "source": [ 50 | "## Download model" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "06b1c1c9-f8b1-4719-8564-2383eac9ff28", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "!python tools/download.py --model_names \"PixArt-XL-2-512x512.pth\"" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "id": "f298a89c-d2a5-4da7-8304-c1390da0ba58", 66 | "metadata": {}, 67 | "source": [ 68 | "## Make dataset out of Hugginggface dataset" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "id": "e17b8883-0a5c-4fa3-a7d0-e8ee95e42027", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "import os\n", 79 | "from tqdm.notebook import tqdm\n", 80 | "from datasets import load_dataset\n", 81 | "import json" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "id": "92957b2c-6765-48ee-9296-d6739066d74d", 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "dataset = load_dataset(\"lambdalabs/pokemon-blip-captions\")" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "id": "0095cdda-c31a-48ee-a115-076a5fc393c3", 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "root_dir = \"/workspace/pixart-pokemon\"\n", 102 | "images_dir = \"images\"\n", 103 | "captions_dir = \"captions\"\n", 104 | "\n", 105 | "images_dir_absolute = os.path.join(root_dir, images_dir)\n", 106 | "captions_dir_absolute = os.path.join(root_dir, captions_dir)\n", 107 | "\n", 108 | "if not os.path.exists(root_dir):\n", 109 | " os.makedirs(os.path.join(root_dir, images_dir))\n", 110 | "\n", 111 | "if not os.path.exists(os.path.join(root_dir, images_dir)):\n", 112 | " os.makedirs(os.path.join(root_dir, images_dir))\n", 113 | "if not os.path.exists(os.path.join(root_dir, captions_dir)):\n", 114 | " os.makedirs(os.path.join(root_dir, captions_dir))\n", 115 | "\n", 116 | "image_format = \"png\"\n", 117 | "json_name = \"partition/data_info.json\"\n", 118 | "if not os.path.exists(os.path.join(root_dir, \"partition\")):\n", 119 | " os.makedirs(os.path.join(root_dir, \"partition\"))\n", 120 | "\n", 121 | "absolute_json_name = os.path.join(root_dir, json_name)\n", 122 | "data_info = []\n", 123 | "\n", 124 | "order = 0\n", 125 | "for item in tqdm(dataset[\"train\"]): \n", 126 | " image = item[\"image\"]\n", 127 | " image.save(f\"{images_dir_absolute}/{order}.{image_format}\")\n", 128 | " with open(f\"{captions_dir_absolute}/{order}.txt\", \"w\") as text_file:\n", 129 | " text_file.write(item[\"text\"])\n", 130 | " \n", 131 | " width, height = 512, 512\n", 132 | " ratio = 1\n", 133 | " data_info.append({\n", 134 | " \"height\": height,\n", 135 | " \"width\": width,\n", 136 | " \"ratio\": ratio,\n", 137 | " \"path\": f\"images/{order}.{image_format}\",\n", 138 | " \"prompt\": item[\"text\"],\n", 139 | " })\n", 140 | " \n", 141 | " order += 1\n", 142 | "\n", 143 | "with open(absolute_json_name, \"w\") as json_file:\n", 144 | " json.dump(data_info, json_file)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "id": "25be1c03", 150 | "metadata": {}, 151 | "source": [ 152 | "## Extract features" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "id": "9f07a4f5-1873-48bf-86d0-9304942de5d3", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "!python /workspace/PixArt-alpha/tools/extract_features.py \\\n", 163 | " --img_size 512 \\\n", 164 | " --json_path \"/workspace/pixart-pokemon/partition/data_info.json\" \\\n", 165 | " --t5_save_root \"/workspace/pixart-pokemon/caption_feature_wmask\" \\\n", 166 | " --vae_save_root \"/workspace/pixart-pokemon/img_vae_features\" \\\n", 167 | " --pretrained_models_dir \"/workspace/PixArt-alpha/output/pretrained_models\" \\\n", 168 | " --dataset_root \"/workspace/pixart-pokemon\"" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "id": "9fc653d0", 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "!wandb login REPLACE_THIS_WITH_YOUR_AUTH_TOKEN_OF_WANDB" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "id": "2cf1fd1a", 184 | "metadata": {}, 185 | "source": [ 186 | "## Train model" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "id": "ea0e9dab-17bc-45ed-9c81-b670bbb8de47", 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "!python -m torch.distributed.launch \\\n", 197 | " train_scripts/train.py \\\n", 198 | " /workspace/PixArt-alpha/notebooks/PixArt_xl2_img512_internal_for_pokemon_sample_training.py \\\n", 199 | " --work-dir output/trained_model \\\n", 200 | " --report_to=\"wandb\" \\\n", 201 | " --loss_report_name=\"train_loss\"" 202 | ] 203 | } 204 | ], 205 | "metadata": { 206 | "kernelspec": { 207 | "display_name": "Python 3 (ipykernel)", 208 | "language": "python", 209 | "name": "python3" 210 | }, 211 | "language_info": { 212 | "codemirror_mode": { 213 | "name": "ipython", 214 | "version": 3 215 | }, 216 | "file_extension": ".py", 217 | "mimetype": "text/x-python", 218 | "name": "python", 219 | "nbconvert_exporter": "python", 220 | "pygments_lexer": "ipython3", 221 | "version": "3.8.13" 222 | } 223 | }, 224 | "nbformat": 4, 225 | "nbformat_minor": 5 226 | } 227 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mmcv==1.7.0 2 | git+https://github.com/huggingface/diffusers 3 | timm==0.6.12 4 | accelerate==0.25.0 5 | tensorboard 6 | tensorboardX 7 | transformers==4.36.1 8 | sentencepiece~=0.1.99 9 | ftfy 10 | beautifulsoup4 11 | protobuf==3.20.2 12 | gradio==4.1.1 13 | yapf==0.40.1 14 | opencv-python 15 | bs4 16 | einops 17 | xformers==0.0.19 18 | optimum 19 | peft 20 | came-pytorch -------------------------------------------------------------------------------- /scripts/DMD/transformer_train/attention_processor.py: -------------------------------------------------------------------------------- 1 | from diffusers.models.attention_processor import AttnProcessor2_0, Attention 2 | from typing import Optional 3 | import torch 4 | from diffusers.utils import USE_PEFT_BACKEND 5 | import torch.nn.functional as F 6 | 7 | class AttentionPorcessorFP32(AttnProcessor2_0): 8 | 9 | def __call__( 10 | self, 11 | attn: Attention, 12 | hidden_states: torch.FloatTensor, 13 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 14 | attention_mask: Optional[torch.FloatTensor] = None, 15 | temb: Optional[torch.FloatTensor] = None, 16 | scale: float = 1.0, 17 | use_fp32_attention = True, 18 | ) -> torch.FloatTensor: 19 | residual = hidden_states 20 | if attn.spatial_norm is not None: 21 | hidden_states = attn.spatial_norm(hidden_states, temb) 22 | 23 | input_ndim = hidden_states.ndim 24 | 25 | if input_ndim == 4: 26 | batch_size, channel, height, width = hidden_states.shape 27 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 28 | 29 | batch_size, sequence_length, _ = ( 30 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 31 | ) 32 | 33 | if attention_mask is not None: 34 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 35 | # scaled_dot_product_attention expects attention_mask shape to be 36 | # (batch, heads, source_length, target_length) 37 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 38 | 39 | if attn.group_norm is not None: 40 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 41 | 42 | args = () if USE_PEFT_BACKEND else (scale,) 43 | query = attn.to_q(hidden_states, *args) 44 | 45 | if encoder_hidden_states is None: 46 | encoder_hidden_states = hidden_states 47 | elif attn.norm_cross: 48 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 49 | 50 | key = attn.to_k(encoder_hidden_states, *args) 51 | value = attn.to_v(encoder_hidden_states, *args) 52 | 53 | inner_dim = key.shape[-1] 54 | head_dim = inner_dim // attn.heads 55 | 56 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 57 | 58 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 59 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 60 | 61 | if use_fp32_attention: 62 | query = query.float() 63 | key = key.float() 64 | value = value.float() 65 | if attention_mask is not None: 66 | attention_mask = attention_mask.to(query.dtype) 67 | 68 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 69 | # TODO: add support for attn.scale when we move to Torch 2.1 70 | with torch.cuda.amp.autocast(enabled=not use_fp32_attention): 71 | hidden_states = F.scaled_dot_product_attention( 72 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 73 | ) 74 | 75 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 76 | hidden_states = hidden_states.to(torch.float16) 77 | 78 | # linear proj 79 | hidden_states = attn.to_out[0](hidden_states, *args) 80 | # dropout 81 | hidden_states = attn.to_out[1](hidden_states) 82 | 83 | if input_ndim == 4: 84 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 85 | 86 | if attn.residual_connection: 87 | hidden_states = hidden_states + residual 88 | 89 | hidden_states = hidden_states / attn.rescale_output_factor 90 | 91 | return hidden_states -------------------------------------------------------------------------------- /scripts/DMD/transformer_train/generate.py: -------------------------------------------------------------------------------- 1 | # contains functions of generating samples 2 | import torch 3 | from accelerate.utils.other import extract_model_from_parallel 4 | 5 | 6 | def model_forward(generator, encoder_hidden_states, encoder_attention_mask, added_cond_kwargs, noise, start_ts): 7 | if isinstance(start_ts, int): 8 | # convert int to long 9 | start_ts_net_in = torch.zeros((noise.size()[0],)) + start_ts 10 | start_ts_net_in = start_ts_net_in.long().to(noise.device) 11 | else: 12 | start_ts_net_in = start_ts.to(noise.device) 13 | noise_pred = generator(hidden_states=noise, encoder_hidden_states=encoder_hidden_states, 14 | encoder_attention_mask=encoder_attention_mask, added_cond_kwargs=added_cond_kwargs, 15 | imestep=start_ts_net_in).sample 16 | B, C = noise.shape[:2] 17 | assert noise_pred.shape == (B, C * 2, *noise.shape[2:]) 18 | noise_pred = torch.split(noise_pred, C, dim=1)[0] 19 | return noise_pred 20 | 21 | 22 | def generate_sample_1step(model, scheduler, latents, maxt, prompt_embeds, prompt_attention_masks=None): 23 | t = torch.full((1,), maxt, device=latents.device).long() 24 | noise_pred = forward_model( 25 | model, 26 | latents=latents, 27 | timestep=t, 28 | prompt_embeds=prompt_embeds, 29 | prompt_attention_masks=prompt_attention_masks, 30 | ) 31 | latents = eps_to_mu(scheduler, noise_pred, latents, t) 32 | return latents 33 | 34 | def eps_to_mu(scheduler, model_output, sample, timesteps): 35 | alphas_cumprod = scheduler.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) 36 | alpha_prod_t = alphas_cumprod[timesteps] 37 | while len(alpha_prod_t.shape) < len(sample.shape): 38 | alpha_prod_t = alpha_prod_t.unsqueeze(-1) 39 | beta_prod_t = 1 - alpha_prod_t 40 | pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 41 | return pred_original_sample 42 | 43 | 44 | def forward_model(model, latents, timestep, prompt_embeds, prompt_attention_masks=None): 45 | added_cond_kwargs = {"resolution": None, "aspect_ratio": None} 46 | if extract_model_from_parallel(model).config.sample_size == 128: 47 | batch_size, _, height, width = latents.shape 48 | resolution = torch.tensor([height, width]).repeat(batch_size, 1) 49 | aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1) 50 | resolution = resolution.to(dtype=prompt_embeds.dtype, device=prompt_embeds.device) 51 | aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=prompt_embeds.device) 52 | added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} 53 | 54 | timestep = timestep.expand(latents.shape[0]) 55 | 56 | noise_pred = model( 57 | latents, 58 | timestep=timestep, 59 | encoder_hidden_states=prompt_embeds, 60 | encoder_attention_mask=prompt_attention_masks, 61 | added_cond_kwargs=added_cond_kwargs, 62 | ).sample 63 | 64 | if extract_model_from_parallel(model).config.out_channels // 2 == latents.shape[1]: 65 | noise_pred = noise_pred.chunk(2, dim=1)[0] 66 | 67 | return noise_pred -------------------------------------------------------------------------------- /scripts/DMD/transformer_train/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import re 5 | import shutil 6 | 7 | from accelerate.checkpointing import save_accelerator_state, save_custom_state 8 | from accelerate.logging import get_logger 9 | from accelerate.utils import ( 10 | MODEL_NAME, 11 | DistributedType, 12 | save_fsdp_model, 13 | save_fsdp_optimizer, 14 | is_deepspeed_available 15 | ) 16 | 17 | from PIL import Image 18 | 19 | if is_deepspeed_available(): 20 | import deepspeed 21 | 22 | from accelerate.utils import ( 23 | DeepSpeedEngineWrapper, 24 | DeepSpeedOptimizerWrapper, 25 | DeepSpeedSchedulerWrapper, 26 | DummyOptim, 27 | DummyScheduler, 28 | ) 29 | 30 | try: 31 | from torch.optim.lr_scheduler import LRScheduler 32 | except ImportError: 33 | from torch.optim.lr_scheduler import _LRScheduler as LRScheduler 34 | 35 | logger = get_logger(__name__) 36 | 37 | ############# Saving model utils 38 | 39 | def accelerate_save_state(accelerator, output_dir=None, save_unet_only=False, unet_id=0, **save_model_func_kwargs): 40 | """ 41 | Saves the current states of the model, optimizer, scaler, RNG generators, and registered objects to a folder. 42 | 43 | If a `ProjectConfiguration` was passed to the `Accelerator` object with `automatic_checkpoint_naming` enabled 44 | then checkpoints will be saved to `self.project_dir/checkpoints`. If the number of current saves is greater 45 | than `total_limit` then the oldest save is deleted. Each checkpoint is saved in seperate folders named 46 | `checkpoint_`. 47 | 48 | Otherwise they are just saved to `output_dir`. 49 | 50 | 51 | 52 | Should only be used when wanting to save a checkpoint during training and restoring the state in the same 53 | environment. 54 | 55 | 56 | 57 | Args: 58 | output_dir (`str` or `os.PathLike`): 59 | The name of the folder to save all relevant weights and states. 60 | save_model_func_kwargs (`dict`, *optional*): 61 | Additional keyword arguments for saving model which can be passed to the underlying save function, such 62 | as optional arguments for DeepSpeed's `save_checkpoint` function. 63 | 64 | Example: 65 | 66 | ```python 67 | >>> from accelerate import Accelerator 68 | 69 | >>> accelerator = Accelerator() 70 | >>> model, optimizer, lr_scheduler = ... 71 | >>> model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) 72 | >>> accelerator.save_state(output_dir="my_checkpoint") 73 | ``` 74 | """ 75 | if accelerator.project_configuration.automatic_checkpoint_naming: 76 | output_dir = os.path.join(accelerator.project_dir, "checkpoints") 77 | os.makedirs(output_dir, exist_ok=True) 78 | if accelerator.project_configuration.automatic_checkpoint_naming: 79 | folders = [os.path.join(output_dir, folder) for folder in os.listdir(output_dir)] 80 | if accelerator.project_configuration.total_limit is not None and ( 81 | len(folders) + 1 > accelerator.project_configuration.total_limit 82 | ): 83 | 84 | def _inner(folder): 85 | return list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0] 86 | 87 | folders.sort(key=_inner) 88 | logger.warning( 89 | f"Deleting {len(folders) + 1 - accelerator.project_configuration.total_limit} checkpoints to make room for new checkpoint." 90 | ) 91 | for folder in folders[: len(folders) + 1 - accelerator.project_configuration.total_limit]: 92 | shutil.rmtree(folder) 93 | output_dir = os.path.join(output_dir, f"checkpoint_{accelerator.save_iteration}") 94 | if os.path.exists(output_dir): 95 | raise ValueError( 96 | f"Checkpoint directory {output_dir} ({accelerator.save_iteration}) already exists. Please manually override `self.save_iteration` with what iteration to start with." 97 | ) 98 | os.makedirs(output_dir, exist_ok=True) 99 | logger.info(f"Saving current state to {output_dir}") 100 | 101 | 102 | # Save the models taking care of FSDP and DeepSpeed nuances 103 | 104 | weights = [] 105 | for i, model in enumerate(accelerator._models): 106 | if save_unet_only and i != unet_id: 107 | continue 108 | if accelerator.distributed_type == DistributedType.FSDP: 109 | logger.info("Saving FSDP model") 110 | save_fsdp_model(accelerator.state.fsdp_plugin, accelerator, model, output_dir, i) 111 | logger.info(f"FSDP Model saved to output dir {output_dir}") 112 | elif accelerator.distributed_type == DistributedType.DEEPSPEED: 113 | logger.info("Saving DeepSpeed Model and Optimizer") 114 | ckpt_id = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}" 115 | model.save_checkpoint(output_dir, ckpt_id, **save_model_func_kwargs) 116 | logger.info(f"DeepSpeed Model and Optimizer saved to output dir {os.path.join(output_dir, ckpt_id)}") 117 | elif accelerator.distributed_type == DistributedType.MEGATRON_LM: 118 | logger.info("Saving Megatron-LM Model, Optimizer and Scheduler") 119 | model.save_checkpoint(output_dir) 120 | logger.info(f"Megatron-LM Model , Optimizer and Scheduler saved to output dir {output_dir}") 121 | else: 122 | weights.append(accelerator.get_state_dict(model, unwrap=False)) 123 | 124 | # Save the optimizers taking care of FSDP and DeepSpeed nuances 125 | optimizers = [] 126 | if not save_unet_only: 127 | if accelerator.distributed_type == DistributedType.FSDP: 128 | for i, opt in enumerate(accelerator._optimizers): 129 | logger.info("Saving FSDP Optimizer") 130 | save_fsdp_optimizer(accelerator.state.fsdp_plugin, accelerator, opt, accelerator._models[i], output_dir, i) 131 | logger.info(f"FSDP Optimizer saved to output dir {output_dir}") 132 | elif accelerator.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]: 133 | optimizers = accelerator._optimizers 134 | 135 | # Save the lr schedulers taking care of DeepSpeed nuances 136 | schedulers = [] 137 | if accelerator.distributed_type == DistributedType.DEEPSPEED: 138 | for i, scheduler in enumerate(accelerator._schedulers): 139 | if isinstance(scheduler, DeepSpeedSchedulerWrapper): 140 | continue 141 | schedulers.append(scheduler) 142 | elif accelerator.distributed_type not in [DistributedType.MEGATRON_LM]: 143 | schedulers = accelerator._schedulers 144 | 145 | # Call model loading hooks that might have been registered with 146 | # accelerator.register_model_state_hook 147 | for hook in accelerator._save_model_state_pre_hook.values(): 148 | hook(accelerator._models, weights, output_dir) 149 | 150 | save_location = save_accelerator_state( 151 | output_dir, weights, optimizers, schedulers, accelerator.state.process_index, accelerator.scaler 152 | ) 153 | for i, obj in enumerate(accelerator._custom_objects): 154 | save_custom_state(obj, output_dir, i) 155 | 156 | accelerator.project_configuration.iteration += 1 157 | return save_location 158 | 159 | 160 | #### calculation 161 | def compute_snr(timesteps, noise_scheduler): 162 | """ 163 | Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 164 | """ 165 | alphas_cumprod = noise_scheduler.alphas_cumprod 166 | sqrt_alphas_cumprod = alphas_cumprod**0.5 167 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 168 | 169 | # Expand the tensors. 170 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 171 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 172 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): 173 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] 174 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) 175 | 176 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 177 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): 178 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] 179 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) 180 | 181 | # Compute SNR. 182 | snr = (alpha / sigma) ** 2 183 | return snr 184 | 185 | 186 | def save_image(image, path): 187 | image = (image / 2 + 0.5).clamp(0, 1) 188 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 189 | image = (image * 255).round().astype("uint8") 190 | image = [Image.fromarray(im) for im in image] 191 | image[0].save(path) 192 | -------------------------------------------------------------------------------- /scripts/inference_pipeline.py: -------------------------------------------------------------------------------- 1 | from transformers import T5EncoderModel 2 | from diffusers import PixArtAlphaPipeline, Transformer2DModel,DEISMultistepScheduler, DPMSolverMultistepScheduler 3 | import torch 4 | import gc 5 | import argparse 6 | import pathlib 7 | from pathlib import Path 8 | import sys 9 | 10 | current_file_path = Path(__file__).resolve() 11 | sys.path.insert(0, str(current_file_path.parent.parent)) 12 | from scripts.diffusers_patches import pixart_sigma_init_patched_inputs 13 | 14 | def main(args): 15 | setattr(Transformer2DModel, '_init_patched_inputs', pixart_sigma_init_patched_inputs) 16 | 17 | gc.collect() 18 | torch.cuda.empty_cache() 19 | 20 | repo_path = args.repo_path 21 | output_image = pathlib.Path(args.output) 22 | positive_prompt = args.positive_prompt 23 | negative_prompt = args.negative_prompt 24 | image_width = args.width 25 | image_height = args.height 26 | num_steps = args.num_steps 27 | guidance_scale = args.guidance_scale 28 | seed = args.seed 29 | low_vram = args.low_vram 30 | num_images = args.num_images 31 | scheduler_type = args.scheduler 32 | karras = args.karras 33 | algorithm_type = args.algorithm 34 | beta_schedule = args.beta_schedule 35 | use_lu_lambdas = args.use_lu_lambdas 36 | 37 | pipe = None 38 | if low_vram: 39 | print('low_vram') 40 | text_encoder = T5EncoderModel.from_pretrained( 41 | repo_path, 42 | subfolder="text_encoder", 43 | load_in_8bit=True, 44 | torch_dtype=torch.float16, 45 | ) 46 | pipe = PixArtAlphaPipeline.from_pretrained( 47 | repo_path, 48 | text_encoder=text_encoder, 49 | transformer=None, 50 | torch_dtype=torch.float16, 51 | ) 52 | 53 | with torch.no_grad(): 54 | prompt = positive_prompt 55 | negative = negative_prompt 56 | prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt, negative_prompt=negative) 57 | 58 | def flush(): 59 | gc.collect() 60 | torch.cuda.empty_cache() 61 | 62 | pipe.text_encoder = None 63 | del text_encoder 64 | flush() 65 | 66 | pipe.transformer = Transformer2DModel.from_pretrained(repo_path, subfolder='transformer', 67 | load_in_8bit=True, 68 | torch_dtype=torch.float16) 69 | pipe.to('cuda') 70 | else: 71 | print('low_vram=False') 72 | pipe = PixArtAlphaPipeline.from_pretrained( 73 | repo_path, 74 | ).to('cuda') 75 | 76 | with torch.no_grad(): 77 | prompt = positive_prompt 78 | negative = negative_prompt 79 | prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt, negative_prompt=negative) 80 | 81 | generator = torch.Generator() 82 | 83 | if seed != -1: 84 | generator = generator.manual_seed(seed) 85 | else: 86 | generator = None 87 | 88 | prompt_embeds = prompt_embeds.to('cuda') 89 | negative_embeds = negative_embeds.to('cuda') 90 | prompt_attention_mask = prompt_attention_mask.to('cuda') 91 | negative_prompt_attention_mask = negative_prompt_attention_mask.to('cuda') 92 | 93 | if scheduler_type == 'deis': 94 | pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config) 95 | else: 96 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 97 | 98 | pipe.scheduler.beta_schedule = beta_schedule 99 | pipe.scheduler.algorithm_type = algorithm_type 100 | pipe.scheduler.use_karras_sigmas = karras 101 | pipe.scheduler.use_lu_lambdas = use_lu_lambdas 102 | latents = pipe( 103 | negative_prompt=None, 104 | num_inference_steps=num_steps, 105 | height=image_height, 106 | width=image_width, 107 | prompt_embeds=prompt_embeds, 108 | guidance_scale=guidance_scale, 109 | negative_prompt_embeds=negative_embeds, 110 | prompt_attention_mask=prompt_attention_mask, 111 | negative_prompt_attention_mask=negative_prompt_attention_mask, 112 | num_images_per_prompt=num_images, 113 | output_type="latent", 114 | generator=generator, 115 | ).images 116 | 117 | words = str(output_image).split('.') 118 | filename = words[0] 119 | extension = words[1] 120 | 121 | with torch.no_grad(): 122 | images = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] 123 | images = pipe.image_processor.postprocess(images, output_type="pil") 124 | 125 | i = 0 126 | for image in images: 127 | image.save(filename + str(i) + '.' + extension) 128 | i = i + 1 129 | 130 | if __name__ == '__main__': 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument('--repo_path', required=True, type=str, help='Local path or remote path to the pipeline folder') 133 | parser.add_argument('--output', required=False, type=str, default='out.png', help='Path to the generated output image. Supports most image formats i.e. .png, .jpg, .jpeg, .webp') 134 | parser.add_argument('--positive_prompt', required=True, type=str, help='Positive prompt to generate') 135 | parser.add_argument('--negative_prompt', required=False, type=str, default='', help='Negative prompt to generate') 136 | parser.add_argument('--seed', required=False, default=-1, type=int, help='Seed for the random generator') 137 | parser.add_argument('--width', required=False, default=512, type=int, help='Image width to generate') 138 | parser.add_argument('--height', required=False, default=512, type=int, help='Image height to generate') 139 | parser.add_argument('--num_steps', required=False, default=20, type=int, help='Number of inference steps') 140 | parser.add_argument('--guidance_scale', required=False, default=7.0, type=float, help='Guidance scale') 141 | parser.add_argument('--low_vram', required=False, action='store_true') 142 | parser.add_argument('--num_images', required=False, default=1, type=int, help='Number of images per prompt') 143 | parser.add_argument('--scheduler', required=False, default='dpm', type=str, choices=['dpm', 'deis']) 144 | parser.add_argument('--karras', required=False, action='store_true') 145 | parser.add_argument('--algorithm', required=False, default='sde-dpmsolver++', type=str, choices=['dpmsolver', 'dpmsolver++', 'sde-dpmsolver', 'sde-dpmsolver++']) 146 | parser.add_argument('--beta_schedule', required=False, default='linear', type=str, choices=['linear', 'scaled_linear', 'squaredcos_cap_v2']) 147 | parser.add_argument('--use_lu_lambdas', required=False, action='store_true') 148 | 149 | args = parser.parse_args() 150 | main(args) -------------------------------------------------------------------------------- /scripts/run_pixart_dmd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | from pathlib import Path 5 | 6 | current_file_path = Path(__file__).resolve() 7 | sys.path.insert(0, str(current_file_path.parent.parent)) 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("config", type=str, help="config") 11 | parser.add_argument('--work_dir', help='the dir to save logs and models') 12 | parser.add_argument('--init_method', type=str, default='tcp://127.0.0.1:6666', help='') 13 | parser.add_argument('--rank', type=str, default='0', help='') 14 | parser.add_argument('--world_size', type=str, default='1', help='') 15 | parser.add_argument("--is_debugging", action="store_true") 16 | parser.add_argument("--max_samples", type=int, default=500000, help='') 17 | parser.add_argument("--regression_weight", type=float, default=0.25, help='regression loss weight') 18 | parser.add_argument("--resume_from", type=str, default='', help='path of resumed checkpoint') 19 | parser.add_argument("--mixed_precision", type=str, default='no', help='whether use mixed precision') 20 | parser.add_argument("--batch_size", type=int, default=1, help='batch size per gpu') 21 | parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help='gradient_accumulation_steps') 22 | parser.add_argument("--use_dm", type=int, default=1, help='use distribution matching loss') 23 | parser.add_argument("--use_regression", type=int, default=1, help='use regression loss') 24 | parser.add_argument("--one_step_maxt", type=int, default=999, help='maximum timestep of one step generator') 25 | parser.add_argument("--learning_rate", type=float, default=1e-5, help='learning rate') 26 | parser.add_argument("--lr_fake_multiplier", type=float, default=1.0, help='lr of fake model / lr of set lr') 27 | parser.add_argument("--max_grad_norm", type=int, default=10, help='batch size per gpu') 28 | parser.add_argument("--save_image_interval", type=int, default=500, help='iteration interval to save image') 29 | parser.add_argument("--checkpointing_steps", type=int, default=500, 30 | help=( 31 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 32 | " training using `--resume_from_checkpoint`." 33 | ), 34 | ) 35 | 36 | 37 | args, args_run = parser.parse_known_args() 38 | 39 | if args.world_size != '1': 40 | os.environ['MASTER_PORT'] = args.init_method.split(':')[-1] 41 | os.environ['MASTER_ADDR'] = args.init_method.split(':')[1][2:] 42 | os.environ['WORLD_SIZE'] = args.world_size 43 | os.environ['RANK'] = args.rank 44 | for k in ['MASTER_PORT', 'MASTER_ADDR', 'WORLD_SIZE', 'RANK']: 45 | print (k, ':', os.environ[k]) 46 | 47 | is_debugging = args.is_debugging 48 | 49 | use_dm = True if args.use_dm == 1 else False 50 | use_regression = True if args.use_regression == 1 else False 51 | laion_subset = 'subset6.25' 52 | 53 | if len(args.resume_from) == 0: 54 | resume_from = None 55 | else: 56 | resume_from = args.resume_from 57 | 58 | # activate enviroment 59 | print('setting up env') 60 | print('env set') 61 | 62 | if args.mixed_precision == 'no': 63 | acc_common_args = '--use_fsdp --fsdp_offload_params="False" --fsdp_sharding_strategy=2' 64 | else: 65 | acc_common_args = '--mixed_precision="fp16" --use_fsdp --fsdp_offload_params="False" --fsdp_sharding_strategy=2' 66 | 67 | main_args = ( 68 | f'--config="{args.config}" ' 69 | f'--train_batch_size={args.batch_size} ' 70 | f'--one_step_maxt={args.one_step_maxt} ' 71 | f'--output_dir={args.work_dir} ' 72 | f'--learning_rate={args.learning_rate} ' 73 | f'--max_samples={args.max_samples} ' 74 | f'--node_id={int(args.rank)} ' 75 | f'--gradient_accumulation_steps={args.gradient_accumulation_steps} ' 76 | f'--checkpointing_steps={args.checkpointing_steps} ' 77 | f'--lr_fake_multiplier={args.lr_fake_multiplier} ' 78 | f'--max_grad_norm={args.max_grad_norm} ' 79 | f'--save_image_interval={args.save_image_interval} ' 80 | '--max_train_steps=1000000 ' 81 | '--di_steps=1 ' 82 | '--start_ts=999 ' 83 | '--cfg=3 ' 84 | '--dataloader_num_workers=16 ' 85 | '--resolution=512 ' 86 | '--center_crop' 87 | '--random_flip ' 88 | '--use_ema ' 89 | '--lr_scheduler="constant" ' 90 | '--lr_warmup_steps=0 ' 91 | '--logging_dir="_logs" ' 92 | '--report_to=tensorboard ' 93 | '--adam_epsilon=1e-06 ' 94 | '--seed=0 ' 95 | ) 96 | 97 | if args.mixed_precision == 'fp16': 98 | main_args += '--mixed_precision="fp16" ' 99 | 100 | if use_dm: 101 | main_args += '--use_dm ' 102 | if use_regression: 103 | main_args += f'--use_regression --regression_weight={args.regression_weight} ' 104 | 105 | if resume_from is not None: 106 | main_args += f'--resume_from_checkpoint="{resume_from}" ' 107 | 108 | if is_debugging: 109 | num_gpus_per_node = 2 110 | else: 111 | num_gpus_per_node = 8 112 | 113 | if args.world_size != '1': 114 | num_processes = int(args.world_size) * num_gpus_per_node 115 | print('num_processes', num_processes) 116 | run_cmd = (f'accelerate launch {acc_common_args} ' 117 | f'--num_machines={args.world_size} ' 118 | f'--num_processes={num_processes} ' 119 | f'--machine_rank={os.environ["RANK"]} ' 120 | f'--main_process_ip={os.environ["MASTER_ADDR"]} ' 121 | f'--main_process_port={os.environ["MASTER_PORT"]} ' 122 | f'train_scripts/train_pixart_dmd.py {main_args}' 123 | ) 124 | else: 125 | run_cmd = (f'accelerate launch {acc_common_args} ' 126 | f'--num_machines={args.world_size} ' 127 | f'--num_processes={num_gpus_per_node} ' 128 | 'train_scripts/train_pixart_dmd.py ' 129 | f'{main_args}' 130 | ) 131 | 132 | print('run_cmd', run_cmd) 133 | 134 | print('running') 135 | os.system(run_cmd) 136 | print('done') 137 | -------------------------------------------------------------------------------- /scripts/style.css: -------------------------------------------------------------------------------- 1 | /*.gradio-container{width:680px!important}*/ 2 | /* style.css */ 3 | .gradio_group, .gradio_row, .gradio_column { 4 | display: flex; 5 | flex-direction: row; 6 | justify-content: flex-start; 7 | align-items: flex-start; 8 | flex-wrap: wrap; 9 | } -------------------------------------------------------------------------------- /tools/convert_diffusers_to_pipeline.py: -------------------------------------------------------------------------------- 1 | 2 | from safetensors import safe_open 3 | from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPipeline, Transformer2DModel 4 | from transformers import T5EncoderModel, T5Tokenizer 5 | import pathlib 6 | import argparse 7 | import gc 8 | import torch 9 | import sys 10 | 11 | from pathlib import Path 12 | current_file_path = Path(__file__).resolve() 13 | sys.path.insert(0, str(current_file_path.parent.parent)) 14 | from scripts.diffusers_patches import pixart_sigma_init_patched_inputs 15 | 16 | interpolation_scale_sigma = {256: 0.5, 512: 1, 1024: 2, 2048: 4} 17 | ckpt_id = "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" 18 | 19 | def flush_memory(): 20 | gc.collect() 21 | torch.cuda.empty_cache() 22 | 23 | def main(args): 24 | safetensors_file = pathlib.Path(args.safetensors_path) 25 | image_size = args.image_size 26 | 27 | setattr(Transformer2DModel, '_init_patched_inputs', pixart_sigma_init_patched_inputs) 28 | pathlib.Path(args.output_folder).mkdir(parents=True, exist_ok=True) 29 | transformer = Transformer2DModel( 30 | sample_size=image_size // 8, 31 | num_layers=28, 32 | attention_head_dim=72, 33 | in_channels=4, 34 | out_channels=8, 35 | patch_size=2, 36 | attention_bias=True, 37 | num_attention_heads=16, 38 | cross_attention_dim=1152, 39 | activation_fn="gelu-approximate", 40 | num_embeds_ada_norm=1000, 41 | norm_type="ada_norm_single", 42 | norm_elementwise_affine=False, 43 | norm_eps=1e-6, 44 | caption_channels=4096, 45 | interpolation_scale=interpolation_scale_sigma[image_size], 46 | ).to('cuda') 47 | 48 | state_dict = {} 49 | with safe_open(safetensors_file, framework='pt') as f: 50 | for k in f.keys(): 51 | state_dict[k] = f.get_tensor(k) 52 | transformer.load_state_dict(state_dict, strict=True) 53 | 54 | transformer.save_pretrained(pathlib.Path.joinpath(pathlib.Path(args.output_folder), 'transformer')) 55 | scheduler = DPMSolverMultistepScheduler() 56 | vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae") 57 | tokenizer = T5Tokenizer.from_pretrained(ckpt_id, subfolder="tokenizer") 58 | text_encoder = T5EncoderModel.from_pretrained(ckpt_id, subfolder="text_encoder") 59 | 60 | pipe = PixArtAlphaPipeline(transformer=transformer, scheduler=scheduler, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder) 61 | pipe.save_config(pathlib.Path(args.output_folder)) 62 | del pipe 63 | del transformer 64 | del scheduler 65 | del vae 66 | del tokenizer 67 | del text_encoder 68 | flush_memory() 69 | 70 | scheduler = DPMSolverMultistepScheduler() 71 | scheduler.save_pretrained(pathlib.Path.joinpath(pathlib.Path(args.output_folder), 'scheduler')) 72 | del scheduler 73 | flush_memory() 74 | 75 | vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae").to('cuda') 76 | vae.save_pretrained(pathlib.Path.joinpath(pathlib.Path(args.output_folder), 'vae')) 77 | del vae 78 | flush_memory() 79 | 80 | tokenizer = T5Tokenizer.from_pretrained(ckpt_id, subfolder="tokenizer") 81 | tokenizer.save_pretrained(pathlib.Path.joinpath(pathlib.Path(args.output_folder), 'tokenizer')) 82 | del tokenizer 83 | flush_memory() 84 | 85 | text_encoder = T5EncoderModel.from_pretrained(ckpt_id, subfolder="text_encoder") 86 | text_encoder.save_pretrained(pathlib.Path.joinpath(pathlib.Path(args.output_folder), 'text_encoder')) 87 | del text_encoder 88 | flush_memory() 89 | 90 | if __name__ == '__main__': 91 | parser = argparse.ArgumentParser() 92 | 93 | parser.add_argument('--safetensors_path', required=True, type=str, help='Path to the .safetensors file to convert to diffusers folder structure') 94 | parser.add_argument('--image_size', required=False, default=512, type=int, choices=[256, 512, 1024, 2048], help='Image size of pretrained model') 95 | parser.add_argument('--output_folder', required=True, type=str, help='Path to the output folder') 96 | parser.add_argument('--multistep', required=False, type=bool, default=True, help='Multistep option') 97 | args = parser.parse_args() 98 | main(args) -------------------------------------------------------------------------------- /tools/convert_diffusers_to_pixart.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import annotations 3 | 4 | import argparse 5 | import sys 6 | from pathlib import Path 7 | 8 | from safetensors import safe_open 9 | 10 | current_file_path = Path(__file__).resolve() 11 | sys.path.insert(0, str(current_file_path.parent.parent)) 12 | 13 | import torch 14 | 15 | ckpt_id = "PixArt-alpha" 16 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125 17 | interpolation_scale_alpha = {256: 1, 512: 1, 1024: 2} 18 | interpolation_scale_sigma = {256: 0.5, 512: 1, 1024: 2, 2048: 4} 19 | 20 | def main(args): 21 | # load the pipe, but only the transformer 22 | repo_path = args.safetensor_path 23 | output_path = args.pth_path 24 | 25 | transformer = safe_open(repo_path, framework='pt') 26 | 27 | state_dict = transformer.keys() 28 | 29 | layer_depth = sum([key.endswith("attn1.to_out.0.weight") for key in state_dict]) or 28 30 | 31 | # check for micro condition?? 32 | converted_state_dict = { 33 | 'x_embedder.proj.weight': transformer.get_tensor('pos_embed.proj.weight'), 34 | 'x_embedder.proj.bias': transformer.get_tensor('pos_embed.proj.bias'), 35 | 'y_embedder.y_proj.fc1.weight': transformer.get_tensor('caption_projection.linear_1.weight'), 36 | 'y_embedder.y_proj.fc1.bias': transformer.get_tensor('caption_projection.linear_1.bias'), 37 | 'y_embedder.y_proj.fc2.weight': transformer.get_tensor('caption_projection.linear_2.weight'), 38 | 'y_embedder.y_proj.fc2.bias': transformer.get_tensor('caption_projection.linear_2.bias'), 39 | 't_embedder.mlp.0.weight': transformer.get_tensor('adaln_single.emb.timestep_embedder.linear_1.weight'), 40 | 't_embedder.mlp.0.bias': transformer.get_tensor('adaln_single.emb.timestep_embedder.linear_1.bias'), 41 | 't_embedder.mlp.2.weight': transformer.get_tensor('adaln_single.emb.timestep_embedder.linear_2.weight'), 42 | 't_embedder.mlp.2.bias': transformer.get_tensor('adaln_single.emb.timestep_embedder.linear_2.bias') 43 | } 44 | if 'adaln_single.emb.resolution_embedder.linear_1.weight' in state_dict: 45 | converted_state_dict['csize_embedder.mlp.0.weight'] = transformer.get_tensor('adaln_single.emb.resolution_embedder.linear_1.weight') 46 | converted_state_dict['csize_embedder.mlp.0.bias'] = transformer.get_tensor('adaln_single.emb.resolution_embedder.linear_1.bias') 47 | converted_state_dict['csize_embedder.mlp.2.weight'] = transformer.get_tensor('adaln_single.emb.resolution_embedder.linear_2.weight') 48 | converted_state_dict['csize_embedder.mlp.2.bias'] = transformer.get_tensor('adaln_single.emb.resolution_embedder.linear_2.bias') 49 | converted_state_dict['ar_embedder.mlp.0.weight'] = transformer.get_tensor('adaln_single.emb.aspect_ratio_embedder.linear_1.weight') 50 | converted_state_dict['ar_embedder.mlp.0.bias'] = transformer.get_tensor('adaln_single.emb.aspect_ratio_embedder.linear_1.bias') 51 | converted_state_dict['ar_embedder.mlp.2.weight'] = transformer.get_tensor('adaln_single.emb.aspect_ratio_embedder.linear_2.weight') 52 | converted_state_dict['ar_embedder.mlp.2.bias'] = transformer.get_tensor('adaln_single.emb.aspect_ratio_embedder.linear_2.bias') 53 | 54 | # shared norm 55 | converted_state_dict['t_block.1.weight'] = transformer.get_tensor('adaln_single.linear.weight') 56 | converted_state_dict['t_block.1.bias'] = transformer.get_tensor('adaln_single.linear.bias') 57 | 58 | for depth in range(layer_depth): 59 | print(f"Converting layer {depth}") 60 | converted_state_dict[f"blocks.{depth}.scale_shift_table"] = transformer.get_tensor(f"transformer_blocks.{depth}.scale_shift_table") 61 | 62 | # self attention 63 | q = transformer.get_tensor(f'transformer_blocks.{depth}.attn1.to_q.weight') 64 | q_bias = transformer.get_tensor(f'transformer_blocks.{depth}.attn1.to_q.bias') 65 | k = transformer.get_tensor(f'transformer_blocks.{depth}.attn1.to_k.weight') 66 | k_bias = transformer.get_tensor(f'transformer_blocks.{depth}.attn1.to_k.bias') 67 | v = transformer.get_tensor(f'transformer_blocks.{depth}.attn1.to_v.weight') 68 | v_bias = transformer.get_tensor(f'transformer_blocks.{depth}.attn1.to_v.bias') 69 | converted_state_dict[f'blocks.{depth}.attn.qkv.weight'] = torch.cat((q, k, v)) 70 | converted_state_dict[f'blocks.{depth}.attn.qkv.bias'] = torch.cat((q_bias, k_bias, v_bias)) 71 | 72 | # projection 73 | converted_state_dict[f"blocks.{depth}.attn.proj.weight"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn1.to_out.0.weight") 74 | converted_state_dict[f"blocks.{depth}.attn.proj.bias"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn1.to_out.0.bias") 75 | 76 | # check for qk norm 77 | if f'transformer_blocks.{depth}.attn1.q_norm.weight' in state_dict: 78 | converted_state_dict[f"blocks.{depth}.attn.q_norm.weight"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn1.q_norm.weight") 79 | converted_state_dict[f"blocks.{depth}.attn.q_norm.bias"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn1.q_norm.bias") 80 | converted_state_dict[f"blocks.{depth}.attn.k_norm.weight"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn1.k_norm.weight") 81 | converted_state_dict[f"blocks.{depth}.attn.k_norm.bias"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn1.k_norm.bias") 82 | 83 | # feed-forward 84 | converted_state_dict[f"blocks.{depth}.mlp.fc1.weight"] = transformer.get_tensor(f"transformer_blocks.{depth}.ff.net.0.proj.weight") 85 | converted_state_dict[f"blocks.{depth}.mlp.fc1.bias"] = transformer.get_tensor(f"transformer_blocks.{depth}.ff.net.0.proj.bias") 86 | converted_state_dict[f"blocks.{depth}.mlp.fc2.weight"] = transformer.get_tensor(f"transformer_blocks.{depth}.ff.net.2.weight") 87 | converted_state_dict[f"blocks.{depth}.mlp.fc2.bias"] = transformer.get_tensor(f"transformer_blocks.{depth}.ff.net.2.bias") 88 | 89 | # cross-attention 90 | q = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_q.weight") 91 | q_bias = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_q.bias") 92 | k = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_k.weight") 93 | k_bias = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_k.bias") 94 | v = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_v.weight") 95 | v_bias = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_v.bias") 96 | 97 | converted_state_dict[f"blocks.{depth}.cross_attn.q_linear.weight"] = q 98 | converted_state_dict[f"blocks.{depth}.cross_attn.q_linear.bias"] = q_bias 99 | converted_state_dict[f"blocks.{depth}.cross_attn.kv_linear.weight"] = torch.cat((k, v)) 100 | converted_state_dict[f"blocks.{depth}.cross_attn.kv_linear.bias"] = torch.cat((k_bias, v_bias)) 101 | 102 | converted_state_dict[f"blocks.{depth}.cross_attn.proj.weight"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_out.0.weight") 103 | converted_state_dict[f"blocks.{depth}.cross_attn.proj.bias"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_out.0.bias") 104 | 105 | # final block 106 | converted_state_dict["final_layer.linear.weight"] = transformer.get_tensor("proj_out.weight") 107 | converted_state_dict["final_layer.linear.bias"] = transformer.get_tensor("proj_out.bias") 108 | converted_state_dict["final_layer.scale_shift_table"] = transformer.get_tensor("scale_shift_table") 109 | 110 | # save the state_dict 111 | to_save = {} 112 | to_save['state_dict'] = converted_state_dict 113 | torch.save(to_save, output_path) 114 | 115 | if __name__ == '__main__': 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('--safetensor_path', type=str, required=True, help='Path and filename of a safetensor file to convert. i.e. output/mymodel.safetensors') 118 | parser.add_argument('--pth_path', type=str, required=True, help='Path and filename to the output file i.e. output/mymodel.pth') 119 | 120 | args = parser.parse_args() 121 | main(args) 122 | -------------------------------------------------------------------------------- /tools/convert_images_to_json.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | from os import walk 4 | from os import path 5 | from PIL import Image 6 | import json 7 | import tqdm 8 | 9 | def print_usage(): 10 | print('convert_images_to_json [params] images_path output_path') 11 | print('--caption_extension') 12 | 13 | def main(): 14 | args = sys.argv 15 | if len(args) < 3: 16 | print_usage() 17 | return 18 | 19 | input_folder = args[1] 20 | output_folder = args[-1] 21 | 22 | caption_extension = '.txt' 23 | try: 24 | caption_arg = args.index('--caption_extension') 25 | caption_extension = args[caption_arg + 1] 26 | except: 27 | pass 28 | 29 | # create a folder with the output path 30 | output_folder = Path(output_folder) 31 | output_folder.mkdir(parents=True, exist_ok=True) 32 | 33 | # create a InternData and a InternImgs inside the output path 34 | intern_data_folder = output_folder.joinpath('InternData') 35 | intern_data_folder.mkdir(parents=True, exist_ok=True) 36 | 37 | intern_imgs_folder = output_folder.joinpath('InternImgs') 38 | intern_imgs_folder.mkdir(parents=True, exist_ok=True) 39 | 40 | # create a data_info.json inside InternData 41 | data_info_path = intern_data_folder.joinpath('data_info.json') 42 | 43 | # create a table which will contain all the entries 44 | json_entries = [] 45 | with open(data_info_path, 'w') as json_file: 46 | for (dirpath, dirnames, filenames) in walk(input_folder): 47 | for filename in tqdm.tqdm(filenames): 48 | if not caption_extension in filename: 49 | continue 50 | 51 | # check if an image exists for this caption 52 | image_filename = filename[:-len(caption_extension)] 53 | 54 | for image_extension in ['.jpg', '.png', '.jpeg', 'webp', '.JPEG', '.JPG']: 55 | image_path = Path(dirpath).joinpath(image_filename + image_extension) 56 | if path.exists(image_path): 57 | write_entry(json_entries, dirpath, image_path, Path(dirpath).joinpath(filename), image_filename + image_extension, intern_imgs_folder) 58 | break 59 | 60 | # use the entries 61 | json_file.write(json.dumps(json_entries)) 62 | 63 | def write_entry(json_entries, folder, image_path, caption_path, image_filename, intern_imgs_path): 64 | # open the file containing the prompt 65 | with open(caption_path) as prompt_file: 66 | prompt = prompt_file.read() 67 | 68 | # read the images info 69 | image = Image.open(image_path) 70 | image_width = image.width 71 | image_height = image.height 72 | ratio = image_height / image_width 73 | 74 | entry = {} 75 | entry['width'] = image_width 76 | entry['height'] = image_height 77 | entry['ratio'] = ratio 78 | entry['path'] = image_filename 79 | entry['prompt'] = prompt 80 | entry['sharegpt4v'] = '' 81 | 82 | json_entries.append(entry) 83 | 84 | # make sure to copy the image to the internimgs folder with the new filename! 85 | image_output_path = intern_imgs_path.joinpath(image_filename) 86 | image.save(image_output_path) 87 | 88 | if __name__ == '__main__': 89 | main() -------------------------------------------------------------------------------- /tools/download.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Functions for downloading pre-trained PixArt models 10 | """ 11 | from torchvision.datasets.utils import download_url 12 | import torch 13 | import os 14 | import argparse 15 | 16 | 17 | pretrained_models = { 18 | 'PixArt-Sigma-XL-2-512-MS.pth', 'PixArt-Sigma-XL-2-256x256.pth', 'PixArt-Sigma-XL-2-1024-MS.pth' 19 | } 20 | 21 | 22 | def find_model(model_name): 23 | """ 24 | Finds a pre-trained G.pt model, downloading it if necessary. Alternatively, loads a model from a local path. 25 | """ 26 | if model_name in pretrained_models: # Find/download our pre-trained G.pt checkpoints 27 | return download_model(model_name) 28 | else: # Load a custom PixArt checkpoint: 29 | assert os.path.isfile(model_name), f'Could not find PixArt checkpoint at {model_name}' 30 | return torch.load(model_name, map_location=lambda storage, loc: storage) 31 | 32 | 33 | def download_model(model_name): 34 | """ 35 | Downloads a pre-trained PixArt model from the web. 36 | """ 37 | assert model_name in pretrained_models 38 | local_path = f'output/pretrained_models/{model_name}' 39 | if not os.path.isfile(local_path): 40 | hf_endpoint = os.environ.get("HF_ENDPOINT") 41 | if hf_endpoint is None: 42 | hf_endpoint = "https://huggingface.co" 43 | os.makedirs('output/pretrained_models', exist_ok=True) 44 | web_path = f'{hf_endpoint}/PixArt-alpha/PixArt-Sigma/resolve/main/{model_name}' 45 | download_url(web_path, 'output/pretrained_models/') 46 | model = torch.load(local_path, map_location=lambda storage, loc: storage) 47 | return model 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument('--model_names', nargs='+', type=str, default=pretrained_models) 53 | args = parser.parse_args() 54 | model_names = args.model_names 55 | model_names = set(model_names) 56 | 57 | # Download PixArt checkpoints 58 | for model in model_names: 59 | download_model(model) 60 | print('Done.') 61 | -------------------------------------------------------------------------------- /tools/generate_dmd_data_noise_pairs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import annotations 3 | import argparse 4 | import os 5 | import sys 6 | import json 7 | from pathlib import Path 8 | 9 | current_file_path = Path(__file__).resolve() 10 | sys.path.insert(0, str(current_file_path.parent.parent)) 11 | 12 | import numpy as np 13 | import random 14 | import torch 15 | from diffusers import PixArtAlphaPipeline, Transformer2DModel 16 | from tqdm import tqdm 17 | 18 | MAX_SEED = np.iinfo(np.int32).max 19 | 20 | def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: 21 | if randomize_seed: 22 | seed = random.randint(0, MAX_SEED) 23 | return seed 24 | 25 | @torch.no_grad() 26 | @torch.inference_mode() 27 | def generate(items): 28 | 29 | seed = int(randomize_seed_fn(0, randomize_seed=False)) 30 | generator = torch.Generator().manual_seed(seed) 31 | 32 | for item in tqdm(items, "Generating: "): 33 | prompt = item['prompt'] 34 | save_name = item['path'].split('.')[0] 35 | 36 | # noise 37 | latent_size = pipe.transformer.config.sample_size 38 | noise = torch.randn( 39 | (1, 4, latent_size, latent_size), generator=generator, dtype=torch.float32 40 | ).to(weight_dtype).to(device) 41 | 42 | # image 43 | img_latent = pipe( 44 | prompt=prompt, 45 | latents=noise.to(weight_dtype), 46 | generator=generator, 47 | output_type="latent", 48 | max_sequence_length=T5_token_max_length, 49 | ).images[0] 50 | 51 | # save noise-denoised latent features 52 | noise_save_path = os.path.join(noise_save_dir, f"{save_name}.npy") 53 | np.save(noise_save_path, noise[0].cpu().numpy()) 54 | img_latent_save_path = os.path.join(img_latent_save_dir, f"{save_name}.npy") 55 | np.save(img_latent_save_path, img_latent.cpu().numpy()) 56 | 57 | if args.save_img: 58 | image = pipe.vae.decode(img_latent / pipe.vae.config.scaling_factor, return_dict=False)[0] 59 | image = pipe.image_processor.postprocess(image, output_type="pil") 60 | img_save_path = os.path.join(img_save_dir, f"{save_name}.png") 61 | image.save(img_save_path) 62 | 63 | 64 | def parse_args(): 65 | parser = argparse.ArgumentParser(description="Process some integers.") 66 | parser.add_argument('--work-dir', help='the dir to save logs and models') 67 | parser.add_argument('--save_img', action='store_true', help='if save latents and images at the same time') 68 | parser.add_argument('--sample_nums', default=640_000, type=int, help='sample numbers') 69 | parser.add_argument('--T5_token_max_length', default=120, type=int, choices=[120, 300], help='T5 token length') 70 | parser.add_argument( 71 | '--model_path', default="PixArt-alpha/PixArt-XL-2-512x512", help='the dir to load a ckpt for teacher model') 72 | parser.add_argument( 73 | '--pipeline_load_from', default="PixArt-alpha/PixArt-XL-2-1024-MS", type=str, 74 | help="Download for loading text_encoder, " 75 | "tokenizer and vae from https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS") 76 | args = parser.parse_args() 77 | return args 78 | 79 | 80 | # Use PixArt-Alpha to generate PixArt-Alpha-DMD training data (noise-image pairs). 81 | if __name__ == '__main__': 82 | args = parse_args() 83 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 84 | 85 | metadata_json_list = ["pixart-sigma-toy-dataset/InternData/data_info.json",] 86 | # dataset 87 | meta_data_clean = [] 88 | for json_file in metadata_json_list: 89 | meta_data = json.load(open(json_file, 'r')) 90 | meta_data_clean.extend([item for item in meta_data if item['ratio'] <= 4.5]) 91 | 92 | weight_dtype = torch.float16 93 | T5_token_max_length = args.T5_token_max_length 94 | if torch.cuda.is_available(): 95 | 96 | # Teacher Model 97 | pipe = PixArtAlphaPipeline.from_pretrained( 98 | args.pipeline_load_from, 99 | transformer=None, 100 | torch_dtype=weight_dtype, 101 | ) 102 | pipe.transformer = Transformer2DModel.from_pretrained( 103 | args.model_path, subfolder="transformer", torch_dtype=weight_dtype 104 | ) 105 | pipe.to(device) 106 | 107 | print(f"INFO: Select only first {args.sample_nums} samples") 108 | meta_data_clean = meta_data_clean[:args.sample_nums] 109 | 110 | # save path 111 | if args.save_img: 112 | img_save_dir = os.path.join(f'pixart-sigma-toy-dataset/InternData/InternImgs_DMD_images') 113 | os.makedirs(img_save_dir, exist_ok=True) 114 | img_latent_save_dir = os.path.join(f'pixart-sigma-toy-dataset/InternData/InternImgs_DMD_latents') 115 | noise_save_dir = os.path.join(f'pixart-sigma-toy-dataset/InternData/InternImgs_DMD_noises') 116 | os.umask(0o000) # file permission: 666; dir permission: 777 117 | os.makedirs(img_latent_save_dir, exist_ok=True) 118 | os.makedirs(noise_save_dir, exist_ok=True) 119 | 120 | # generate 121 | generate(meta_data_clean) 122 | 123 | 124 | -------------------------------------------------------------------------------- /tools/merge_transformers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import annotations 3 | import argparse 4 | import os 5 | import sys 6 | from pathlib import Path 7 | import gc 8 | 9 | current_file_path = Path(__file__).resolve() 10 | sys.path.insert(0, str(current_file_path.parent.parent)) 11 | 12 | import torch 13 | from transformers import T5EncoderModel, T5Tokenizer 14 | import pathlib 15 | 16 | from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPipeline, Transformer2DModel 17 | from scripts.diffusers_patches import pixart_sigma_init_patched_inputs 18 | 19 | interpolation_scale_sigma = {256: 0.5, 512: 1, 1024: 2, 2048: 4} 20 | 21 | def main(args): 22 | # load the first checkpoint 23 | repo_path_a = args.repo_path_a 24 | repo_path_b = args.repo_path_b 25 | output_folder = pathlib.Path(args.output_folder) 26 | ratio = args.ratio 27 | 28 | setattr(Transformer2DModel, '_init_patched_inputs', pixart_sigma_init_patched_inputs) 29 | transformer_a = Transformer2DModel.from_pretrained(repo_path_a, subfolder='transformer') 30 | state_dict_a = transformer_a.state_dict() 31 | 32 | # load the second checkpoint 33 | transformer_b = Transformer2DModel.from_pretrained(repo_path_b, subfolder='transformer') 34 | state_dict_b = transformer_b.state_dict() 35 | 36 | new_state_dict = {} 37 | 38 | for key, value in state_dict_a.items(): 39 | value_a = state_dict_a[key] 40 | value_b = state_dict_b[key] 41 | new_val = torch.lerp(value_a, value_b, ratio) 42 | new_state_dict[key] = new_val 43 | 44 | # delete the transformers to reduce RAM requirements 45 | del transformer_a 46 | del transformer_b 47 | del state_dict_a 48 | del state_dict_b 49 | gc.collect() 50 | 51 | # save the new transformer 52 | new_transformer = Transformer2DModel.from_pretrained(repo_path_a, subfolder='transformer') 53 | new_transformer.load_state_dict(new_state_dict) 54 | new_transformer.save_pretrained(output_folder) 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--repo_path_a', required=True, type=str) 59 | parser.add_argument('--repo_path_b', required=True, type=str) 60 | parser.add_argument('--output_folder', required=True, type=str) 61 | parser.add_argument('--ratio', required=True, type=float) 62 | parser.add_argument('--version', required=False, default='sigma', type=str) 63 | 64 | args = parser.parse_args() 65 | main(args) -------------------------------------------------------------------------------- /train_scripts/train_dmd.sh: -------------------------------------------------------------------------------- 1 | config=configs/pixart_app_config/PixArt-DMD_xl2_img512_internalms.py 2 | work_dir=output/debug/ 3 | 4 | # machine 5 | machine_num=1 6 | np=8 7 | 8 | # training settings 9 | max_samples=500000 10 | batchsize=1 11 | mix_precision='no' 12 | use_dm=1 13 | use_regression=1 14 | regression_weight=0.25 15 | one_step_maxt=400 16 | learning_rate=1e-6 17 | lr_fake_multiplier=1.0 18 | max_grad_norm=10 19 | gradient_accumulation_steps=2 20 | save_image_interval=5000 21 | 22 | # resume from checkpoint 23 | resume_from="" 24 | 25 | # train 26 | python_command="python scripts/run_pixart_dmd.py --world_size 1 ${config} " 27 | python_command+="--is_debugging " 28 | python_command+="--work_dir=${work_dir} --machine_num=${machine_num} --np=${np} " 29 | python_command+="--max_samples=${max_samples} " 30 | python_command+="--batch_size=${batchsize} --mixed_precision=${mix_precision} " 31 | python_command+="--use_dm=${use_dm} " 32 | python_command+="--use_regression=${use_regression} --regression_weight=${regression_weight} " 33 | python_command+="--one_step_maxt=${one_step_maxt} " 34 | python_command+="--learning_rate=${learning_rate} " 35 | python_command+="--lr_fake_multiplier=${lr_fake_multiplier} " 36 | python_command+="--max_grad_norm=${max_grad_norm} " 37 | python_command+="--gradient_accumulation_steps=${gradient_accumulation_steps} " 38 | python_command+="--save_image_interval=${save_image_interval} " 39 | 40 | python_command+="--resume_from=${resume_from} " 41 | 42 | eval $python_command -------------------------------------------------------------------------------- /train_scripts/train_pixart_lora.sh: -------------------------------------------------------------------------------- 1 | pip install -U peft 2 | 3 | dataset_id=svjack/pokemon-blip-captions-en-zh 4 | model_id=PixArt-alpha/PixArt-XL-2-512x512 5 | 6 | accelerate launch --num_processes=1 --main_process_port=36667 train_scripts/train_pixart_lora_hf.py \ 7 | --mixed_precision="fp16" \ 8 | --pretrained_model_name_or_path=$model_id \ 9 | --dataset_name=$dataset_id \ 10 | --caption_column="text" \ 11 | --resolution=512 \ 12 | --random_flip \ 13 | --train_batch_size=16 \ 14 | --num_train_epochs=80 \ 15 | --checkpointing_steps=1000 \ 16 | --learning_rate=1e-05 \ 17 | --lr_scheduler="constant" \ 18 | --lr_warmup_steps=0 \ 19 | --seed=42 \ 20 | --output_dir="output/pixart-pokemon-model" \ 21 | --validation_prompt="cute dragon creature" \ 22 | --report_to="tensorboard" \ 23 | --gradient_checkpointing \ 24 | --checkpoints_total_limit=10 \ 25 | --validation_epochs=5 \ 26 | --max_token_length=120 \ 27 | --rank=16 --------------------------------------------------------------------------------