├── .dockerignore
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── app
├── app_pixart_dmd.py
└── app_pixart_sigma.py
├── asset
├── PixArt.svg
├── docs
│ ├── convert_image2json.md
│ ├── data_feature_extraction.md
│ ├── pixart.md
│ ├── pixart_dmd.md
│ └── pixart_lora.md
├── examples.py
├── imgs
│ ├── dmd.png
│ ├── lora_512.png
│ └── noise_snr.png
├── logo-sigma.png
├── logo.png
└── samples.txt
├── configs
├── PixArt_xl2_internal.py
├── pixart_alpha_config
│ ├── PixArt_xl2_img1024_dreambooth.py
│ ├── PixArt_xl2_img1024_internal.py
│ ├── PixArt_xl2_img1024_internalms.py
│ ├── PixArt_xl2_img256_internal.py
│ ├── PixArt_xl2_img512_internal.py
│ └── PixArt_xl2_img512_internalms.py
├── pixart_app_config
│ └── PixArt-DMD_xl2_img512_internalms.py
└── pixart_sigma_config
│ ├── PixArt_sigma_xl2_img1024_internalms.py
│ ├── PixArt_sigma_xl2_img1024_internalms_kvcompress.py
│ ├── PixArt_sigma_xl2_img1024_lcm.py
│ ├── PixArt_sigma_xl2_img256_internal.py
│ ├── PixArt_sigma_xl2_img2K_internalms_kvcompress.py
│ └── PixArt_sigma_xl2_img512_internalms.py
├── diffusion
├── __init__.py
├── data
│ ├── __init__.py
│ ├── builder.py
│ ├── datasets
│ │ ├── InternalData.py
│ │ ├── InternalData_ms.py
│ │ ├── __init__.py
│ │ ├── dmd.py
│ │ └── utils.py
│ └── transforms.py
├── dpm_solver.py
├── iddpm.py
├── lcm_scheduler.py
├── model
│ ├── __init__.py
│ ├── builder.py
│ ├── diffusion_utils.py
│ ├── dpm_solver.py
│ ├── edm_sample.py
│ ├── gaussian_diffusion.py
│ ├── llava
│ │ ├── __init__.py
│ │ ├── llava_mpt.py
│ │ └── mpt
│ │ │ ├── attention.py
│ │ │ ├── blocks.py
│ │ │ ├── configuration_mpt.py
│ │ │ ├── modeling_mpt.py
│ │ │ ├── norm.py
│ │ │ └── param_init_fns.py
│ ├── nets
│ │ ├── PixArt.py
│ │ ├── PixArtMS.py
│ │ ├── PixArt_blocks.py
│ │ └── __init__.py
│ ├── respace.py
│ ├── sa_solver.py
│ ├── t5.py
│ ├── timestep_sampler.py
│ └── utils.py
├── sa_sampler.py
├── sa_solver_diffusers.py
└── utils
│ ├── __init__.py
│ ├── checkpoint.py
│ ├── data_sampler.py
│ ├── dist_utils.py
│ ├── logger.py
│ ├── lr_scheduler.py
│ ├── misc.py
│ └── optimizer.py
├── environment.yml
├── notebooks
├── PixArt_xl2_img512_internal_for_pokemon_sample_training.py
├── convert-checkpoint-to-diffusers.ipynb
├── infer.ipynb
└── train.ipynb
├── requirements.txt
├── scripts
├── DMD
│ └── transformer_train
│ │ ├── args.py
│ │ ├── attention_processor.py
│ │ ├── generate.py
│ │ └── utils.py
├── diffusers_patches.py
├── inference.py
├── inference_pipeline.py
├── interface.py
├── run_pixart_dmd.py
└── style.css
├── tools
├── convert_diffusers_to_pipeline.py
├── convert_diffusers_to_pixart.py
├── convert_images_to_json.py
├── convert_pixart_to_diffusers.py
├── download.py
├── extract_features.py
├── generate_dmd_data_noise_pairs.py
└── merge_transformers.py
└── train_scripts
├── train.py
├── train_dmd.sh
├── train_dreambooth_lora.py
├── train_pixart_dmd.py
├── train_pixart_lcm.py
├── train_pixart_lora.sh
└── train_pixart_lora_hf.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | docker
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug
139 | .idea/
140 | cloud_tools/
141 | output/
142 | output_cv/
143 |
144 | # added by ylw
145 | *.pt
146 | *.pth
147 | *mj*
148 | s3helper/
149 | TODO.md
150 | pretrained_models
151 | work_dir
152 | #demo.py
153 | develop/
154 | tmp.py
155 | output_cv/
156 | output_all/
157 | output_demo/
158 | output_debug/
159 | output/
160 |
161 | #cache for docker
162 | docker/cache/gradio
163 | docker/cache/huggingface
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # This is a sample Dockefile that builds a runtime container and runs the sample Gradio app.
2 | # Note, you must pass in the pretrained models when you run the container.
3 |
4 | FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04
5 |
6 | WORKDIR /workspace
7 |
8 | RUN apt-get update && \
9 | apt-get install -y \
10 | git \
11 | python3 \
12 | python-is-python3 \
13 | python3-pip \
14 | python3.10-venv \
15 | libgl1 \
16 | libgl1-mesa-glx \
17 | libglib2.0-0 \
18 | && rm -rf /var/lib/apt/lists/*
19 |
20 | ADD requirements.txt .
21 |
22 | RUN pip install -r requirements.txt
23 |
24 | ADD . .
25 |
26 | RUN chmod a+x docker-entrypoint.sh
27 |
28 | ENV DEMO_PORT=12345
29 | ENTRYPOINT [ "/workspace/docker-entrypoint.sh" ]
--------------------------------------------------------------------------------
/asset/PixArt.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
97 |
--------------------------------------------------------------------------------
/asset/docs/convert_image2json.md:
--------------------------------------------------------------------------------
1 | ## Tools
2 |
3 | # tools/convert_images_to_json.py
4 |
5 | This script is used to convert a folder that contains images and caption files to a dataset folder structure required by the PixArt-Sigma training scripts.
6 |
7 | Before :
8 | - root
9 | - image1.png
10 | - image1.txt
11 | - image2.jpg
12 | - image2.txt
13 | - image3.webp
14 | - image3.txt
15 |
16 | After :
17 | - out
18 | - InternData
19 | - data_info.json
20 | - InternImgs
21 | - image1.png
22 | - image2.jpg
23 | - image3.webp
24 |
25 | The script detects all images with a paired caption and copies these in the InternImgs folder and its prompt in the data_info.json.
26 |
27 | The usage is the following:
28 |
29 | `python tools/convert_images_to_json.py [params] images_path output_path`
30 |
31 | The caption file extension is by default .txt but the user can change it with the argument `--caption_extension .caption` for example.
--------------------------------------------------------------------------------
/asset/docs/data_feature_extraction.md:
--------------------------------------------------------------------------------
1 | # 📕 Data Preparation
2 |
3 | ### 1.Downloading the toy dataset
4 |
5 | Download the [toy dataset](https://huggingface.co/datasets/PixArt-alpha/pixart-sigma-toy-dataset) first.
6 | The dataset structure for training is:
7 |
8 | ```
9 | cd your_project_path/pixart-sigma-toy-dataset
10 |
11 | Dataset Structure
12 | ├──InternImgs/ (images are saved here)
13 | │ ├──000000000000.png
14 | │ ├──000000000001.png
15 | │ ├──......
16 | ├──InternData/
17 | │ ├──data_info.json (meta data)
18 | Optional(👇)
19 | │ ├──img_sdxl_vae_features_512resolution_ms_new (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension)
20 | │ │ ├──000000000000.npy
21 | │ │ ├──000000000001.npy
22 | │ │ ├──......
23 | │ ├──caption_features_new
24 | │ │ ├──000000000000.npz
25 | │ │ ├──000000000001.npz
26 | │ │ ├──......
27 | │ ├──sharegpt4v_caption_features_new (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension)
28 | │ │ ├──000000000000.npz
29 | │ │ ├──000000000001.npz
30 | │ │ ├──......
31 | ```
32 | ### You are already able to run the [training code](https://github.com/PixArt-alpha/PixArt-sigma#12-download-pretrained-checkpoint)
33 |
34 | ---
35 | ## Optional(👇)
36 | > [!IMPORTANT]
37 | > You don't have to extract following feature to do the training, BUT
38 | >
39 | > if you want to train with **faster speed** and **lower GPU occupancy**, you can pre-process all the VAE & T5 features
40 |
41 | ### 2. Extract VAE features
42 |
43 | ```bash
44 | python tools/extract_features.py --run_vae_feature_extract \
45 | --multi_scale \
46 | --img_size=512 \
47 | --dataset_root=pixart-sigma-toy-dataset/InternData \
48 | --vae_json_file=data_info.json \
49 | --vae_models_dir=madebyollin/sdxl-vae-fp16-fix \
50 | --vae_save_root=pixart-sigma-toy-dataset/InternData
51 | ```
52 | **SDXL-VAE** features will be saved at: `pixart-sigma-toy-dataset/InternData/img_sdxl_vae_features_512resolution_ms_new`
53 | as shown in the [DataTree](#1downloading-the-toy-dataset).
54 | They will be later used in [InternalData_ms.py](https://github.com/PixArt-alpha/PixArt-sigma/blob/d5adc756dd6a8b64f1f0aaa1d266e90949e873c0/diffusion/data/datasets/InternalData_ms.py#L242)
55 |
56 | ### 3. Extract T5 features (prompt)
57 |
58 | ```bash
59 | python tools/extract_features.py --run_t5_feature_extract \
60 | --max_length=300 \
61 | --t5_json_path=pixart-sigma-toy-dataset/InternData/data_info.json \
62 | --t5_models_dir=PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers \
63 | --caption_label=prompt \
64 | --t5_save_root=pixart-sigma-toy-dataset/InternData
65 | ```
66 | **T5 features** will be saved at: `pixart-sigma-toy-dataset/InternData/caption_features_new`
67 | as shown in the [DataTree](#1downloading-the-toy-dataset).
68 | They will be later used in [InternalData_ms.py](https://github.com/PixArt-alpha/PixArt-sigma/blob/d5adc756dd6a8b64f1f0aaa1d266e90949e873c0/diffusion/data/datasets/InternalData_ms.py#L227)
69 |
70 | ---
71 | > [!TIP]
72 | > Ignore it if you don't have `sharegpt4v` in your data_info.json
73 |
74 | ### 3.1. Extract T5 features (sharegpt4v)
75 |
76 | ```bash
77 | python tools/extract_features.py --run_t5_feature_extract \
78 | --max_length=300 \
79 | --t5_json_path=pixart-sigma-toy-dataset/InternData/data_info.json \
80 | --t5_models_dir=PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers \
81 | --caption_label=sharegpt4v \
82 | --t5_save_root=pixart-sigma-toy-dataset/InternData
83 | ```
84 | **T5 features** will be saved at: `pixart-sigma-toy-dataset/InternData/caption_features_new`
85 | as shown in the [DataTree](#1downloading-the-toy-dataset).
86 | They will be later used in [InternalData_ms.py](https://github.com/PixArt-alpha/PixArt-sigma/blob/d5adc756dd6a8b64f1f0aaa1d266e90949e873c0/diffusion/data/datasets/InternalData_ms.py#L234)
87 |
88 |
--------------------------------------------------------------------------------
/asset/docs/pixart.md:
--------------------------------------------------------------------------------
1 |
12 |
13 | [//]: # ((reference from [hugging Face](https://github.com/huggingface/diffusers/blob/docs/8bit-inference-pixart/docs/source/en/api/pipelines/pixart.md)))
14 |
15 | ## Running the `PixArtAlphaPipeline` in under 8GB GPU VRAM
16 |
17 | It is possible to run the [`PixArtAlphaPipeline`] under 8GB GPU VRAM by loading the text encoder in 8-bit numerical precision. Let's walk through a full-fledged example.
18 |
19 | First, install the `bitsandbytes` library:
20 |
21 | ```bash
22 | pip install -U bitsandbytes
23 | ```
24 |
25 | Then load the text encoder in 8-bit:
26 |
27 | ```python
28 | from transformers import T5EncoderModel
29 | from diffusers import PixArtAlphaPipeline
30 |
31 | text_encoder = T5EncoderModel.from_pretrained(
32 | "PixArt-alpha/PixArt-XL-2-1024-MS",
33 | subfolder="text_encoder",
34 | load_in_8bit=True,
35 | device_map="auto",
36 |
37 | )
38 | pipe = PixArtAlphaPipeline.from_pretrained(
39 | "PixArt-alpha/PixArt-XL-2-1024-MS",
40 | text_encoder=text_encoder,
41 | transformer=None,
42 | device_map="auto"
43 | )
44 | ```
45 |
46 | Now, use the `pipe` to encode a prompt:
47 |
48 | ```python
49 | with torch.no_grad():
50 | prompt = "cute cat"
51 | prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)
52 |
53 | del text_encoder
54 | del pipe
55 | flush()
56 | ```
57 |
58 | `flush()` is just a utility function to clear the GPU VRAM and is implemented like so:
59 |
60 | ```python
61 | import gc
62 |
63 | def flush():
64 | gc.collect()
65 | torch.cuda.empty_cache()
66 | ```
67 |
68 | Then compute the latents providing the prompt embeddings as inputs:
69 |
70 | ```python
71 | pipe = PixArtAlphaPipeline.from_pretrained(
72 | "PixArt-alpha/PixArt-XL-2-1024-MS",
73 | text_encoder=None,
74 | torch_dtype=torch.float16,
75 | ).to("cuda")
76 |
77 | latents = pipe(
78 | negative_prompt=None,
79 | prompt_embeds=prompt_embeds,
80 | negative_prompt_embeds=negative_embeds,
81 | prompt_attention_mask=prompt_attention_mask,
82 | negative_prompt_attention_mask=negative_prompt_attention_mask,
83 | num_images_per_prompt=1,
84 | output_type="latent",
85 | ).images
86 |
87 | del pipe.transformer
88 | flush()
89 | ```
90 |
91 | Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded.
92 |
93 | Once the latents are computed, pass it off the VAE to decode into a real image:
94 |
95 | ```python
96 | with torch.no_grad():
97 | image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
98 | image = pipe.image_processor.postprocess(image, output_type="pil")
99 | image.save("cat.png")
100 | ```
101 |
102 | All of this, put together, should allow you to run [`PixArtAlphaPipeline`] under 8GB GPU VRAM.
103 |
104 | 
105 |
106 | Find the script [here](https://gist.github.com/sayakpaul/3ae0f847001d342af27018a96f467e4e) that can be run end-to-end to report the memory being used.
107 |
108 |
109 |
110 | Text embeddings computed in 8-bit can have an impact on the quality of the generated images because of the information loss in the representation space induced by the reduced precision. It's recommended to compare the outputs with and without 8-bit.
111 |
112 |
--------------------------------------------------------------------------------
/asset/docs/pixart_dmd.md:
--------------------------------------------------------------------------------
1 | ## Summary
2 |
3 | **We combine [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha) and [DMD](https://arxiv.org/abs/2311.18828)
4 | to achieve one step image generation. This document will guide you how to train and test.**
5 |
6 | 
7 |
8 | > [!IMPORTANT]
9 | > Due to the difference between the DiT & Stable Diffusion,
10 | >
11 | > We find that the setting of the `start timestep` of the student model is very important for DMD training.
12 |
13 | refer to PixArt-Sigma paper's Supplementary for more details: https://arxiv.org/abs/2403.04692
14 |
15 | ---
16 | ## How to Train
17 |
18 | ### 1. Environment
19 | Refer to the PixArt-Sigma Environment [Here](https://github.com/PixArt-alpha/PixArt-sigma/tree/d5adc756dd6a8b64f1f0aaa1d266e90949e873c0?tab=readme-ov-file#-dependencies-and-installation).
20 | To use FSDP, you need to make sure:
21 | - [PyTorch >= 2.0.1+cu11.7](https://pytorch.org/)
22 |
23 | ### 2. Data preparation (Image-Noise pairs)
24 |
25 | #### Generate image-noise pairs
26 | ```bash
27 | python tools/generate_dmd_data_noise_pairs.py \
28 | --pipeline_load_from=PixArt-alpha/PixArt-XL-2-512x512 \
29 | --model_path=PixArt-alpha/PixArt-XL-2-512x512 \
30 | --save_img # (optinal)
31 | ```
32 |
33 | #### Extract features in advance (Ignore if you already have)
34 | ```bash
35 | python tools/extract_features.py --run_t5_feature_extract \
36 | --t5_models_dir=PixArt-alpha/PixArt-XL-2-512x512 \
37 | --t5_save_root=pixart-sigma-toy-dataset/InternData \
38 | --caption_label=prompt \
39 | --t5_json_path=pixart-sigma-toy-dataset/InternData/data_info.json
40 | ```
41 |
42 | ### 3. 🔥 Run
43 | ```bash
44 | bash train_scripts/train_dmd.sh
45 | ```
46 |
47 | ---
48 | ## How to Test
49 | ### PixArt-DMD Demo
50 | ```bash
51 | pip install git+https://github.com/huggingface/diffusers
52 |
53 | # PixArt-Sigma One step Sampler(DMD)
54 | DEMO_PORT=12345 python app/app_pixart_dmd.py
55 | ```
56 | Let's have a look at a simple example using the `http://your-server-ip:12345`.
57 |
58 | ---
59 | ## Samples
60 | 
61 |
62 |
--------------------------------------------------------------------------------
/asset/docs/pixart_lora.md:
--------------------------------------------------------------------------------
1 | ## Summary
2 |
3 | **We adapt from the LoRA training code from [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha)
4 | to achieve Transformer-LoRA fine-tuning. This document will guide you how to train and test.**
5 |
6 | > [!IMPORTANT]
7 | > Somehow due to the implementation of `diffusers` and `transformers`,
8 | > LoRA training for `transformers` can only be done in FP32.
9 | >
10 | > We welcome everyone to help for solving this issue.
11 |
12 | ## How to Train
13 | ### 🔥 Run
14 | ```bahs
15 | bash train_scripts/train_pixart_lora.sh
16 | ```
17 |
18 | Details👇:
19 |
20 | ```bash
21 | pip install -U peft
22 |
23 | dataset_id=svjack/pokemon-blip-captions-en-zh
24 | model_id=PixArt-alpha/PixArt-XL-2-512x512
25 |
26 | accelerate launch --num_processes=1 --main_process_port=36667 train_scripts/train_pixart_lora_hf.py \
27 | --mixed_precision="fp16" \
28 | --pretrained_model_name_or_path=$model_id \
29 | --dataset_name=$dataset_id \
30 | --caption_column="text" \
31 | --resolution=512 \
32 | --random_flip \
33 | --train_batch_size=16 \
34 | --num_train_epochs=80 \
35 | --checkpointing_steps=1000 \
36 | --learning_rate=1e-05 \
37 | --lr_scheduler="constant" \
38 | --lr_warmup_steps=0 \
39 | --seed=42 \
40 | --output_dir="output/pixart-pokemon-model" \
41 | --validation_prompt="cute dragon creature" \
42 | --report_to="tensorboard" \
43 | --gradient_checkpointing \
44 | --checkpoints_total_limit=10 \
45 | --validation_epochs=5 \
46 | --max_token_length=120 \ # chang to 300 for Sigma
47 | --rank=16
48 | ```
49 |
50 | ## How to Test
51 |
52 | ```python
53 | import torch
54 | from diffusers import PixArtAlphaPipeline, Transformer2DModel
55 | from peft import PeftModel
56 |
57 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58 |
59 | # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-1024-MS" too.
60 | MODEL_ID = "PixArt-alpha/PixArt-XL-2-512x512"
61 |
62 | # LoRA model
63 | transformer = Transformer2DModel.from_pretrained(MODEL_ID, subfolder="transformer", torch_dtype=torch.float16)
64 | transformer = PeftModel.from_pretrained(transformer, "Your-LoRA-Model-Path")
65 |
66 | # Pipeline
67 | pipe = PixArtAlphaPipeline.from_pretrained(MODEL_ID, transformer=transformer, torch_dtype=torch.float16)
68 | del transformer
69 |
70 | pipe.to(device)
71 |
72 | prompt = "a drawing of a green pokemon with red eyes"
73 | image = pipe(prompt).images[0]
74 | image.save("./pokemon_with_red_eyes.png")
75 | ```
76 |
77 | ---
78 | ## Samples
79 | 
--------------------------------------------------------------------------------
/asset/examples.py:
--------------------------------------------------------------------------------
1 |
2 | examples = [
3 | [
4 | "A small cactus with a happy face in the Sahara desert.",
5 | "dpm-solver", 20, 4.5,
6 | ],
7 | [
8 | "An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history"
9 | "of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits "
10 | "mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt, he wears a brown beret "
11 | "and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile "
12 | "as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and "
13 | "the Parisian streets and city in the background, depth of field, cinematic 35mm film.",
14 | "dpm-solver", 20, 4.5,
15 | ],
16 | [
17 | "An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. "
18 | "Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. "
19 | "The quote 'Find the universe within you' is etched in bold letters across the horizon."
20 | "blue and pink, brilliantly illuminated in the background.",
21 | "dpm-solver", 20, 4.5,
22 | ],
23 | [
24 | "A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.",
25 | "dpm-solver", 20, 4.5,
26 | ],
27 | [
28 | "A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.",
29 | "dpm-solver", 20, 4.5,
30 | ],
31 | [
32 | "a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, "
33 | "national geographic photo, 8k resolution, crayon art, interactive artwork",
34 | "dpm-solver", 20, 4.5,
35 | ]
36 | ]
37 |
--------------------------------------------------------------------------------
/asset/imgs/dmd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma/1ce521afddcc2fab329b35b7374aa86d654e12f7/asset/imgs/dmd.png
--------------------------------------------------------------------------------
/asset/imgs/lora_512.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma/1ce521afddcc2fab329b35b7374aa86d654e12f7/asset/imgs/lora_512.png
--------------------------------------------------------------------------------
/asset/imgs/noise_snr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma/1ce521afddcc2fab329b35b7374aa86d654e12f7/asset/imgs/noise_snr.png
--------------------------------------------------------------------------------
/asset/logo-sigma.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma/1ce521afddcc2fab329b35b7374aa86d654e12f7/asset/logo-sigma.png
--------------------------------------------------------------------------------
/asset/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma/1ce521afddcc2fab329b35b7374aa86d654e12f7/asset/logo.png
--------------------------------------------------------------------------------
/configs/PixArt_xl2_internal.py:
--------------------------------------------------------------------------------
1 | data_root = '/data/data'
2 | data = dict(type='InternalData', root='images', image_list_json=['data_info.json'], transform='default_train', load_vae_feat=True, load_t5_feat=True)
3 | image_size = 256 # the generated image resolution
4 | train_batch_size = 32
5 | eval_batch_size = 16
6 | use_fsdp=False # if use FSDP mode
7 | valid_num=0 # take as valid aspect-ratio when sample number >= valid_num
8 | fp32_attention = True
9 | # model setting
10 | model = 'PixArt_XL_2'
11 | aspect_ratio_type = None # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
12 | multi_scale = False # if use multiscale dataset model training
13 | pe_interpolation = 1.0 # positional embedding interpolation
14 | # qk norm
15 | qk_norm = False
16 | # kv token compression
17 | kv_compress = False
18 | kv_compress_config = {
19 | 'sampling': None,
20 | 'scale_factor': 1,
21 | 'kv_compress_layer': [],
22 | }
23 |
24 | # training setting
25 | num_workers=4
26 | train_sampling_steps = 1000
27 | visualize=False
28 | # Keep the same seed during validation
29 | deterministic_validation = False
30 | eval_sampling_steps = 250
31 | model_max_length = 120
32 | lora_rank = 4
33 | num_epochs = 80
34 | gradient_accumulation_steps = 1
35 | grad_checkpointing = False
36 | gradient_clip = 1.0
37 | gc_step = 1
38 | auto_lr = dict(rule='sqrt')
39 | validation_prompts = [
40 | "dog",
41 | "portrait photo of a girl, photograph, highly detailed face, depth of field",
42 | "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
43 | "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
44 | "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
45 | ]
46 |
47 | # we use different weight decay with the official implementation since it results better result
48 | optimizer = dict(type='AdamW', lr=1e-4, weight_decay=3e-2, eps=1e-10)
49 | lr_schedule = 'constant'
50 | lr_schedule_args = dict(num_warmup_steps=500)
51 |
52 | save_image_epochs = 1
53 | save_model_epochs = 1
54 | save_model_steps=1000000
55 |
56 | sample_posterior = True
57 | mixed_precision = 'fp16'
58 | scale_factor = 0.18215 # ldm vae: 0.18215; sdxl vae: 0.13025
59 | ema_rate = 0.9999
60 | tensorboard_mox_interval = 50
61 | log_interval = 50
62 | cfg_scale = 4
63 | mask_type='null'
64 | num_group_tokens=0
65 | mask_loss_coef=0.
66 | load_mask_index=False # load prepared mask_type index
67 | # load model settings
68 | vae_pretrained = "/cache/pretrained_models/sd-vae-ft-ema"
69 | load_from = None
70 | resume_from = dict(checkpoint=None, load_ema=False, resume_optimizer=True, resume_lr_scheduler=True)
71 | snr_loss=False
72 | real_prompt_ratio = 1.0
73 | # classifier free guidance
74 | class_dropout_prob = 0.1
75 | # work dir settings
76 | work_dir = '/cache/exps/'
77 | s3_work_dir = None
78 | micro_condition = False
79 | seed = 43
80 | skip_step=0
81 |
82 | # LCM
83 | loss_type = 'huber'
84 | huber_c = 0.001
85 | num_ddim_timesteps=50
86 | w_max = 15.0
87 | w_min = 3.0
88 | ema_decay = 0.95
89 |
90 |
--------------------------------------------------------------------------------
/configs/pixart_alpha_config/PixArt_xl2_img1024_dreambooth.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../PixArt_xl2_internal.py']
2 | data_root = 'data/dreambooth/dataset'
3 |
4 | data = dict(type='DreamBooth', root='dog6', prompt=['a photo of sks dog'], transform='default_train', load_vae_feat=True)
5 | image_size = 1024
6 |
7 | # model setting
8 | model = 'PixArtMS_XL_2' # model for multi-scale training
9 | fp32_attention = True
10 | load_from = 'Path/to/PixArt-XL-2-1024-MS.pth'
11 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
12 | aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
13 | multi_scale = True # if use multiscale dataset model training
14 | pe_interpolation = 2.0
15 |
16 | # training setting
17 | num_workers=1
18 | train_batch_size = 1
19 | num_epochs = 200
20 | gradient_accumulation_steps = 1
21 | grad_checkpointing = True
22 | gradient_clip = 0.01
23 | optimizer = dict(type='AdamW', lr=5e-6, weight_decay=3e-2, eps=1e-10)
24 | lr_schedule_args = dict(num_warmup_steps=0)
25 | auto_lr = None
26 |
27 | log_interval = 1
28 | save_model_epochs=10000
29 | save_model_steps=100
30 | work_dir = 'output/debug'
31 |
--------------------------------------------------------------------------------
/configs/pixart_alpha_config/PixArt_xl2_img1024_internal.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../PixArt_xl2_internal.py']
2 | data_root = 'data'
3 | image_list_json = ['data_info.json',]
4 |
5 | data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
6 | image_size = 1024
7 |
8 | # model setting
9 | model = 'PixArt_XL_2'
10 | fp32_attention = True
11 | load_from = None
12 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
13 | pe_interpolation = 2.0
14 |
15 | # training setting
16 | num_workers=10
17 | train_batch_size = 2 # 32
18 | num_epochs = 200 # 3
19 | gradient_accumulation_steps = 1
20 | grad_checkpointing = True
21 | gradient_clip = 0.01
22 | optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
23 | lr_schedule_args = dict(num_warmup_steps=1000)
24 |
25 | eval_sampling_steps = 200
26 | log_interval = 20
27 | save_model_epochs=1
28 | save_model_steps=2000
29 | work_dir = 'output/debug'
30 |
--------------------------------------------------------------------------------
/configs/pixart_alpha_config/PixArt_xl2_img1024_internalms.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../PixArt_xl2_internal.py']
2 | data_root = 'data'
3 | image_list_json = ['data_info.json',]
4 |
5 | data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
6 | image_size = 1024
7 |
8 | # model setting
9 | model = 'PixArtMS_XL_2' # model for multi-scale training
10 | fp32_attention = True
11 | load_from = None
12 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
13 | aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
14 | multi_scale = True # if use multiscale dataset model training
15 | pe_interpolation = 2.0
16 |
17 | # training setting
18 | num_workers=10
19 | train_batch_size = 12 # max 14 for PixArt-xL/2 when grad_checkpoint
20 | num_epochs = 10 # 3
21 | gradient_accumulation_steps = 1
22 | grad_checkpointing = True
23 | gradient_clip = 0.01
24 | optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
25 | lr_schedule_args = dict(num_warmup_steps=1000)
26 | save_model_epochs=1
27 | save_model_steps=2000
28 |
29 | log_interval = 20
30 | eval_sampling_steps = 200
31 | work_dir = 'output/debug'
32 | micro_condition = True
33 |
--------------------------------------------------------------------------------
/configs/pixart_alpha_config/PixArt_xl2_img256_internal.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../PixArt_xl2_internal.py']
2 | data_root = 'data'
3 | image_list_json = ['data_info.json',]
4 |
5 | data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
6 | image_size = 256
7 |
8 | # model setting
9 | model = 'PixArt_XL_2'
10 | fp32_attention = True
11 | load_from = None
12 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
13 | # training setting
14 | eval_sampling_steps = 200
15 |
16 | num_workers=10
17 | train_batch_size = 176 # 32 # max 96 for PixArt-L/4 when grad_checkpoint
18 | num_epochs = 200 # 3
19 | gradient_accumulation_steps = 1
20 | grad_checkpointing = True
21 | gradient_clip = 0.01
22 | optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
23 | lr_schedule_args = dict(num_warmup_steps=1000)
24 |
25 | log_interval = 20
26 | save_model_epochs=5
27 | work_dir = 'output/debug'
28 |
--------------------------------------------------------------------------------
/configs/pixart_alpha_config/PixArt_xl2_img512_internal.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../PixArt_xl2_internal.py']
2 | data_root = 'data'
3 | image_list_json = ['data_info.json',]
4 |
5 | data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
6 | image_size = 512
7 |
8 | # model setting
9 | model = 'PixArt_XL_2'
10 | fp32_attention = True
11 | load_from = None
12 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
13 | pe_interpolation = 1.0
14 |
15 | # training setting
16 | use_fsdp=False # if use FSDP mode
17 | num_workers=10
18 | train_batch_size = 38 # 32
19 | num_epochs = 200 # 3
20 | gradient_accumulation_steps = 1
21 | grad_checkpointing = True
22 | gradient_clip = 0.01
23 | optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
24 | lr_schedule_args = dict(num_warmup_steps=1000)
25 |
26 | eval_sampling_steps = 200
27 | log_interval = 20
28 | save_model_epochs=1
29 | work_dir = 'output/debug'
30 |
--------------------------------------------------------------------------------
/configs/pixart_alpha_config/PixArt_xl2_img512_internalms.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../PixArt_xl2_internal.py']
2 | data_root = 'data'
3 | image_list_json = ['data_info.json',]
4 |
5 | data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
6 | image_size = 512
7 |
8 | # model setting
9 | model = 'PixArtMS_XL_2' # model for multi-scale training
10 | fp32_attention = True
11 | load_from = None
12 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
13 | aspect_ratio_type = 'ASPECT_RATIO_512' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
14 | multi_scale = True # if use multiscale dataset model training
15 | pe_interpolation = 1.0
16 |
17 | # training setting
18 | num_workers=10
19 | train_batch_size = 40 # max 40 for PixArt-xL/2 when grad_checkpoint
20 | num_epochs = 20 # 3
21 | gradient_accumulation_steps = 1
22 | grad_checkpointing = True
23 | gradient_clip = 0.01
24 | optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
25 | lr_schedule_args = dict(num_warmup_steps=1000)
26 | save_model_epochs=1
27 | save_model_steps=2000
28 |
29 | log_interval = 20
30 | eval_sampling_steps = 200
31 | work_dir = 'output/debug'
32 |
--------------------------------------------------------------------------------
/configs/pixart_app_config/PixArt-DMD_xl2_img512_internalms.py:
--------------------------------------------------------------------------------
1 | # Config for PixArt-DMD
2 | _base_ = ['../PixArt_xl2_internal.py']
3 | data_root = 'pixart-sigma-toy-dataset'
4 |
5 | image_list_json = ['data_info.json']
6 |
7 | data = dict(
8 | type='DMD', root='InternData', image_list_json=image_list_json, transform='default_train',
9 | load_vae_feat=True, load_t5_feat=True
10 | )
11 | image_size = 512
12 |
13 | # model setting
14 | model = 'PixArtMS_XL_2' # model for multi-scale training
15 | fp32_attention = True
16 | load_from = "PixArt-alpha/PixArt-XL-2-512x512"
17 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
18 | tiny_vae_pretrained = "output/pretrained_models/tinyvae"
19 | aspect_ratio_type = 'ASPECT_RATIO_512'
20 | multi_scale = True # if use multiscale dataset model training
21 | pe_interpolation = 1.0
22 |
23 | # training setting
24 | num_workers = 10
25 | train_batch_size = 1 # max 40 for PixArt-xL/2 when grad_checkpoint
26 | num_epochs = 10 # 3
27 | gradient_accumulation_steps = 1
28 | grad_checkpointing = True
29 | gradient_clip = 0.01
30 | optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
31 | lr_schedule_args = dict(num_warmup_steps=1000)
32 |
33 | log_interval = 20
34 | save_model_epochs=1
35 | save_model_steps=2000
36 | work_dir = 'output/debug'
37 |
--------------------------------------------------------------------------------
/configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../PixArt_xl2_internal.py']
2 | data_root = 'pixart-sigma-toy-dataset'
3 | image_list_json = ['data_info.json']
4 |
5 | data = dict(
6 | type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
7 | load_vae_feat=False, load_t5_feat=False
8 | )
9 | image_size = 1024
10 |
11 | # model setting
12 | model = 'PixArtMS_XL_2'
13 | mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16']
14 | fp32_attention = True
15 | load_from = None
16 | resume_from = None
17 | vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
18 | aspect_ratio_type = 'ASPECT_RATIO_1024'
19 | multi_scale = True # if use multiscale dataset model training
20 | pe_interpolation = 2.0
21 |
22 | # training setting
23 | num_workers = 10
24 | train_batch_size = 2 # 3 for w.o feature extraction; 12 for feature extraction
25 | num_epochs = 2 # 3
26 | gradient_accumulation_steps = 1
27 | grad_checkpointing = True
28 | gradient_clip = 0.01
29 | optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
30 | lr_schedule_args = dict(num_warmup_steps=1000)
31 |
32 | eval_sampling_steps = 500
33 | visualize = True
34 | log_interval = 20
35 | save_model_epochs = 1
36 | save_model_steps = 1000
37 | work_dir = 'output/debug'
38 |
39 | # pixart-sigma
40 | scale_factor = 0.13025
41 | real_prompt_ratio = 0.5
42 | model_max_length = 300
43 | class_dropout_prob = 0.1
44 |
45 | qk_norm = False
46 | skip_step = 0 # skip steps during data loading
47 |
--------------------------------------------------------------------------------
/configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms_kvcompress.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../PixArt_xl2_internal.py']
2 | data_root = 'data'
3 | image_list_json = ['data_info.json']
4 |
5 | data = dict(
6 | type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
7 | load_vae_feat=False, load_t5_feat=False
8 | )
9 | image_size = 1024
10 |
11 | # model setting
12 | model = 'PixArtMS_XL_2'
13 | mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16']
14 | fp32_attention = True
15 | load_from = None
16 | resume_from = None
17 | vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
18 | aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
19 | multi_scale = True # if use multiscale dataset model training
20 | pe_interpolation = 2.0
21 |
22 | # training setting
23 | num_workers = 10
24 | train_batch_size = 4 # 16
25 | num_epochs = 2 # 3
26 | gradient_accumulation_steps = 1
27 | grad_checkpointing = True
28 | gradient_clip = 0.01
29 | optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
30 | lr_schedule_args = dict(num_warmup_steps=500)
31 |
32 | eval_sampling_steps = 250
33 | visualize = True
34 | log_interval = 10
35 | save_model_epochs = 1
36 | save_model_steps = 1000
37 | work_dir = 'output/debug'
38 |
39 | # pixart-sigma
40 | scale_factor = 0.13025
41 | real_prompt_ratio = 0.5
42 | model_max_length = 300
43 | class_dropout_prob = 0.1
44 | kv_compress = True
45 | kv_compress_config = {
46 | 'sampling': 'conv', # ['conv', 'uniform', 'ave']
47 | 'scale_factor': 2,
48 | 'kv_compress_layer': [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27],
49 | }
50 | qk_norm = False
51 | skip_step = 0 # skip steps during data loading
52 |
--------------------------------------------------------------------------------
/configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_lcm.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../PixArt_xl2_internal.py']
2 | data_root = 'pixart-sigma-toy-dataset'
3 | image_list_json = ['data_info.json']
4 |
5 | data = dict(
6 | type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
7 | load_vae_feat=True, load_t5_feat=True,
8 | )
9 | image_size = 1024
10 |
11 | # model setting
12 | model = 'PixArtMS_XL_2' # model for multi-scale training
13 | fp32_attention = False
14 | load_from = None
15 | resume_from = None
16 | vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
17 | aspect_ratio_type = 'ASPECT_RATIO_1024'
18 | multi_scale = True # if use multiscale dataset model training
19 | pe_interpolation = 2.0
20 |
21 | # training setting
22 | num_workers = 4
23 | train_batch_size = 12 # max 12 for PixArt-xL/2 when grad_checkpoint
24 | num_epochs = 10 # 3
25 | gradient_accumulation_steps = 1
26 | grad_checkpointing = True
27 | gradient_clip = 0.01
28 | optimizer = dict(type='CAMEWrapper', lr=1e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
29 | lr_schedule_args = dict(num_warmup_steps=100)
30 | save_model_epochs = 10
31 | save_model_steps = 1000
32 | valid_num = 0 # take as valid aspect-ratio when sample number >= valid_num
33 |
34 | log_interval = 10
35 | eval_sampling_steps = 5
36 | visualize = True
37 | work_dir = 'output/debug'
38 |
39 | # pixart-sigma
40 | scale_factor = 0.13025
41 | real_prompt_ratio = 0.5
42 | model_max_length = 300
43 | class_dropout_prob = 0.1
44 |
45 | # LCM
46 | loss_type = 'huber'
47 | huber_c = 0.001
48 | num_ddim_timesteps = 50
49 | w_max = 15.0
50 | w_min = 3.0
51 | ema_decay = 0.95
52 | cfg_scale = 4.5
53 |
--------------------------------------------------------------------------------
/configs/pixart_sigma_config/PixArt_sigma_xl2_img256_internal.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../PixArt_xl2_internal.py']
2 | data_root = 'pixart-sigma-toy-dataset'
3 | image_list_json = ['data_info.json']
4 |
5 | data = dict(
6 | type='InternalDataSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
7 | load_vae_feat=False, load_t5_feat=False,
8 | )
9 | image_size = 256
10 |
11 | # model setting
12 | model = 'PixArt_XL_2'
13 | mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16']
14 | fp32_attention = True
15 | load_from = "output/pretrained_models/PixArt-Sigma-XL-2-256x256.pth" # https://huggingface.co/PixArt-alpha/PixArt-Sigma
16 | resume_from = None
17 | vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
18 | multi_scale = False # if use multiscale dataset model training
19 | pe_interpolation = 0.5
20 |
21 | # training setting
22 | num_workers = 10
23 | train_batch_size = 64 # 64 as default
24 | num_epochs = 200 # 3
25 | gradient_accumulation_steps = 1
26 | grad_checkpointing = True
27 | gradient_clip = 0.01
28 | optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
29 | lr_schedule_args = dict(num_warmup_steps=1000)
30 |
31 | eval_sampling_steps = 500
32 | log_interval = 20
33 | save_model_epochs = 5
34 | save_model_steps = 2500
35 | work_dir = 'output/debug'
36 |
37 | # pixart-sigma
38 | scale_factor = 0.13025
39 | real_prompt_ratio = 0.5
40 | model_max_length = 300
41 | class_dropout_prob = 0.1
42 |
--------------------------------------------------------------------------------
/configs/pixart_sigma_config/PixArt_sigma_xl2_img2K_internalms_kvcompress.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../PixArt_xl2_internal.py']
2 | data_root = 'data'
3 | image_list_json = ['data_info.json']
4 |
5 | data = dict(
6 | type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
7 | load_vae_feat=False, load_t5_feat=False
8 | )
9 | image_size = 2048
10 |
11 | # model setting
12 | model = 'PixArtMS_XL_2'
13 | mixed_precision = 'fp16'
14 | fp32_attention = True
15 | load_from = None
16 | resume_from = None
17 | vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
18 | aspect_ratio_type = 'ASPECT_RATIO_2048' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
19 | multi_scale = True # if use multiscale dataset model training
20 | pe_interpolation = 4.0
21 |
22 | # training setting
23 | num_workers = 10
24 | train_batch_size = 4 # 48
25 | num_epochs = 10 # 3
26 | gradient_accumulation_steps = 1
27 | grad_checkpointing = True
28 | gradient_clip = 0.01
29 | optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
30 | lr_schedule_args = dict(num_warmup_steps=100)
31 |
32 | eval_sampling_steps = 100
33 | visualize = True
34 | log_interval = 10
35 | save_model_epochs = 10
36 | save_model_steps = 100
37 | work_dir = 'output/debug'
38 |
39 | # pixart-sigma
40 | scale_factor = 0.13025
41 | real_prompt_ratio = 0.5
42 | model_max_length = 300
43 | class_dropout_prob = 0.1
44 | kv_compress = False
45 | kv_compress_config = {
46 | 'sampling': 'conv',
47 | 'scale_factor': 2,
48 | 'kv_compress_layer': [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27],
49 | }
50 |
--------------------------------------------------------------------------------
/configs/pixart_sigma_config/PixArt_sigma_xl2_img512_internalms.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../PixArt_xl2_internal.py']
2 | data_root = 'pixart-sigma-toy-dataset'
3 | image_list_json = ['data_info.json']
4 |
5 | data = dict(
6 | type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
7 | load_vae_feat=False, load_t5_feat=False,
8 | )
9 | image_size = 512
10 |
11 | # model setting
12 | model = 'PixArtMS_XL_2'
13 | mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16']
14 | fp32_attention = True
15 | load_from = "output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth" # https://huggingface.co/PixArt-alpha/PixArt-Sigma
16 | resume_from = None
17 | vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
18 | aspect_ratio_type = 'ASPECT_RATIO_512'
19 | multi_scale = True # if use multiscale dataset model training
20 | pe_interpolation = 1.0
21 |
22 | # training setting
23 | num_workers = 10
24 | train_batch_size = 2 # 48 as default
25 | num_epochs = 10 # 3
26 | gradient_accumulation_steps = 1
27 | grad_checkpointing = True
28 | gradient_clip = 0.01
29 | optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
30 | lr_schedule_args = dict(num_warmup_steps=1000)
31 |
32 | eval_sampling_steps = 500
33 | visualize = True
34 | log_interval = 20
35 | save_model_epochs = 5
36 | save_model_steps = 2500
37 | work_dir = 'output/debug'
38 |
39 | # pixart-sigma
40 | scale_factor = 0.13025
41 | real_prompt_ratio = 0.5
42 | model_max_length = 300
43 | class_dropout_prob = 0.1
44 |
--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from .iddpm import IDDPM
7 | from .dpm_solver import DPMS
8 | from .sa_sampler import SASolverSampler
9 |
--------------------------------------------------------------------------------
/diffusion/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import *
2 | from .transforms import get_transform
3 |
--------------------------------------------------------------------------------
/diffusion/data/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | from mmcv import Registry, build_from_cfg
5 | from torch.utils.data import DataLoader
6 |
7 | from diffusion.data.transforms import get_transform
8 | from diffusion.utils.logger import get_root_logger
9 |
10 | DATASETS = Registry('datasets')
11 |
12 | DATA_ROOT = '/cache/data'
13 |
14 |
15 | def set_data_root(data_root):
16 | global DATA_ROOT
17 | DATA_ROOT = data_root
18 |
19 |
20 | def get_data_path(data_dir):
21 | if os.path.isabs(data_dir):
22 | return data_dir
23 | global DATA_ROOT
24 | return os.path.join(DATA_ROOT, data_dir)
25 |
26 |
27 | def get_data_root_and_path(data_dir):
28 | if os.path.isabs(data_dir):
29 | return data_dir
30 | global DATA_ROOT
31 | return DATA_ROOT, os.path.join(DATA_ROOT, data_dir)
32 |
33 |
34 | def build_dataset(cfg, resolution=224, **kwargs):
35 | logger = get_root_logger()
36 |
37 | dataset_type = cfg.get('type')
38 | logger.info(f"Constructing dataset {dataset_type}...")
39 | t = time.time()
40 | transform = cfg.pop('transform', 'default_train')
41 | transform = get_transform(transform, resolution)
42 | dataset = build_from_cfg(cfg, DATASETS, default_args=dict(transform=transform, resolution=resolution, **kwargs))
43 | logger.info(f"Dataset {dataset_type} constructed. time: {(time.time() - t):.2f} s, length (use/ori): {len(dataset)}/{dataset.ori_imgs_nums}")
44 | return dataset
45 |
46 |
47 | def build_dataloader(dataset, batch_size=256, num_workers=4, shuffle=True, **kwargs):
48 | if 'batch_sampler' in kwargs:
49 | dataloader = DataLoader(dataset, batch_sampler=kwargs['batch_sampler'], num_workers=num_workers, pin_memory=True)
50 | else:
51 | dataloader = DataLoader(dataset,
52 | batch_size=batch_size,
53 | shuffle=shuffle,
54 | num_workers=num_workers,
55 | pin_memory=True,
56 | **kwargs)
57 | return dataloader
58 |
--------------------------------------------------------------------------------
/diffusion/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .InternalData import InternalData, InternalDataSigma
2 | from .InternalData_ms import InternalDataMS, InternalDataSigma
3 | from .dmd import DMD
4 | from .utils import *
5 |
--------------------------------------------------------------------------------
/diffusion/data/datasets/dmd.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import os
3 | import numpy as np
4 | import glob, torch
5 | from diffusion.data.builder import get_data_root_and_path
6 | from PIL import Image
7 | import json
8 | from diffusers.utils.torch_utils import randn_tensor
9 | from torchvision import transforms as T
10 | from torchvision.datasets.folder import default_loader
11 |
12 |
13 | def read_prompt(prompt_path):
14 | with open(prompt_path, 'r') as f:
15 | prompt_list = f.readlines()
16 | return prompt_list
17 |
18 |
19 | # Dataloader for pixart-dmd model
20 | class DMD(Dataset):
21 | ## rewrite the dataloader to avoid data loading bugs
22 | def __init__(self,
23 | root,
24 | transform=None,
25 | image_list_json='data_info.json',
26 | resolution=512,
27 | load_vae_feat=False,
28 | load_t5_feat=False,
29 | max_samples=None,
30 | max_length=120,
31 | ):
32 | '''
33 | :param root: the root of saving txt features ./data/data/
34 | :param latent_root: the root of saving latent image pairs
35 | :param image_list_json:
36 | :param resolution:
37 | :param max_samples:
38 | :param offset:
39 | :param kwargs:
40 | '''
41 | super().__init__()
42 | DATA_ROOT, root = get_data_root_and_path(root)
43 | self.root = root
44 | self.transform = transform
45 | self.load_vae_feat = load_vae_feat
46 | self.load_t5_feat = load_t5_feat
47 | self.DATA_ROOT = DATA_ROOT
48 | self.ori_imgs_nums = 0
49 | self.resolution = resolution
50 | self.max_samples = max_samples
51 | self.max_lenth = max_length
52 | self.meta_data_clean = []
53 | self.img_samples = []
54 | self.txt_feat_samples = []
55 | self.noise_samples = []
56 | self.vae_feat_samples = []
57 | self.gen_image_samples = []
58 | self.txt_samples = []
59 |
60 | image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json]
61 | for json_file in image_list_json:
62 | meta_data = self.load_json(os.path.join(self.root, json_file))
63 | self.ori_imgs_nums += len(meta_data)
64 | meta_data_clean = [item for item in meta_data if item['ratio'] <= 4.5]
65 | self.meta_data_clean.extend(meta_data_clean)
66 | self.img_samples.extend([
67 | os.path.join(self.root.replace('InternData', 'InternImgs'), item['path']) for item in meta_data_clean
68 | ])
69 | self.gen_image_samples.extend([
70 | os.path.join(self.root, 'InternImgs_DMD_images', item['path']) for item in meta_data_clean
71 | ])
72 | self.txt_samples.extend([item['prompt'] for item in meta_data_clean])
73 | self.txt_feat_samples.extend([
74 | os.path.join(
75 | self.root,
76 | 'caption_features_new',
77 | item['path'].rsplit('/', 1)[-1].replace('.png', '.npz')
78 | ) for item in meta_data_clean
79 | ])
80 | self.noise_samples.extend([
81 | os.path.join(
82 | self.root,
83 | 'InternImgs_DMD_noises',
84 | item['path'].rsplit('/', 1)[-1].replace('.png', '.npy')
85 | ) for item in meta_data_clean
86 | ])
87 | self.vae_feat_samples.extend(
88 | [
89 | os.path.join(
90 | self.root,
91 | 'InternImgs_DMD_latents',
92 | item['path'].rsplit('/', 1)[-1].replace('.png', '.npy')
93 | ) for item in meta_data_clean
94 | ])
95 |
96 | # Set loader and extensions
97 | if load_vae_feat:
98 | self.transform = None
99 | self.loader = self.latent_feat_loader
100 | else:
101 | self.loader = default_loader
102 |
103 | def __len__(self):
104 | return min(self.max_samples, len(self.img_samples))
105 |
106 | @staticmethod
107 | def vae_feat_loader(path):
108 | # [mean, std]
109 | mean, std = torch.from_numpy(np.load(path)).chunk(2)
110 | sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype)
111 | return mean + std * sample
112 |
113 | @staticmethod
114 | def latent_feat_loader(path):
115 | return torch.from_numpy(np.load(path))
116 |
117 | def load_ori_img(self, img_path):
118 | # 加载图像并转换为Tensor
119 | transform = T.Compose([
120 | T.Resize(512), # Image.BICUBIC
121 | T.CenterCrop(512),
122 | T.ToTensor(),
123 | ])
124 | img = transform(Image.open(img_path))
125 | img = img * 2.0 - 1.0
126 | return img
127 |
128 | def load_json(self, file_path):
129 | with open(file_path, 'r') as f:
130 | meta_data = json.load(f)
131 |
132 | return meta_data
133 |
134 | def getdata(self, index):
135 | img_gt_path = self.img_samples[index]
136 | gen_img_path = self.gen_image_samples[index]
137 | npz_path = self.txt_feat_samples[index]
138 | txt = self.txt_samples[index]
139 | npy_path = self.vae_feat_samples[index]
140 | data_info = {
141 | 'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32),
142 | 'aspect_ratio': torch.tensor(1.)
143 | }
144 |
145 | if self.load_vae_feat:
146 | gen_img = self.loader(npy_path)
147 | else:
148 | gen_img = self.loader(gen_img_path)
149 |
150 | attention_mask = torch.ones(1, 1, self.max_lenth) # 1x1xT
151 | if self.load_t5_feat:
152 | txt_info = np.load(npz_path)
153 | txt_fea = torch.from_numpy(txt_info['caption_feature']) # 1xTx4096
154 | if 'attention_mask' in txt_info.keys():
155 | attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
156 | if txt_fea.shape[1] < self.max_lenth:
157 | txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1)
158 | attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1)
159 | elif txt_fea.shape[1] > self.max_lenth:
160 | txt_fea = txt_fea[:, :self.max_lenth]
161 | attention_mask = attention_mask[:, :, :self.max_lenth]
162 | else:
163 | txt_fea = txt
164 |
165 | noise = torch.from_numpy(np.load(self.noise_samples[index]))
166 | img_gt = self.load_ori_img(img_gt_path)
167 |
168 | return {'noise': noise,
169 | 'base_latent': gen_img,
170 | 'latent_path': self.vae_feat_samples[index],
171 | 'text': txt,
172 | 'data_info': data_info,
173 | 'txt_fea': txt_fea,
174 | 'attention_mask': attention_mask,
175 | 'img_gt': img_gt}
176 |
177 | def __getitem__(self, idx):
178 | data = self.getdata(idx)
179 | return data
180 | # for _ in range(20):
181 | # try:
182 | # data = self.getdata(idx)
183 | # return data
184 | # except Exception as e:
185 | # print(f"Error details: {str(e)}")
186 | # idx = np.random.randint(len(self))
187 | # raise RuntimeError('Too many bad data.')
188 |
--------------------------------------------------------------------------------
/diffusion/data/datasets/utils.py:
--------------------------------------------------------------------------------
1 |
2 | ASPECT_RATIO_2880 = {
3 | '0.25': [1408.0, 5760.0], '0.26': [1408.0, 5568.0], '0.27': [1408.0, 5376.0], '0.28': [1408.0, 5184.0],
4 | '0.32': [1600.0, 4992.0], '0.33': [1600.0, 4800.0], '0.34': [1600.0, 4672.0], '0.4': [1792.0, 4480.0],
5 | '0.42': [1792.0, 4288.0], '0.47': [1920.0, 4096.0], '0.49': [1920.0, 3904.0], '0.51': [1920.0, 3776.0],
6 | '0.55': [2112.0, 3840.0], '0.59': [2112.0, 3584.0], '0.68': [2304.0, 3392.0], '0.72': [2304.0, 3200.0],
7 | '0.78': [2496.0, 3200.0], '0.83': [2496.0, 3008.0], '0.89': [2688.0, 3008.0], '0.93': [2688.0, 2880.0],
8 | '1.0': [2880.0, 2880.0], '1.07': [2880.0, 2688.0], '1.12': [3008.0, 2688.0], '1.21': [3008.0, 2496.0],
9 | '1.28': [3200.0, 2496.0], '1.39': [3200.0, 2304.0], '1.47': [3392.0, 2304.0], '1.7': [3584.0, 2112.0],
10 | '1.82': [3840.0, 2112.0], '2.03': [3904.0, 1920.0], '2.13': [4096.0, 1920.0], '2.39': [4288.0, 1792.0],
11 | '2.5': [4480.0, 1792.0], '2.92': [4672.0, 1600.0], '3.0': [4800.0, 1600.0], '3.12': [4992.0, 1600.0],
12 | '3.68': [5184.0, 1408.0], '3.82': [5376.0, 1408.0], '3.95': [5568.0, 1408.0], '4.0': [5760.0, 1408.0]
13 | }
14 |
15 | ASPECT_RATIO_2048 = {
16 | '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0], '0.27': [1024.0, 3840.0], '0.28': [1024.0, 3712.0],
17 | '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0],
18 | '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0],
19 | '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0],
20 | '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0],
21 | '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0],
22 | '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0],
23 | '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0],
24 | '2.5': [3200.0, 1280.0], '2.89': [3328.0, 1152.0], '3.0': [3456.0, 1152.0], '3.11': [3584.0, 1152.0],
25 | '3.62': [3712.0, 1024.0], '3.75': [3840.0, 1024.0], '3.88': [3968.0, 1024.0], '4.0': [4096.0, 1024.0]
26 | }
27 |
28 | ASPECT_RATIO_1024 = {
29 | '0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.],
30 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
31 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
32 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
33 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
34 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
35 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
36 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
37 | '2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.],
38 | '3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.],
39 | }
40 |
41 | ASPECT_RATIO_512 = {
42 | '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
43 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
44 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
45 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
46 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
47 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
48 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
49 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
50 | '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
51 | '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
52 | }
53 |
54 | ASPECT_RATIO_256 = {
55 | '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0],
56 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
57 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
58 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
59 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
60 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
61 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
62 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
63 | '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0],
64 | '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0]
65 | }
66 |
67 | ASPECT_RATIO_256_TEST = {
68 | '0.25': [128.0, 512.0], '0.28': [128.0, 464.0],
69 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
70 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
71 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
72 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
73 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
74 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
75 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
76 | '2.5': [400.0, 160.0], '3.0': [432.0, 144.0],
77 | '4.0': [512.0, 128.0]
78 | }
79 |
80 | ASPECT_RATIO_512_TEST = {
81 | '0.25': [256.0, 1024.0], '0.28': [256.0, 928.0],
82 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
83 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
84 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
85 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
86 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
87 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
88 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
89 | '2.5': [800.0, 320.0], '3.0': [864.0, 288.0],
90 | '4.0': [1024.0, 256.0]
91 | }
92 |
93 | ASPECT_RATIO_1024_TEST = {
94 | '0.25': [512., 2048.], '0.28': [512., 1856.],
95 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
96 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
97 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
98 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
99 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
100 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
101 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
102 | '2.5': [1600., 640.], '3.0': [1728., 576.],
103 | '4.0': [2048., 512.],
104 | }
105 |
106 | ASPECT_RATIO_2048_TEST = {
107 | '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0],
108 | '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0],
109 | '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0],
110 | '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0],
111 | '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0],
112 | '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0],
113 | '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0],
114 | '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0],
115 | '2.5': [3200.0, 1280.0], '3.0': [3456.0, 1152.0],
116 | '4.0': [4096.0, 1024.0]
117 | }
118 |
119 | ASPECT_RATIO_2880_TEST = {
120 | '0.25': [2048.0, 8192.0], '0.26': [2048.0, 7936.0],
121 | '0.32': [2304.0, 7168.0], '0.33': [2304.0, 6912.0], '0.35': [2304.0, 6656.0], '0.4': [2560.0, 6400.0],
122 | '0.42': [2560.0, 6144.0], '0.48': [2816.0, 5888.0], '0.5': [2816.0, 5632.0], '0.52': [2816.0, 5376.0],
123 | '0.57': [3072.0, 5376.0], '0.6': [3072.0, 5120.0], '0.68': [3328.0, 4864.0], '0.72': [3328.0, 4608.0],
124 | '0.78': [3584.0, 4608.0], '0.82': [3584.0, 4352.0], '0.88': [3840.0, 4352.0], '0.94': [3840.0, 4096.0],
125 | '1.0': [4096.0, 4096.0], '1.07': [4096.0, 3840.0], '1.13': [4352.0, 3840.0], '1.21': [4352.0, 3584.0],
126 | '1.29': [4608.0, 3584.0], '1.38': [4608.0, 3328.0], '1.46': [4864.0, 3328.0], '1.67': [5120.0, 3072.0],
127 | '1.75': [5376.0, 3072.0], '2.0': [5632.0, 2816.0], '2.09': [5888.0, 2816.0], '2.4': [6144.0, 2560.0],
128 | '2.5': [6400.0, 2560.0], '3.0': [6912.0, 2304.0],
129 | '4.0': [8192.0, 2048.0],
130 | }
131 |
132 | def get_chunks(lst, n):
133 | for i in range(0, len(lst), n):
134 | yield lst[i:i + n]
135 |
--------------------------------------------------------------------------------
/diffusion/data/transforms.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as T
2 |
3 | TRANSFORMS = dict()
4 |
5 |
6 | def register_transform(transform):
7 | name = transform.__name__
8 | if name in TRANSFORMS:
9 | raise RuntimeError(f'Transform {name} has already registered.')
10 | TRANSFORMS.update({name: transform})
11 |
12 |
13 | def get_transform(type, resolution):
14 | transform = TRANSFORMS[type](resolution)
15 | transform = T.Compose(transform)
16 | transform.image_size = resolution
17 | return transform
18 |
19 |
20 | @register_transform
21 | def default_train(n_px):
22 | transform = [
23 | T.Lambda(lambda img: img.convert('RGB')),
24 | T.Resize(n_px), # Image.BICUBIC
25 | T.CenterCrop(n_px),
26 | # T.RandomHorizontalFlip(),
27 | T.ToTensor(),
28 | T.Normalize([.5], [.5]),
29 | ]
30 | return transform
31 |
--------------------------------------------------------------------------------
/diffusion/dpm_solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .model import gaussian_diffusion as gd
3 | from .model.dpm_solver import model_wrapper, DPM_Solver, NoiseScheduleVP
4 |
5 |
6 | def DPMS(
7 | model,
8 | condition,
9 | uncondition,
10 | cfg_scale,
11 | model_type='noise', # or "x_start" or "v" or "score"
12 | noise_schedule="linear",
13 | guidance_type='classifier-free',
14 | model_kwargs={},
15 | diffusion_steps=1000
16 | ):
17 | betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
18 |
19 | ## 1. Define the noise schedule.
20 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
21 |
22 | ## 2. Convert your discrete-time `model` to the continuous-time
23 | ## noise prediction model. Here is an example for a diffusion model
24 | ## `model` with the noise prediction type ("noise") .
25 | model_fn = model_wrapper(
26 | model,
27 | noise_schedule,
28 | model_type=model_type,
29 | model_kwargs=model_kwargs,
30 | guidance_type=guidance_type,
31 | condition=condition,
32 | unconditional_condition=uncondition,
33 | guidance_scale=cfg_scale,
34 | )
35 | ## 3. Define dpm-solver and sample by multistep DPM-Solver.
36 | return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
--------------------------------------------------------------------------------
/diffusion/iddpm.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 | from diffusion.model.respace import SpacedDiffusion, space_timesteps
6 | from .model import gaussian_diffusion as gd
7 |
8 |
9 | def IDDPM(
10 | timestep_respacing,
11 | noise_schedule="linear",
12 | use_kl=False,
13 | sigma_small=False,
14 | predict_xstart=False,
15 | learn_sigma=True,
16 | pred_sigma=True,
17 | rescale_learned_sigmas=False,
18 | diffusion_steps=1000,
19 | snr=False,
20 | return_startx=False,
21 | ):
22 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
23 | if use_kl:
24 | loss_type = gd.LossType.RESCALED_KL
25 | elif rescale_learned_sigmas:
26 | loss_type = gd.LossType.RESCALED_MSE
27 | else:
28 | loss_type = gd.LossType.MSE
29 | if timestep_respacing is None or timestep_respacing == "":
30 | timestep_respacing = [diffusion_steps]
31 | return SpacedDiffusion(
32 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
33 | betas=betas,
34 | model_mean_type=(
35 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
36 | ),
37 | model_var_type=(
38 | ((
39 | gd.ModelVarType.FIXED_LARGE
40 | if not sigma_small
41 | else gd.ModelVarType.FIXED_SMALL
42 | )
43 | if not learn_sigma
44 | else gd.ModelVarType.LEARNED_RANGE
45 | )
46 | if pred_sigma
47 | else None
48 | ),
49 | loss_type=loss_type,
50 | snr=snr,
51 | return_startx=return_startx,
52 | # rescale_timesteps=rescale_timesteps,
53 | )
--------------------------------------------------------------------------------
/diffusion/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .nets import *
2 |
--------------------------------------------------------------------------------
/diffusion/model/builder.py:
--------------------------------------------------------------------------------
1 | from mmcv import Registry
2 |
3 | from diffusion.model.utils import set_grad_checkpoint
4 |
5 | MODELS = Registry('models')
6 |
7 |
8 | def build_model(cfg, use_grad_checkpoint=False, use_fp32_attention=False, gc_step=1, **kwargs):
9 | if isinstance(cfg, str):
10 | cfg = dict(type=cfg)
11 | model = MODELS.build(cfg, default_args=kwargs)
12 | if use_grad_checkpoint:
13 | set_grad_checkpoint(model, use_fp32_attention=use_fp32_attention, gc_step=gc_step)
14 | return model
15 |
--------------------------------------------------------------------------------
/diffusion/model/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a continuous Gaussian distribution.
50 | :param x: the targets
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | centered_x = x - means
56 | inv_stdv = th.exp(-log_scales)
57 | normalized_x = centered_x * inv_stdv
58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59 | return log_probs
60 |
61 |
62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63 | """
64 | Compute the log-likelihood of a Gaussian distribution discretizing to a
65 | given image.
66 | :param x: the target images. It is assumed that this was uint8 values,
67 | rescaled to the range [-1, 1].
68 | :param means: the Gaussian mean Tensor.
69 | :param log_scales: the Gaussian log stddev Tensor.
70 | :return: a tensor like x of log probabilities (in nats).
71 | """
72 | assert x.shape == means.shape == log_scales.shape
73 | centered_x = x - means
74 | inv_stdv = th.exp(-log_scales)
75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76 | cdf_plus = approx_standard_normal_cdf(plus_in)
77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78 | cdf_min = approx_standard_normal_cdf(min_in)
79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81 | cdf_delta = cdf_plus - cdf_min
82 | log_probs = th.where(
83 | x < -0.999,
84 | log_cdf_plus,
85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86 | )
87 | assert log_probs.shape == x.shape
88 | return log_probs
89 |
--------------------------------------------------------------------------------
/diffusion/model/edm_sample.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from tqdm import tqdm
4 |
5 | from diffusion.model.utils import *
6 |
7 |
8 | # ----------------------------------------------------------------------------
9 | # Proposed EDM sampler (Algorithm 2).
10 |
11 | def edm_sampler(
12 | net, latents, class_labels=None, cfg_scale=None, randn_like=torch.randn_like,
13 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
14 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **kwargs
15 | ):
16 | # Adjust noise levels based on what's supported by the network.
17 | sigma_min = max(sigma_min, net.sigma_min)
18 | sigma_max = min(sigma_max, net.sigma_max)
19 |
20 | # Time step discretization.
21 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
22 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
23 | sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
24 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
25 |
26 | # Main sampling loop.
27 | x_next = latents.to(torch.float64) * t_steps[0]
28 | for i, (t_cur, t_next) in tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:])))): # 0, ..., N-1
29 | x_cur = x_next
30 |
31 | # Increase noise temporarily.
32 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
33 | t_hat = net.round_sigma(t_cur + gamma * t_cur)
34 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
35 |
36 | # Euler step.
37 | denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
38 | d_cur = (x_hat - denoised) / t_hat
39 | x_next = x_hat + (t_next - t_hat) * d_cur
40 |
41 | # Apply 2nd order correction.
42 | if i < num_steps - 1:
43 | denoised = net(x_next.float(), t_next, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
44 | d_prime = (x_next - denoised) / t_next
45 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
46 |
47 | return x_next
48 |
49 |
50 | # ----------------------------------------------------------------------------
51 | # Generalized ablation sampler, representing the superset of all sampling
52 | # methods discussed in the paper.
53 |
54 | def ablation_sampler(
55 | net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like,
56 | num_steps=18, sigma_min=None, sigma_max=None, rho=7,
57 | solver='heun', discretization='edm', schedule='linear', scaling='none',
58 | epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
59 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
60 | ):
61 | assert solver in ['euler', 'heun']
62 | assert discretization in ['vp', 've', 'iddpm', 'edm']
63 | assert schedule in ['vp', 've', 'linear']
64 | assert scaling in ['vp', 'none']
65 |
66 | # Helper functions for VP & VE noise level schedules.
67 | vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
68 | vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
69 | vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (
70 | sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
71 | ve_sigma = lambda t: t.sqrt()
72 | ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
73 | ve_sigma_inv = lambda sigma: sigma ** 2
74 |
75 | # Select default noise level range based on the specified time step discretization.
76 | if sigma_min is None:
77 | vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s)
78 | sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
79 | if sigma_max is None:
80 | vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1)
81 | sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
82 |
83 | # Adjust noise levels based on what's supported by the network.
84 | sigma_min = max(sigma_min, net.sigma_min)
85 | sigma_max = min(sigma_max, net.sigma_max)
86 |
87 | # Compute corresponding betas for VP.
88 | vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
89 | vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
90 |
91 | # Define time steps in terms of noise level.
92 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
93 | if discretization == 'vp':
94 | orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
95 | sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
96 | elif discretization == 've':
97 | orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
98 | sigma_steps = ve_sigma(orig_t_steps)
99 | elif discretization == 'iddpm':
100 | u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
101 | alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
102 | for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
103 | u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
104 | u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
105 | sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
106 | else:
107 | assert discretization == 'edm'
108 | sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
109 | sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
110 |
111 | # Define noise level schedule.
112 | if schedule == 'vp':
113 | sigma = vp_sigma(vp_beta_d, vp_beta_min)
114 | sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
115 | sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
116 | elif schedule == 've':
117 | sigma = ve_sigma
118 | sigma_deriv = ve_sigma_deriv
119 | sigma_inv = ve_sigma_inv
120 | else:
121 | assert schedule == 'linear'
122 | sigma = lambda t: t
123 | sigma_deriv = lambda t: 1
124 | sigma_inv = lambda sigma: sigma
125 |
126 | # Define scaling schedule.
127 | if scaling == 'vp':
128 | s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
129 | s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
130 | else:
131 | assert scaling == 'none'
132 | s = lambda t: 1
133 | s_deriv = lambda t: 0
134 |
135 | # Compute final time steps based on the corresponding noise levels.
136 | t_steps = sigma_inv(net.round_sigma(sigma_steps))
137 | t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
138 |
139 | # Main sampling loop.
140 | t_next = t_steps[0]
141 | x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
142 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
143 | x_cur = x_next
144 |
145 | # Increase noise temporarily.
146 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
147 | t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
148 | x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(
149 | t_hat) * S_noise * randn_like(x_cur)
150 |
151 | # Euler step.
152 | h = t_next - t_hat
153 | denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to(
154 | torch.float64)
155 | d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(
156 | t_hat) / sigma(t_hat) * denoised
157 | x_prime = x_hat + alpha * h * d_cur
158 | t_prime = t_hat + alpha * h
159 |
160 | # Apply 2nd order correction.
161 | if solver == 'euler' or i == num_steps - 1:
162 | x_next = x_hat + h * d_cur
163 | else:
164 | assert solver == 'heun'
165 | denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to(
166 | torch.float64)
167 | d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(
168 | t_prime) * s(t_prime) / sigma(t_prime) * denoised
169 | x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
170 |
171 | return x_next
172 |
--------------------------------------------------------------------------------
/diffusion/model/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from diffusion.model.llava.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig
--------------------------------------------------------------------------------
/diffusion/model/llava/mpt/blocks.py:
--------------------------------------------------------------------------------
1 | """GPT Blocks used for the GPT Model."""
2 | from typing import Dict, Optional, Tuple
3 | import torch
4 | import torch.nn as nn
5 | from .attention import ATTN_CLASS_REGISTRY
6 | from .norm import NORM_CLASS_REGISTRY
7 |
8 | class MPTMLP(nn.Module):
9 |
10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
11 | super().__init__()
12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
13 | self.act = nn.GELU(approximate='none')
14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
15 | self.down_proj._is_residual = True
16 |
17 | def forward(self, x):
18 | return self.down_proj(self.act(self.up_proj(x)))
19 |
20 | class MPTBlock(nn.Module):
21 |
22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
23 | del kwargs
24 | super().__init__()
25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27 | self.norm_1 = norm_class(d_model, device=device)
28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
29 | self.norm_2 = norm_class(d_model, device=device)
30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop)
32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
33 |
34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35 | a = self.norm_1(x)
36 | (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37 | x = x + self.resid_attn_dropout(b)
38 | m = self.norm_2(x)
39 | n = self.ffn(m)
40 | x = x + self.resid_ffn_dropout(n)
41 | return (x, past_key_value)
--------------------------------------------------------------------------------
/diffusion/model/llava/mpt/configuration_mpt.py:
--------------------------------------------------------------------------------
1 | """A HuggingFace-style model configuration."""
2 | from typing import Dict, Optional, Union
3 | from transformers import PretrainedConfig
4 | attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5 | init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'}
6 |
7 | class MPTConfig(PretrainedConfig):
8 | model_type = 'mpt'
9 |
10 | def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs):
11 | """The MPT configuration class.
12 |
13 | Args:
14 | d_model (int): The size of the embedding dimension of the model.
15 | n_heads (int): The number of attention heads.
16 | n_layers (int): The number of layers in the model.
17 | expansion_ratio (int): The ratio of the up/down scale in the MLP.
18 | max_seq_len (int): The maximum sequence length of the model.
19 | vocab_size (int): The size of the vocabulary.
20 | resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
21 | emb_pdrop (float): The dropout probability for the embedding layer.
22 | learned_pos_emb (bool): Whether to use learned positional embeddings
23 | attn_config (Dict): A dictionary used to configure the model's attention module:
24 | attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention
25 | attn_pdrop (float): The dropout probability for the attention layers.
26 | attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
27 | qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
28 | clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
29 | this value.
30 | softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
31 | use the default scale of ``1/sqrt(d_keys)``.
32 | prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
33 | extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
34 | can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
35 | attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
36 | When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
37 | which sub-sequence each token belongs to.
38 | Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
39 | alibi (bool): Whether to use the alibi bias instead of position embeddings.
40 | alibi_bias_max (int): The maximum value of the alibi bias.
41 | init_device (str): The device to use for parameter initialization.
42 | logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
43 | no_bias (bool): Whether to use bias in all layers.
44 | verbose (int): The verbosity level. 0 is silent.
45 | embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
46 | norm_type (str): choose type of norm to use
47 | multiquery_attention (bool): Whether to use multiquery attention implementation.
48 | use_cache (bool): Whether or not the model should return the last key/values attentions
49 | init_config (Dict): A dictionary used to configure the model initialization:
50 | init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
51 | 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
52 | 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
53 | init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
54 | emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
55 | emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
56 | used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
57 | init_std (float): The standard deviation of the normal distribution used to initialize the model,
58 | if using the baseline_ parameter initialization scheme.
59 | init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
60 | fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
61 | init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
62 | ---
63 | See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
64 | """
65 | self.d_model = d_model
66 | self.n_heads = n_heads
67 | self.n_layers = n_layers
68 | self.expansion_ratio = expansion_ratio
69 | self.max_seq_len = max_seq_len
70 | self.vocab_size = vocab_size
71 | self.resid_pdrop = resid_pdrop
72 | self.emb_pdrop = emb_pdrop
73 | self.learned_pos_emb = learned_pos_emb
74 | self.attn_config = attn_config
75 | self.init_device = init_device
76 | self.logit_scale = logit_scale
77 | self.no_bias = no_bias
78 | self.verbose = verbose
79 | self.embedding_fraction = embedding_fraction
80 | self.norm_type = norm_type
81 | self.use_cache = use_cache
82 | self.init_config = init_config
83 | if 'name' in kwargs:
84 | del kwargs['name']
85 | if 'loss_fn' in kwargs:
86 | del kwargs['loss_fn']
87 | super().__init__(**kwargs)
88 | self._validate_config()
89 |
90 | def _set_config_defaults(self, config, config_defaults):
91 | for (k, v) in config_defaults.items():
92 | if k not in config:
93 | config[k] = v
94 | return config
95 |
96 | def _validate_config(self):
97 | self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
98 | self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
99 | if self.d_model % self.n_heads != 0:
100 | raise ValueError('d_model must be divisible by n_heads')
101 | if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
102 | raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
103 | if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
104 | raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
105 | if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
106 | raise NotImplementedError('prefix_lm only implemented with torch and triton attention.')
107 | if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
108 | raise NotImplementedError('alibi only implemented with torch and triton attention.')
109 | if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
110 | raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')
111 | if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
112 | raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
113 | if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model':
114 | raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
115 | if self.init_config.get('name', None) is None:
116 | raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
117 | if not self.learned_pos_emb and (not self.attn_config['alibi']):
118 | raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.')
--------------------------------------------------------------------------------
/diffusion/model/llava/mpt/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def _cast_if_autocast_enabled(tensor):
4 | if torch.is_autocast_enabled():
5 | if tensor.device.type == 'cuda':
6 | dtype = torch.get_autocast_gpu_dtype()
7 | elif tensor.device.type == 'cpu':
8 | dtype = torch.get_autocast_cpu_dtype()
9 | else:
10 | raise NotImplementedError()
11 | return tensor.to(dtype=dtype)
12 | return tensor
13 |
14 | class LPLayerNorm(torch.nn.LayerNorm):
15 |
16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
18 |
19 | def forward(self, x):
20 | module_device = x.device
21 | downcast_x = _cast_if_autocast_enabled(x)
22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
24 | with torch.autocast(enabled=False, device_type=module_device.type):
25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
26 |
27 | def rms_norm(x, weight=None, eps=1e-05):
28 | output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
29 | if weight is not None:
30 | return output * weight
31 | return output
32 |
33 | class RMSNorm(torch.nn.Module):
34 |
35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
36 | super().__init__()
37 | self.eps = eps
38 | if weight:
39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
40 | else:
41 | self.register_parameter('weight', None)
42 |
43 | def forward(self, x):
44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
45 |
46 | class LPRMSNorm(RMSNorm):
47 |
48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
50 |
51 | def forward(self, x):
52 | downcast_x = _cast_if_autocast_enabled(x)
53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
54 | with torch.autocast(enabled=False, device_type=x.device.type):
55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
--------------------------------------------------------------------------------
/diffusion/model/nets/__init__.py:
--------------------------------------------------------------------------------
1 | from .PixArt import PixArt, PixArt_XL_2
2 | from .PixArtMS import PixArtMS, PixArtMS_XL_2, PixArtMSBlock
--------------------------------------------------------------------------------
/diffusion/model/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | if section_count <= 1:
52 | frac_stride = 1
53 | else:
54 | frac_stride = (size - 1) / (section_count - 1)
55 | cur_idx = 0.0
56 | taken_steps = []
57 | for _ in range(section_count):
58 | taken_steps.append(start_idx + round(cur_idx))
59 | cur_idx += frac_stride
60 | all_steps += taken_steps
61 | start_idx += size
62 | return set(all_steps)
63 |
64 |
65 | class SpacedDiffusion(GaussianDiffusion):
66 | """
67 | A diffusion process which can skip steps in a base diffusion process.
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | def training_losses(
95 | self, model, *args, **kwargs
96 | ): # pylint: disable=signature-differs
97 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
98 |
99 | def training_losses_diffusers(
100 | self, model, *args, **kwargs
101 | ): # pylint: disable=signature-differs
102 | return super().training_losses_diffusers(self._wrap_model(model), *args, **kwargs)
103 |
104 | def condition_mean(self, cond_fn, *args, **kwargs):
105 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
106 |
107 | def condition_score(self, cond_fn, *args, **kwargs):
108 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
109 |
110 | def _wrap_model(self, model):
111 | if isinstance(model, _WrappedModel):
112 | return model
113 | return _WrappedModel(
114 | model, self.timestep_map, self.original_num_steps
115 | )
116 |
117 | def _scale_timesteps(self, t):
118 | # Scaling is done by the wrapped model.
119 | return t
120 |
121 |
122 | class _WrappedModel:
123 | def __init__(self, model, timestep_map, original_num_steps):
124 | self.model = model
125 | self.timestep_map = timestep_map
126 | # self.rescale_timesteps = rescale_timesteps
127 | self.original_num_steps = original_num_steps
128 |
129 | def __call__(self, x, timestep, **kwargs):
130 | map_tensor = th.tensor(self.timestep_map, device=timestep.device, dtype=timestep.dtype)
131 | new_ts = map_tensor[timestep]
132 | # if self.rescale_timesteps:
133 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
134 | return self.model(x, timestep=new_ts, **kwargs)
135 |
--------------------------------------------------------------------------------
/diffusion/model/t5.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import re
4 | import html
5 | import urllib.parse as ul
6 |
7 | import ftfy
8 | import torch
9 | from bs4 import BeautifulSoup
10 | from transformers import T5EncoderModel, AutoTokenizer
11 | from huggingface_hub import hf_hub_download
12 |
13 | class T5Embedder:
14 |
15 | available_models = ['t5-v1_1-xxl']
16 | bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa
17 |
18 | def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
19 | t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120):
20 | self.device = torch.device(device)
21 | self.torch_dtype = torch_dtype or torch.bfloat16
22 | if t5_model_kwargs is None:
23 | t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
24 | if use_offload_folder is not None:
25 | t5_model_kwargs['offload_folder'] = use_offload_folder
26 | t5_model_kwargs['device_map'] = {
27 | 'shared': self.device,
28 | 'encoder.embed_tokens': self.device,
29 | 'encoder.block.0': self.device,
30 | 'encoder.block.1': self.device,
31 | 'encoder.block.2': self.device,
32 | 'encoder.block.3': self.device,
33 | 'encoder.block.4': self.device,
34 | 'encoder.block.5': self.device,
35 | 'encoder.block.6': self.device,
36 | 'encoder.block.7': self.device,
37 | 'encoder.block.8': self.device,
38 | 'encoder.block.9': self.device,
39 | 'encoder.block.10': self.device,
40 | 'encoder.block.11': self.device,
41 | 'encoder.block.12': 'disk',
42 | 'encoder.block.13': 'disk',
43 | 'encoder.block.14': 'disk',
44 | 'encoder.block.15': 'disk',
45 | 'encoder.block.16': 'disk',
46 | 'encoder.block.17': 'disk',
47 | 'encoder.block.18': 'disk',
48 | 'encoder.block.19': 'disk',
49 | 'encoder.block.20': 'disk',
50 | 'encoder.block.21': 'disk',
51 | 'encoder.block.22': 'disk',
52 | 'encoder.block.23': 'disk',
53 | 'encoder.final_layer_norm': 'disk',
54 | 'encoder.dropout': 'disk',
55 | }
56 | else:
57 | t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}
58 |
59 | self.use_text_preprocessing = use_text_preprocessing
60 | self.hf_token = hf_token
61 | self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
62 | self.dir_or_name = dir_or_name
63 | tokenizer_path, path = dir_or_name, dir_or_name
64 | if local_cache:
65 | cache_dir = os.path.join(self.cache_dir, dir_or_name)
66 | tokenizer_path, path = cache_dir, cache_dir
67 | elif dir_or_name in self.available_models:
68 | cache_dir = os.path.join(self.cache_dir, dir_or_name)
69 | for filename in [
70 | 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
71 | 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
72 | ]:
73 | hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
74 | force_filename=filename, token=self.hf_token)
75 | tokenizer_path, path = cache_dir, cache_dir
76 | else:
77 | cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
78 | for filename in [
79 | 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
80 | ]:
81 | hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
82 | force_filename=filename, token=self.hf_token)
83 | tokenizer_path = cache_dir
84 |
85 | print(tokenizer_path)
86 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
87 | self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
88 | self.model_max_length = model_max_length
89 |
90 | def get_text_embeddings(self, texts):
91 | texts = [self.text_preprocessing(text) for text in texts]
92 |
93 | text_tokens_and_mask = self.tokenizer(
94 | texts,
95 | max_length=self.model_max_length,
96 | padding='max_length',
97 | truncation=True,
98 | return_attention_mask=True,
99 | add_special_tokens=True,
100 | return_tensors='pt'
101 | )
102 |
103 | text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids']
104 | text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
105 |
106 | with torch.no_grad():
107 | text_encoder_embs = self.model(
108 | input_ids=text_tokens_and_mask['input_ids'].to(self.device),
109 | attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
110 | )['last_hidden_state'].detach()
111 | return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device)
112 |
113 | def text_preprocessing(self, text):
114 | if self.use_text_preprocessing:
115 | # The exact text cleaning as was in the training stage:
116 | text = self.clean_caption(text)
117 | text = self.clean_caption(text)
118 | return text
119 | else:
120 | return text.lower().strip()
121 |
122 | @staticmethod
123 | def basic_clean(text):
124 | text = ftfy.fix_text(text)
125 | text = html.unescape(html.unescape(text))
126 | return text.strip()
127 |
128 | def clean_caption(self, caption):
129 | caption = str(caption)
130 | caption = ul.unquote_plus(caption)
131 | caption = caption.strip().lower()
132 | caption = re.sub('', 'person', caption)
133 | # urls:
134 | caption = re.sub(
135 | r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
136 | '', caption) # regex for urls
137 | caption = re.sub(
138 | r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
139 | '', caption) # regex for urls
140 | # html:
141 | caption = BeautifulSoup(caption, features='html.parser').text
142 |
143 | # @
144 | caption = re.sub(r'@[\w\d]+\b', '', caption)
145 |
146 | # 31C0—31EF CJK Strokes
147 | # 31F0—31FF Katakana Phonetic Extensions
148 | # 3200—32FF Enclosed CJK Letters and Months
149 | # 3300—33FF CJK Compatibility
150 | # 3400—4DBF CJK Unified Ideographs Extension A
151 | # 4DC0—4DFF Yijing Hexagram Symbols
152 | # 4E00—9FFF CJK Unified Ideographs
153 | caption = re.sub(r'[\u31c0-\u31ef]+', '', caption)
154 | caption = re.sub(r'[\u31f0-\u31ff]+', '', caption)
155 | caption = re.sub(r'[\u3200-\u32ff]+', '', caption)
156 | caption = re.sub(r'[\u3300-\u33ff]+', '', caption)
157 | caption = re.sub(r'[\u3400-\u4dbf]+', '', caption)
158 | caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption)
159 | caption = re.sub(r'[\u4e00-\u9fff]+', '', caption)
160 | #######################################################
161 |
162 | # все виды тире / all types of dash --> "-"
163 | caption = re.sub(
164 | r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa
165 | '-', caption)
166 |
167 | # кавычки к одному стандарту
168 | caption = re.sub(r'[`´«»“”¨]', '"', caption)
169 | caption = re.sub(r'[‘’]', "'", caption)
170 |
171 | # "
172 | caption = re.sub(r'"?', '', caption)
173 | # &
174 | caption = re.sub(r'&', '', caption)
175 |
176 | # ip adresses:
177 | caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption)
178 |
179 | # article ids:
180 | caption = re.sub(r'\d:\d\d\s+$', '', caption)
181 |
182 | # \n
183 | caption = re.sub(r'\\n', ' ', caption)
184 |
185 | # "#123"
186 | caption = re.sub(r'#\d{1,3}\b', '', caption)
187 | # "#12345.."
188 | caption = re.sub(r'#\d{5,}\b', '', caption)
189 | # "123456.."
190 | caption = re.sub(r'\b\d{6,}\b', '', caption)
191 | # filenames:
192 | caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)
193 |
194 | #
195 | caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT"""
196 | caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT"""
197 |
198 | caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
199 | caption = re.sub(r'\s+\.\s+', r' ', caption) # " . "
200 |
201 | # this-is-my-cute-cat / this_is_my_cute_cat
202 | regex2 = re.compile(r'(?:\-|\_)')
203 | if len(re.findall(regex2, caption)) > 3:
204 | caption = re.sub(regex2, ' ', caption)
205 |
206 | caption = self.basic_clean(caption)
207 |
208 | caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640
209 | caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc
210 | caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231
211 |
212 | caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption)
213 | caption = re.sub(r'(free\s)?download(\sfree)?', '', caption)
214 | caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption)
215 | caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption)
216 | caption = re.sub(r'\bpage\s+\d+\b', '', caption)
217 |
218 | caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a...
219 |
220 | caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
221 |
222 | caption = re.sub(r'\b\s+\:\s+', r': ', caption)
223 | caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
224 | caption = re.sub(r'\s+', ' ', caption)
225 |
226 | caption.strip()
227 |
228 | caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
229 | caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
230 | caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
231 | caption = re.sub(r'^\.\S+$', '', caption)
232 |
233 | return caption.strip()
234 |
--------------------------------------------------------------------------------
/diffusion/model/timestep_sampler.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from abc import ABC, abstractmethod
7 |
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 |
12 |
13 | def create_named_schedule_sampler(name, diffusion):
14 | """
15 | Create a ScheduleSampler from a library of pre-defined samplers.
16 | :param name: the name of the sampler.
17 | :param diffusion: the diffusion object to sample for.
18 | """
19 | if name == "uniform":
20 | return UniformSampler(diffusion)
21 | elif name == "loss-second-moment":
22 | return LossSecondMomentResampler(diffusion)
23 | else:
24 | raise NotImplementedError(f"unknown schedule sampler: {name}")
25 |
26 |
27 | class ScheduleSampler(ABC):
28 | """
29 | A distribution over timesteps in the diffusion process, intended to reduce
30 | variance of the objective.
31 | By default, samplers perform unbiased importance sampling, in which the
32 | objective's mean is unchanged.
33 | However, subclasses may override sample() to change how the resampled
34 | terms are reweighted, allowing for actual changes in the objective.
35 | """
36 |
37 | @abstractmethod
38 | def weights(self):
39 | """
40 | Get a numpy array of weights, one per diffusion step.
41 | The weights needn't be normalized, but must be positive.
42 | """
43 |
44 | def sample(self, batch_size, device):
45 | """
46 | Importance-sample timesteps for a batch.
47 | :param batch_size: the number of timesteps.
48 | :param device: the torch device to save to.
49 | :return: a tuple (timesteps, weights):
50 | - timesteps: a tensor of timestep indices.
51 | - weights: a tensor of weights to scale the resulting losses.
52 | """
53 | w = self.weights()
54 | p = w / np.sum(w)
55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56 | indices = th.from_numpy(indices_np).long().to(device)
57 | weights_np = 1 / (len(p) * p[indices_np])
58 | weights = th.from_numpy(weights_np).float().to(device)
59 | return indices, weights
60 |
61 |
62 | class UniformSampler(ScheduleSampler):
63 | def __init__(self, diffusion):
64 | self.diffusion = diffusion
65 | self._weights = np.ones([diffusion.num_timesteps])
66 |
67 | def weights(self):
68 | return self._weights
69 |
70 |
71 | class LossAwareSampler(ScheduleSampler):
72 | def update_with_local_losses(self, local_ts, local_losses):
73 | """
74 | Update the reweighting using losses from a model.
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 | :param local_ts: an integer Tensor of timesteps.
80 | :param local_losses: a 1D Tensor of losses.
81 | """
82 | batch_sizes = [
83 | th.tensor([0], dtype=th.int32, device=local_ts.device)
84 | for _ in range(dist.get_world_size())
85 | ]
86 | dist.all_gather(
87 | batch_sizes,
88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89 | )
90 |
91 | # Pad all_gather batches to be the maximum batch size.
92 | batch_sizes = [x.item() for x in batch_sizes]
93 | max_bs = max(batch_sizes)
94 |
95 | timestep_batches = [th.zeros(max_bs, device=local_ts.device) for bs in batch_sizes]
96 | loss_batches = [th.zeros(max_bs, device=local_losses.device) for bs in batch_sizes]
97 | dist.all_gather(timestep_batches, local_ts)
98 | dist.all_gather(loss_batches, local_losses)
99 | timesteps = [
100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101 | ]
102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103 | self.update_with_all_losses(timesteps, losses)
104 |
105 | @abstractmethod
106 | def update_with_all_losses(self, ts, losses):
107 | """
108 | Update the reweighting using losses from a model.
109 | Sub-classes should override this method to update the reweighting
110 | using losses from the model.
111 | This method directly updates the reweighting without synchronizing
112 | between workers. It is called by update_with_local_losses from all
113 | ranks with identical arguments. Thus, it should have deterministic
114 | behavior to maintain state across workers.
115 | :param ts: a list of int timesteps.
116 | :param losses: a list of float losses, one per timestep.
117 | """
118 |
119 |
120 | class LossSecondMomentResampler(LossAwareSampler):
121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122 | self.diffusion = diffusion
123 | self.history_per_term = history_per_term
124 | self.uniform_prob = uniform_prob
125 | self._loss_history = np.zeros(
126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
127 | )
128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129 |
130 | def weights(self):
131 | if not self._warmed_up():
132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134 | weights /= np.sum(weights)
135 | weights *= 1 - self.uniform_prob
136 | weights += self.uniform_prob / len(weights)
137 | return weights
138 |
139 | def update_with_all_losses(self, ts, losses):
140 | for t, loss in zip(ts, losses):
141 | if self._loss_counts[t] == self.history_per_term:
142 | # Shift out the oldest loss term.
143 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
144 | self._loss_history[t, -1] = loss
145 | else:
146 | self._loss_history[t, self._loss_counts[t]] = loss
147 | self._loss_counts[t] += 1
148 |
149 | def _warmed_up(self):
150 | return (self._loss_counts == self.history_per_term).all()
151 |
--------------------------------------------------------------------------------
/diffusion/sa_sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 |
6 | from diffusion.model.sa_solver import NoiseScheduleVP, model_wrapper, SASolver
7 | from .model import gaussian_diffusion as gd
8 |
9 |
10 | class SASolverSampler(object):
11 | def __init__(self, model,
12 | noise_schedule="linear",
13 | diffusion_steps=1000,
14 | device='cpu',
15 | ):
16 | super().__init__()
17 | self.model = model
18 | self.device = device
19 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(device)
20 | betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
21 | alphas = 1.0 - betas
22 | self.register_buffer('alphas_cumprod', to_torch(np.cumprod(alphas, axis=0)))
23 |
24 | def register_buffer(self, name, attr):
25 | if type(attr) == torch.Tensor:
26 | if attr.device != torch.device("cuda"):
27 | attr = attr.to(torch.device("cuda"))
28 | setattr(self, name, attr)
29 |
30 | @torch.no_grad()
31 | def sample(self,
32 | S,
33 | batch_size,
34 | shape,
35 | conditioning=None,
36 | callback=None,
37 | normals_sequence=None,
38 | img_callback=None,
39 | quantize_x0=False,
40 | eta=0.,
41 | mask=None,
42 | x0=None,
43 | temperature=1.,
44 | noise_dropout=0.,
45 | score_corrector=None,
46 | corrector_kwargs=None,
47 | verbose=True,
48 | x_T=None,
49 | log_every_t=100,
50 | unconditional_guidance_scale=1.,
51 | unconditional_conditioning=None,
52 | model_kwargs={},
53 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
54 | **kwargs
55 | ):
56 | if conditioning is not None:
57 | if isinstance(conditioning, dict):
58 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
59 | if cbs != batch_size:
60 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
61 | else:
62 | if conditioning.shape[0] != batch_size:
63 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
64 |
65 | # sampling
66 | C, H, W = shape
67 | size = (batch_size, C, H, W)
68 |
69 | device = self.device
70 | if x_T is None:
71 | img = torch.randn(size, device=device)
72 | else:
73 | img = x_T
74 |
75 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
76 |
77 | model_fn = model_wrapper(
78 | self.model,
79 | ns,
80 | model_type="noise",
81 | guidance_type="classifier-free",
82 | condition=conditioning,
83 | unconditional_condition=unconditional_conditioning,
84 | guidance_scale=unconditional_guidance_scale,
85 | model_kwargs=model_kwargs,
86 | )
87 |
88 | sasolver = SASolver(model_fn, ns, algorithm_type="data_prediction")
89 |
90 | tau_t = lambda t: eta if 0.2 <= t <= 0.8 else 0
91 |
92 | x = sasolver.sample(mode='few_steps', x=img, tau=tau_t, steps=S, skip_type='time', skip_order=1, predictor_order=2, corrector_order=2, pc_mode='PEC', return_intermediate=False)
93 |
94 | return x.to(device), None
--------------------------------------------------------------------------------
/diffusion/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma/1ce521afddcc2fab329b35b7374aa86d654e12f7/diffusion/utils/__init__.py
--------------------------------------------------------------------------------
/diffusion/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import torch
4 |
5 | from diffusion.utils.logger import get_root_logger
6 |
7 |
8 | def save_checkpoint(work_dir,
9 | epoch,
10 | model,
11 | model_ema=None,
12 | optimizer=None,
13 | lr_scheduler=None,
14 | keep_last=False,
15 | step=None,
16 | ):
17 | os.makedirs(work_dir, exist_ok=True)
18 | state_dict = dict(state_dict=model.state_dict())
19 | if model_ema is not None:
20 | state_dict['state_dict_ema'] = model_ema.state_dict()
21 | if optimizer is not None:
22 | state_dict['optimizer'] = optimizer.state_dict()
23 | if lr_scheduler is not None:
24 | state_dict['scheduler'] = lr_scheduler.state_dict()
25 | if epoch is not None:
26 | state_dict['epoch'] = epoch
27 | file_path = os.path.join(work_dir, f"epoch_{epoch}.pth")
28 | if step is not None:
29 | file_path = file_path.split('.pth')[0] + f"_step_{step}.pth"
30 | logger = get_root_logger()
31 | torch.save(state_dict, file_path)
32 | logger.info(f'Saved checkpoint of epoch {epoch} to {file_path.format(epoch)}.')
33 | if keep_last:
34 | for i in range(epoch):
35 | previous_ckgt = file_path.format(i)
36 | if os.path.exists(previous_ckgt):
37 | os.remove(previous_ckgt)
38 |
39 |
40 | def load_checkpoint(checkpoint,
41 | model,
42 | model_ema=None,
43 | optimizer=None,
44 | lr_scheduler=None,
45 | load_ema=False,
46 | resume_optimizer=True,
47 | resume_lr_scheduler=True,
48 | max_length=120,
49 | ):
50 | assert isinstance(checkpoint, str)
51 | ckpt_file = checkpoint
52 | checkpoint = torch.load(ckpt_file, map_location="cpu")
53 |
54 | state_dict_keys = ['pos_embed', 'base_model.pos_embed', 'model.pos_embed']
55 | for key in state_dict_keys:
56 | if key in checkpoint['state_dict']:
57 | del checkpoint['state_dict'][key]
58 | if 'state_dict_ema' in checkpoint and key in checkpoint['state_dict_ema']:
59 | del checkpoint['state_dict_ema'][key]
60 | break
61 |
62 | if load_ema:
63 | state_dict = checkpoint['state_dict_ema']
64 | else:
65 | state_dict = checkpoint.get('state_dict', checkpoint) # to be compatible with the official checkpoint
66 |
67 | null_embed = torch.load(f'output/pretrained_models/null_embed_diffusers_{max_length}token.pth', map_location='cpu')
68 | state_dict['y_embedder.y_embedding'] = null_embed['uncond_prompt_embeds'][0]
69 |
70 | missing, unexpect = model.load_state_dict(state_dict, strict=False)
71 | if model_ema is not None:
72 | model_ema.load_state_dict(checkpoint['state_dict_ema'], strict=False)
73 | if optimizer is not None and resume_optimizer:
74 | optimizer.load_state_dict(checkpoint['optimizer'])
75 | if lr_scheduler is not None and resume_lr_scheduler:
76 | lr_scheduler.load_state_dict(checkpoint['scheduler'])
77 | logger = get_root_logger()
78 | if optimizer is not None:
79 | epoch = checkpoint.get('epoch', re.match(r'.*epoch_(\d*).*.pth', ckpt_file).group()[0])
80 | logger.info(f'Resume checkpoint of epoch {epoch} from {ckpt_file}. Load ema: {load_ema}, '
81 | f'resume optimizer: {resume_optimizer}, resume lr scheduler: {resume_lr_scheduler}.')
82 | return epoch, missing, unexpect
83 | logger.info(f'Load checkpoint from {ckpt_file}. Load ema: {load_ema}.')
84 | return missing, unexpect
85 |
--------------------------------------------------------------------------------
/diffusion/utils/data_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os
3 | from typing import Sequence
4 | from torch.utils.data import BatchSampler, Sampler, Dataset
5 | from random import shuffle, choice
6 | from copy import deepcopy
7 | from diffusion.utils.logger import get_root_logger
8 |
9 |
10 | class AspectRatioBatchSampler(BatchSampler):
11 | """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
12 |
13 | Args:
14 | sampler (Sampler): Base sampler.
15 | dataset (Dataset): Dataset providing data information.
16 | batch_size (int): Size of mini-batch.
17 | drop_last (bool): If ``True``, the sampler will drop the last batch if
18 | its size would be less than ``batch_size``.
19 | aspect_ratios (dict): The predefined aspect ratios.
20 | """
21 |
22 | def __init__(self,
23 | sampler: Sampler,
24 | dataset: Dataset,
25 | batch_size: int,
26 | aspect_ratios: dict,
27 | drop_last: bool = False,
28 | config=None,
29 | valid_num=0, # take as valid aspect-ratio when sample number >= valid_num
30 | **kwargs) -> None:
31 | if not isinstance(sampler, Sampler):
32 | raise TypeError('sampler should be an instance of ``Sampler``, '
33 | f'but got {sampler}')
34 | if not isinstance(batch_size, int) or batch_size <= 0:
35 | raise ValueError('batch_size should be a positive integer value, '
36 | f'but got batch_size={batch_size}')
37 | self.sampler = sampler
38 | self.dataset = dataset
39 | self.batch_size = batch_size
40 | self.aspect_ratios = aspect_ratios
41 | self.drop_last = drop_last
42 | self.ratio_nums_gt = kwargs.get('ratio_nums', None)
43 | self.config = config
44 | assert self.ratio_nums_gt
45 | # buckets for each aspect ratio
46 | self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios.keys()}
47 | self.current_available_bucket_keys = [str(k) for k, v in self.ratio_nums_gt.items() if v >= valid_num]
48 | logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
49 | logger.warning(f"Using valid_num={valid_num} in config file. Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}")
50 |
51 | def __iter__(self) -> Sequence[int]:
52 | for idx in self.sampler:
53 | data_info = self.dataset.get_data_info(idx)
54 | height, width = data_info['height'], data_info['width']
55 | ratio = height / width
56 | # find the closest aspect ratio
57 | closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
58 | if closest_ratio not in self.current_available_bucket_keys:
59 | continue
60 | bucket = self._aspect_ratio_buckets[closest_ratio]
61 | bucket.append(idx)
62 | # yield a batch of indices in the same aspect ratio group
63 | if len(bucket) == self.batch_size:
64 | yield bucket[:]
65 | del bucket[:]
66 |
67 | # yield the rest data and reset the buckets
68 | for bucket in self._aspect_ratio_buckets.values():
69 | while len(bucket) > 0:
70 | if len(bucket) <= self.batch_size:
71 | if not self.drop_last:
72 | yield bucket[:]
73 | bucket = []
74 | else:
75 | yield bucket[:self.batch_size]
76 | bucket = bucket[self.batch_size:]
77 |
78 |
79 | class BalancedAspectRatioBatchSampler(AspectRatioBatchSampler):
80 | def __init__(self, *args, **kwargs):
81 | super().__init__(*args, **kwargs)
82 | # Assign samples to each bucket
83 | self.ratio_nums_gt = kwargs.get('ratio_nums', None)
84 | assert self.ratio_nums_gt
85 | self._aspect_ratio_buckets = {float(ratio): [] for ratio in self.aspect_ratios.keys()}
86 | self.original_buckets = {}
87 | self.current_available_bucket_keys = [k for k, v in self.ratio_nums_gt.items() if v >= 3000]
88 | self.all_available_keys = deepcopy(self.current_available_bucket_keys)
89 | self.exhausted_bucket_keys = []
90 | self.total_batches = len(self.sampler) // self.batch_size
91 | self._aspect_ratio_count = {}
92 | for k in self.all_available_keys:
93 | self._aspect_ratio_count[float(k)] = 0
94 | self.original_buckets[float(k)] = []
95 | logger = get_root_logger(os.path.join(self.config.work_dir, 'train_log.log'))
96 | logger.warning(f"Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}")
97 |
98 | def __iter__(self) -> Sequence[int]:
99 | i = 0
100 | for idx in self.sampler:
101 | data_info = self.dataset.get_data_info(idx)
102 | height, width = data_info['height'], data_info['width']
103 | ratio = height / width
104 | closest_ratio = float(min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)))
105 | if closest_ratio not in self.all_available_keys:
106 | continue
107 | if self._aspect_ratio_count[closest_ratio] < self.ratio_nums_gt[closest_ratio]:
108 | self._aspect_ratio_count[closest_ratio] += 1
109 | self._aspect_ratio_buckets[closest_ratio].append(idx)
110 | self.original_buckets[closest_ratio].append(idx) # Save the original samples for each bucket
111 | if not self.current_available_bucket_keys:
112 | self.current_available_bucket_keys, self.exhausted_bucket_keys = self.exhausted_bucket_keys, []
113 |
114 | if closest_ratio not in self.current_available_bucket_keys:
115 | continue
116 | key = closest_ratio
117 | bucket = self._aspect_ratio_buckets[key]
118 | if len(bucket) == self.batch_size:
119 | yield bucket[:self.batch_size]
120 | del bucket[:self.batch_size]
121 | i += 1
122 | self.exhausted_bucket_keys.append(key)
123 | self.current_available_bucket_keys.remove(key)
124 |
125 | for _ in range(self.total_batches - i):
126 | key = choice(self.all_available_keys)
127 | bucket = self._aspect_ratio_buckets[key]
128 | if len(bucket) >= self.batch_size:
129 | yield bucket[:self.batch_size]
130 | del bucket[:self.batch_size]
131 |
132 | # If a bucket is exhausted
133 | if not bucket:
134 | self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:])
135 | shuffle(self._aspect_ratio_buckets[key])
136 | else:
137 | self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:])
138 | shuffle(self._aspect_ratio_buckets[key])
139 |
--------------------------------------------------------------------------------
/diffusion/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import torch.distributed as dist
4 | from datetime import datetime
5 | from .dist_utils import is_local_master
6 | from mmcv.utils.logging import logger_initialized
7 |
8 |
9 | def get_root_logger(log_file=None, log_level=logging.INFO, name='PixArt'):
10 | """Get root logger.
11 |
12 | Args:
13 | log_file (str, optional): File path of log. Defaults to None.
14 | log_level (int, optional): The level of logger.
15 | Defaults to logging.INFO.
16 | name (str): logger name
17 | Returns:
18 | :obj:`logging.Logger`: The obtained logger
19 | """
20 | if log_file is None:
21 | log_file = '/dev/null'
22 | logger = get_logger(name=name, log_file=log_file, log_level=log_level)
23 | return logger
24 |
25 |
26 | def get_logger(name, log_file=None, log_level=logging.INFO):
27 | """Initialize and get a logger by name.
28 |
29 | If the logger has not been initialized, this method will initialize the
30 | logger by adding one or two handlers, otherwise the initialized logger will
31 | be directly returned. During initialization, a StreamHandler will always be
32 | added. If `log_file` is specified and the process rank is 0, a FileHandler
33 | will also be added.
34 |
35 | Args:
36 | name (str): Logger name.
37 | log_file (str | None): The log filename. If specified, a FileHandler
38 | will be added to the logger.
39 | log_level (int): The logger level. Note that only the process of
40 | rank 0 is affected, and other processes will set the level to
41 | "Error" thus be silent most of the time.
42 |
43 | Returns:
44 | logging.Logger: The expected logger.
45 | """
46 | logger = logging.getLogger(name)
47 | logger.propagate = False # disable root logger to avoid duplicate logging
48 |
49 | if name in logger_initialized:
50 | return logger
51 | # handle hierarchical names
52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the
53 | # initialization since it is a child of "a".
54 | for logger_name in logger_initialized:
55 | if name.startswith(logger_name):
56 | return logger
57 |
58 | stream_handler = logging.StreamHandler()
59 | handlers = [stream_handler]
60 |
61 | if dist.is_available() and dist.is_initialized():
62 | rank = dist.get_rank()
63 | else:
64 | rank = 0
65 |
66 | # only rank 0 will add a FileHandler
67 | if rank == 0 and log_file is not None:
68 | file_handler = logging.FileHandler(log_file, 'w')
69 | handlers.append(file_handler)
70 |
71 | formatter = logging.Formatter(
72 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
73 | for handler in handlers:
74 | handler.setFormatter(formatter)
75 | handler.setLevel(log_level)
76 | logger.addHandler(handler)
77 |
78 | # only rank0 for each node will print logs
79 | log_level = log_level if is_local_master() else logging.ERROR
80 | logger.setLevel(log_level)
81 |
82 | logger_initialized[name] = True
83 |
84 | return logger
85 |
86 | def rename_file_with_creation_time(file_path):
87 | # 获取文件的创建时间
88 | creation_time = os.path.getctime(file_path)
89 | creation_time_str = datetime.fromtimestamp(creation_time).strftime('%Y-%m-%d_%H-%M-%S')
90 |
91 | # 构建新的文件名
92 | dir_name, file_name = os.path.split(file_path)
93 | name, ext = os.path.splitext(file_name)
94 | new_file_name = f"{name}_{creation_time_str}{ext}"
95 | new_file_path = os.path.join(dir_name, new_file_name)
96 |
97 | # 重命名文件
98 | os.rename(file_path, new_file_path)
99 | print(f"File renamed to: {new_file_path}")
100 |
--------------------------------------------------------------------------------
/diffusion/utils/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from diffusers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup
2 | from torch.optim import Optimizer
3 | from torch.optim.lr_scheduler import LambdaLR
4 | import math
5 |
6 | from diffusion.utils.logger import get_root_logger
7 |
8 |
9 | def build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio):
10 | if not config.get('lr_schedule_args', None):
11 | config.lr_schedule_args = dict()
12 | if config.get('lr_warmup_steps', None):
13 | config['num_warmup_steps'] = config.get('lr_warmup_steps') # for compatibility with old version
14 |
15 | logger = get_root_logger()
16 | logger.info(
17 | f'Lr schedule: {config.lr_schedule}, ' + ",".join(
18 | [f"{key}:{value}" for key, value in config.lr_schedule_args.items()]) + '.')
19 | if config.lr_schedule == 'cosine':
20 | lr_scheduler = get_cosine_schedule_with_warmup(
21 | optimizer=optimizer,
22 | **config.lr_schedule_args,
23 | num_training_steps=(len(train_dataloader) * config.num_epochs),
24 | )
25 | elif config.lr_schedule == 'constant':
26 | lr_scheduler = get_constant_schedule_with_warmup(
27 | optimizer=optimizer,
28 | **config.lr_schedule_args,
29 | )
30 | elif config.lr_schedule == 'cosine_decay_to_constant':
31 | assert lr_scale_ratio >= 1
32 | lr_scheduler = get_cosine_decay_to_constant_with_warmup(
33 | optimizer=optimizer,
34 | **config.lr_schedule_args,
35 | final_lr=1 / lr_scale_ratio,
36 | num_training_steps=(len(train_dataloader) * config.num_epochs),
37 | )
38 | else:
39 | raise RuntimeError(f'Unrecognized lr schedule {config.lr_schedule}.')
40 | return lr_scheduler
41 |
42 |
43 | def get_cosine_decay_to_constant_with_warmup(optimizer: Optimizer,
44 | num_warmup_steps: int,
45 | num_training_steps: int,
46 | final_lr: float = 0.0,
47 | num_decay: float = 0.667,
48 | num_cycles: float = 0.5,
49 | last_epoch: int = -1
50 | ):
51 | """
52 | Create a schedule with a cosine annealing lr followed by a constant lr.
53 |
54 | Args:
55 | optimizer ([`~torch.optim.Optimizer`]):
56 | The optimizer for which to schedule the learning rate.
57 | num_warmup_steps (`int`):
58 | The number of steps for the warmup phase.
59 | num_training_steps (`int`):
60 | The number of total training steps.
61 | final_lr (`int`):
62 | The final constant lr after cosine decay.
63 | num_decay (`int`):
64 | The
65 | last_epoch (`int`, *optional*, defaults to -1):
66 | The index of the last epoch when resuming training.
67 |
68 | Return:
69 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
70 | """
71 |
72 | def lr_lambda(current_step):
73 | if current_step < num_warmup_steps:
74 | return float(current_step) / float(max(1, num_warmup_steps))
75 |
76 | num_decay_steps = int(num_training_steps * num_decay)
77 | if current_step > num_decay_steps:
78 | return final_lr
79 |
80 | progress = float(current_step - num_warmup_steps) / float(max(1, num_decay_steps - num_warmup_steps))
81 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) * (
82 | 1 - final_lr) + final_lr
83 |
84 | return LambdaLR(optimizer, lr_lambda, last_epoch)
85 |
--------------------------------------------------------------------------------
/diffusion/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from mmcv import Config
4 | from mmcv.runner import build_optimizer as mm_build_optimizer, OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, \
5 | OPTIMIZERS
6 | from mmcv.utils import _BatchNorm, _InstanceNorm
7 | from torch.nn import GroupNorm, LayerNorm
8 |
9 | from .logger import get_root_logger
10 |
11 | from typing import Tuple, Optional, Callable
12 |
13 | import torch
14 | from torch.optim.optimizer import Optimizer
15 | from came_pytorch import CAME
16 |
17 |
18 | def auto_scale_lr(effective_bs, optimizer_cfg, rule='linear', base_batch_size=256):
19 | assert rule in ['linear', 'sqrt']
20 | logger = get_root_logger()
21 | # scale by world size
22 | if rule == 'sqrt':
23 | scale_ratio = math.sqrt(effective_bs / base_batch_size)
24 | elif rule == 'linear':
25 | scale_ratio = effective_bs / base_batch_size
26 | optimizer_cfg['lr'] *= scale_ratio
27 | logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.5f} (using {rule} scaling rule).')
28 | return scale_ratio
29 |
30 |
31 | @OPTIMIZER_BUILDERS.register_module()
32 | class MyOptimizerConstructor(DefaultOptimizerConstructor):
33 |
34 | def add_params(self, params, module, prefix='', is_dcn_module=None):
35 | """Add all parameters of module to the params list.
36 |
37 | The parameters of the given module will be added to the list of param
38 | groups, with specific rules defined by paramwise_cfg.
39 |
40 | Args:
41 | params (list[dict]): A list of param groups, it will be modified
42 | in place.
43 | module (nn.Module): The module to be added.
44 | prefix (str): The prefix of the module
45 |
46 | """
47 | # get param-wise options
48 | custom_keys = self.paramwise_cfg.get('custom_keys', {})
49 | # first sort with alphabet order and then sort with reversed len of str
50 | # sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
51 |
52 | bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
53 | bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
54 | norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
55 | bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
56 |
57 | # special rules for norm layers and depth-wise conv layers
58 | is_norm = isinstance(module,
59 | (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
60 |
61 | for name, param in module.named_parameters(recurse=False):
62 | base_lr = self.base_lr
63 | if name == 'bias' and not (is_norm or is_dcn_module):
64 | base_lr *= bias_lr_mult
65 |
66 | # apply weight decay policies
67 | base_wd = self.base_wd
68 | if self.base_wd is not None:
69 | # norm decay
70 | if is_norm:
71 | base_wd *= norm_decay_mult
72 | # bias lr and decay
73 | elif name == 'bias' and not is_dcn_module:
74 | # TODO: current bias_decay_mult will have affect on DCN
75 | base_wd *= bias_decay_mult
76 |
77 | param_group = {'params': [param]}
78 | if not param.requires_grad:
79 | param_group['requires_grad'] = False
80 | params.append(param_group)
81 | continue
82 | if bypass_duplicate and self._is_in(param_group, params):
83 | logger = get_root_logger()
84 | logger.warn(f'{prefix} is duplicate. It is skipped since '
85 | f'bypass_duplicate={bypass_duplicate}')
86 | continue
87 | # if the parameter match one of the custom keys, ignore other rules
88 | is_custom = False
89 | for key in custom_keys:
90 | if isinstance(key, tuple):
91 | scope, key_name = key
92 | else:
93 | scope, key_name = None, key
94 | if scope is not None and scope not in f'{prefix}':
95 | continue
96 | if key_name in f'{prefix}.{name}':
97 | is_custom = True
98 | if 'lr_mult' in custom_keys[key]:
99 | # if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}':
100 | # param_group['lr'] = self.base_lr
101 | # else:
102 | param_group['lr'] = self.base_lr * custom_keys[key]['lr_mult']
103 | elif 'lr' not in param_group:
104 | param_group['lr'] = base_lr
105 | if self.base_wd is not None:
106 | if 'decay_mult' in custom_keys[key]:
107 | param_group['weight_decay'] = self.base_wd * custom_keys[key]['decay_mult']
108 | elif 'weight_decay' not in param_group:
109 | param_group['weight_decay'] = base_wd
110 |
111 | if not is_custom:
112 | # bias_lr_mult affects all bias parameters
113 | # except for norm.bias dcn.conv_offset.bias
114 | if base_lr != self.base_lr:
115 | param_group['lr'] = base_lr
116 | if base_wd != self.base_wd:
117 | param_group['weight_decay'] = base_wd
118 | params.append(param_group)
119 |
120 | for child_name, child_mod in module.named_children():
121 | child_prefix = f'{prefix}.{child_name}' if prefix else child_name
122 | self.add_params(
123 | params,
124 | child_mod,
125 | prefix=child_prefix,
126 | is_dcn_module=is_dcn_module)
127 |
128 |
129 | def build_optimizer(model, optimizer_cfg):
130 | # default parameter-wise config
131 | logger = get_root_logger()
132 |
133 | if hasattr(model, 'module'):
134 | model = model.module
135 | # set optimizer constructor
136 | optimizer_cfg.setdefault('constructor', 'MyOptimizerConstructor')
137 | # parameter-wise setting: cancel weight decay for some specific modules
138 | custom_keys = dict()
139 | for name, module in model.named_modules():
140 | if hasattr(module, 'zero_weight_decay'):
141 | custom_keys.update({(name, key): dict(decay_mult=0) for key in module.zero_weight_decay})
142 |
143 | paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys)))
144 | given_cfg = optimizer_cfg.get('paramwise_cfg')
145 | if given_cfg:
146 | paramwise_cfg.merge_from_dict(dict(cfg=given_cfg))
147 | optimizer_cfg['paramwise_cfg'] = paramwise_cfg.cfg
148 | # build optimizer
149 | optimizer = mm_build_optimizer(model, optimizer_cfg)
150 |
151 | weight_decay_groups = dict()
152 | lr_groups = dict()
153 | for group in optimizer.param_groups:
154 | if not group.get('requires_grad', True): continue
155 | lr_groups.setdefault(group['lr'], []).append(group)
156 | weight_decay_groups.setdefault(group['weight_decay'], []).append(group)
157 |
158 | learnable_count, fix_count = 0, 0
159 | for p in model.parameters():
160 | if p.requires_grad:
161 | learnable_count += 1
162 | else:
163 | fix_count += 1
164 | fix_info = f"{learnable_count} are learnable, {fix_count} are fix"
165 | lr_info = "Lr group: " + ", ".join([f'{len(group)} params with lr {lr:.5f}' for lr, group in lr_groups.items()])
166 | wd_info = "Weight decay group: " + ", ".join(
167 | [f'{len(group)} params with weight decay {wd}' for wd, group in weight_decay_groups.items()])
168 | opt_info = f"{optimizer.__class__.__name__} Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}."
169 | logger.info(opt_info)
170 |
171 | return optimizer
172 |
173 |
174 | @OPTIMIZERS.register_module()
175 | class Lion(Optimizer):
176 | def __init__(
177 | self,
178 | params,
179 | lr: float = 1e-4,
180 | betas: Tuple[float, float] = (0.9, 0.99),
181 | weight_decay: float = 0.0,
182 | ):
183 | assert lr > 0.
184 | assert all([0. <= beta <= 1. for beta in betas])
185 |
186 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
187 |
188 | super().__init__(params, defaults)
189 |
190 | @staticmethod
191 | def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
192 | # stepweight decay
193 | p.data.mul_(1 - lr * wd)
194 |
195 | # weight update
196 | update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_()
197 | p.add_(update, alpha=-lr)
198 |
199 | # decay the momentum running average coefficient
200 | exp_avg.lerp_(grad, 1 - beta2)
201 |
202 | @staticmethod
203 | def exists(val):
204 | return val is not None
205 |
206 | @torch.no_grad()
207 | def step(
208 | self,
209 | closure: Optional[Callable] = None
210 | ):
211 |
212 | loss = None
213 | if self.exists(closure):
214 | with torch.enable_grad():
215 | loss = closure()
216 |
217 | for group in self.param_groups:
218 | for p in filter(lambda p: self.exists(p.grad), group['params']):
219 |
220 | grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \
221 | self.state[p]
222 |
223 | # init state - exponential moving average of gradient values
224 | if len(state) == 0:
225 | state['exp_avg'] = torch.zeros_like(p)
226 |
227 | exp_avg = state['exp_avg']
228 |
229 | self.update_fn(
230 | p,
231 | grad,
232 | exp_avg,
233 | lr,
234 | wd,
235 | beta1,
236 | beta2
237 | )
238 |
239 | return loss
240 |
241 |
242 | @OPTIMIZERS.register_module()
243 | class CAMEWrapper(CAME):
244 | def __init__(self, *args, **kwargs):
245 |
246 | super().__init__(*args, **kwargs)
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: PixArt
2 | channels:
3 | - pytorch
4 | - nvidia
5 | dependencies:
6 | - python >= 3.8
7 | - pytorch >= 1.13
8 | - torchvision
9 | - pytorch-cuda=11.7
10 | - pip:
11 | - timm==0.6.12
12 | - diffusers
13 | - accelerate
14 | - mmcv==1.7.0
15 | - diffusers
16 | - accelerate==0.15.0
17 | - tensorboard
18 | - transformers==4.26.1
19 | - sentencepiece~=0.1.97
20 | - ftfy~=6.1.1
21 | - beautifulsoup4~=4.11.1
22 | - opencv-python
23 | - bs4
24 | - einops
25 | - xformers
--------------------------------------------------------------------------------
/notebooks/PixArt_xl2_img512_internal_for_pokemon_sample_training.py:
--------------------------------------------------------------------------------
1 | _base_ = ['/workspace/PixArt-alpha/configs/PixArt_xl2_internal.py']
2 | data_root = '/workspace'
3 |
4 | image_list_json = ['data_info.json',]
5 |
6 | data = dict(type='InternalData', root='/workspace/pixart-pokemon', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
7 | image_size = 512
8 |
9 | # model setting
10 | model = 'PixArt_XL_2'
11 | fp32_attention = True
12 | load_from = "/workspace/PixArt-alpha/output/pretrained_models/PixArt-XL-2-512x512.pth"
13 | vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
14 | pe_interpolation = 1.0
15 |
16 | # training setting
17 | use_fsdp=False # if use FSDP mode
18 | num_workers=10
19 | train_batch_size = 38 # 32
20 | num_epochs = 200 # 3
21 | gradient_accumulation_steps = 1
22 | grad_checkpointing = True
23 | gradient_clip = 0.01
24 | optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
25 | lr_schedule_args = dict(num_warmup_steps=1000)
26 |
27 | eval_sampling_steps = 200
28 | log_interval = 20
29 | save_model_steps=100
30 | work_dir = 'output/debug'
31 |
--------------------------------------------------------------------------------
/notebooks/convert-checkpoint-to-diffusers.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "id": "2878bb5d-33a3-4a5b-b15c-c832c700129b",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "name": "stdout",
11 | "output_type": "stream",
12 | "text": [
13 | "/workspace/PixArt-alpha\n"
14 | ]
15 | },
16 | {
17 | "name": "stderr",
18 | "output_type": "stream",
19 | "text": [
20 | "/usr/local/lib/python3.10/dist-packages/IPython/core/magics/osm.py:417: UserWarning: using dhist requires you to install the `pickleshare` library.\n",
21 | " self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
22 | ]
23 | }
24 | ],
25 | "source": [
26 | "%cd PixArt-alpha"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 14,
32 | "id": "7dd2d98c-3f8f-40f1-a9e1-bc916774afb3",
33 | "metadata": {},
34 | "outputs": [
35 | {
36 | "name": "stdout",
37 | "output_type": "stream",
38 | "text": [
39 | "Total number of transformer parameters: 610856096\n"
40 | ]
41 | }
42 | ],
43 | "source": [
44 | "!python tools/convert_pixart_to_diffusers.py \\\n",
45 | " --orig_ckpt_path \"/workspace/PixArt-alpha/output/trained_model/checkpoints/epoch_5_step_110.pth\" \\\n",
46 | " --dump_path \"/workspace/PixArt-alpha/output/epoch_5_step_110_diffusers\" \\\n",
47 | " --only_transformer=True \\\n",
48 | " --image_size 512 \\\n",
49 | " --version sigma\n"
50 | ]
51 | }
52 | ],
53 | "metadata": {
54 | "kernelspec": {
55 | "display_name": "Python 3 (ipykernel)",
56 | "language": "python",
57 | "name": "python3"
58 | },
59 | "language_info": {
60 | "codemirror_mode": {
61 | "name": "ipython",
62 | "version": 3
63 | },
64 | "file_extension": ".py",
65 | "mimetype": "text/x-python",
66 | "name": "python",
67 | "nbconvert_exporter": "python",
68 | "pygments_lexer": "ipython3",
69 | "version": "3.10.12"
70 | }
71 | },
72 | "nbformat": 4,
73 | "nbformat_minor": 5
74 | }
75 |
--------------------------------------------------------------------------------
/notebooks/train.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "c423d2a1-475e-482e-b759-f16456fd6707",
6 | "metadata": {},
7 | "source": [
8 | "# Install"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "0440d6a7-78b9-49e9-98a2-9a5ed75e1a2f",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "!git clone https://github.com/kopyl/PixArt-alpha.git"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "id": "0abadf51-a7e3-4091-bb02-0bdd8d28fb73",
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "%cd PixArt-alpha"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "id": "4df1af24-f439-485d-a946-966dbf16c49b",
35 | "metadata": {
36 | "scrolled": true
37 | },
38 | "outputs": [],
39 | "source": [
40 | "!pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu117\n",
41 | "!pip install -r requirements.txt\n",
42 | "!pip install wandb"
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "id": "d44474fd-0b92-48fc-b4cf-142b59d3917c",
48 | "metadata": {},
49 | "source": [
50 | "## Download model"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "id": "06b1c1c9-f8b1-4719-8564-2383eac9ff28",
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "!python tools/download.py --model_names \"PixArt-XL-2-512x512.pth\""
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "id": "f298a89c-d2a5-4da7-8304-c1390da0ba58",
66 | "metadata": {},
67 | "source": [
68 | "## Make dataset out of Hugginggface dataset"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "id": "e17b8883-0a5c-4fa3-a7d0-e8ee95e42027",
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "import os\n",
79 | "from tqdm.notebook import tqdm\n",
80 | "from datasets import load_dataset\n",
81 | "import json"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": null,
87 | "id": "92957b2c-6765-48ee-9296-d6739066d74d",
88 | "metadata": {},
89 | "outputs": [],
90 | "source": [
91 | "dataset = load_dataset(\"lambdalabs/pokemon-blip-captions\")"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "id": "0095cdda-c31a-48ee-a115-076a5fc393c3",
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "root_dir = \"/workspace/pixart-pokemon\"\n",
102 | "images_dir = \"images\"\n",
103 | "captions_dir = \"captions\"\n",
104 | "\n",
105 | "images_dir_absolute = os.path.join(root_dir, images_dir)\n",
106 | "captions_dir_absolute = os.path.join(root_dir, captions_dir)\n",
107 | "\n",
108 | "if not os.path.exists(root_dir):\n",
109 | " os.makedirs(os.path.join(root_dir, images_dir))\n",
110 | "\n",
111 | "if not os.path.exists(os.path.join(root_dir, images_dir)):\n",
112 | " os.makedirs(os.path.join(root_dir, images_dir))\n",
113 | "if not os.path.exists(os.path.join(root_dir, captions_dir)):\n",
114 | " os.makedirs(os.path.join(root_dir, captions_dir))\n",
115 | "\n",
116 | "image_format = \"png\"\n",
117 | "json_name = \"partition/data_info.json\"\n",
118 | "if not os.path.exists(os.path.join(root_dir, \"partition\")):\n",
119 | " os.makedirs(os.path.join(root_dir, \"partition\"))\n",
120 | "\n",
121 | "absolute_json_name = os.path.join(root_dir, json_name)\n",
122 | "data_info = []\n",
123 | "\n",
124 | "order = 0\n",
125 | "for item in tqdm(dataset[\"train\"]): \n",
126 | " image = item[\"image\"]\n",
127 | " image.save(f\"{images_dir_absolute}/{order}.{image_format}\")\n",
128 | " with open(f\"{captions_dir_absolute}/{order}.txt\", \"w\") as text_file:\n",
129 | " text_file.write(item[\"text\"])\n",
130 | " \n",
131 | " width, height = 512, 512\n",
132 | " ratio = 1\n",
133 | " data_info.append({\n",
134 | " \"height\": height,\n",
135 | " \"width\": width,\n",
136 | " \"ratio\": ratio,\n",
137 | " \"path\": f\"images/{order}.{image_format}\",\n",
138 | " \"prompt\": item[\"text\"],\n",
139 | " })\n",
140 | " \n",
141 | " order += 1\n",
142 | "\n",
143 | "with open(absolute_json_name, \"w\") as json_file:\n",
144 | " json.dump(data_info, json_file)"
145 | ]
146 | },
147 | {
148 | "cell_type": "markdown",
149 | "id": "25be1c03",
150 | "metadata": {},
151 | "source": [
152 | "## Extract features"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": null,
158 | "id": "9f07a4f5-1873-48bf-86d0-9304942de5d3",
159 | "metadata": {},
160 | "outputs": [],
161 | "source": [
162 | "!python /workspace/PixArt-alpha/tools/extract_features.py \\\n",
163 | " --img_size 512 \\\n",
164 | " --json_path \"/workspace/pixart-pokemon/partition/data_info.json\" \\\n",
165 | " --t5_save_root \"/workspace/pixart-pokemon/caption_feature_wmask\" \\\n",
166 | " --vae_save_root \"/workspace/pixart-pokemon/img_vae_features\" \\\n",
167 | " --pretrained_models_dir \"/workspace/PixArt-alpha/output/pretrained_models\" \\\n",
168 | " --dataset_root \"/workspace/pixart-pokemon\""
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": null,
174 | "id": "9fc653d0",
175 | "metadata": {},
176 | "outputs": [],
177 | "source": [
178 | "!wandb login REPLACE_THIS_WITH_YOUR_AUTH_TOKEN_OF_WANDB"
179 | ]
180 | },
181 | {
182 | "cell_type": "markdown",
183 | "id": "2cf1fd1a",
184 | "metadata": {},
185 | "source": [
186 | "## Train model"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": null,
192 | "id": "ea0e9dab-17bc-45ed-9c81-b670bbb8de47",
193 | "metadata": {},
194 | "outputs": [],
195 | "source": [
196 | "!python -m torch.distributed.launch \\\n",
197 | " train_scripts/train.py \\\n",
198 | " /workspace/PixArt-alpha/notebooks/PixArt_xl2_img512_internal_for_pokemon_sample_training.py \\\n",
199 | " --work-dir output/trained_model \\\n",
200 | " --report_to=\"wandb\" \\\n",
201 | " --loss_report_name=\"train_loss\""
202 | ]
203 | }
204 | ],
205 | "metadata": {
206 | "kernelspec": {
207 | "display_name": "Python 3 (ipykernel)",
208 | "language": "python",
209 | "name": "python3"
210 | },
211 | "language_info": {
212 | "codemirror_mode": {
213 | "name": "ipython",
214 | "version": 3
215 | },
216 | "file_extension": ".py",
217 | "mimetype": "text/x-python",
218 | "name": "python",
219 | "nbconvert_exporter": "python",
220 | "pygments_lexer": "ipython3",
221 | "version": "3.8.13"
222 | }
223 | },
224 | "nbformat": 4,
225 | "nbformat_minor": 5
226 | }
227 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | mmcv==1.7.0
2 | git+https://github.com/huggingface/diffusers
3 | timm==0.6.12
4 | accelerate==0.25.0
5 | tensorboard
6 | tensorboardX
7 | transformers==4.36.1
8 | sentencepiece~=0.1.99
9 | ftfy
10 | beautifulsoup4
11 | protobuf==3.20.2
12 | gradio==4.1.1
13 | yapf==0.40.1
14 | opencv-python
15 | bs4
16 | einops
17 | xformers==0.0.19
18 | optimum
19 | peft
20 | came-pytorch
--------------------------------------------------------------------------------
/scripts/DMD/transformer_train/attention_processor.py:
--------------------------------------------------------------------------------
1 | from diffusers.models.attention_processor import AttnProcessor2_0, Attention
2 | from typing import Optional
3 | import torch
4 | from diffusers.utils import USE_PEFT_BACKEND
5 | import torch.nn.functional as F
6 |
7 | class AttentionPorcessorFP32(AttnProcessor2_0):
8 |
9 | def __call__(
10 | self,
11 | attn: Attention,
12 | hidden_states: torch.FloatTensor,
13 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
14 | attention_mask: Optional[torch.FloatTensor] = None,
15 | temb: Optional[torch.FloatTensor] = None,
16 | scale: float = 1.0,
17 | use_fp32_attention = True,
18 | ) -> torch.FloatTensor:
19 | residual = hidden_states
20 | if attn.spatial_norm is not None:
21 | hidden_states = attn.spatial_norm(hidden_states, temb)
22 |
23 | input_ndim = hidden_states.ndim
24 |
25 | if input_ndim == 4:
26 | batch_size, channel, height, width = hidden_states.shape
27 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
28 |
29 | batch_size, sequence_length, _ = (
30 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
31 | )
32 |
33 | if attention_mask is not None:
34 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
35 | # scaled_dot_product_attention expects attention_mask shape to be
36 | # (batch, heads, source_length, target_length)
37 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
38 |
39 | if attn.group_norm is not None:
40 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
41 |
42 | args = () if USE_PEFT_BACKEND else (scale,)
43 | query = attn.to_q(hidden_states, *args)
44 |
45 | if encoder_hidden_states is None:
46 | encoder_hidden_states = hidden_states
47 | elif attn.norm_cross:
48 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
49 |
50 | key = attn.to_k(encoder_hidden_states, *args)
51 | value = attn.to_v(encoder_hidden_states, *args)
52 |
53 | inner_dim = key.shape[-1]
54 | head_dim = inner_dim // attn.heads
55 |
56 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
57 |
58 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
59 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
60 |
61 | if use_fp32_attention:
62 | query = query.float()
63 | key = key.float()
64 | value = value.float()
65 | if attention_mask is not None:
66 | attention_mask = attention_mask.to(query.dtype)
67 |
68 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
69 | # TODO: add support for attn.scale when we move to Torch 2.1
70 | with torch.cuda.amp.autocast(enabled=not use_fp32_attention):
71 | hidden_states = F.scaled_dot_product_attention(
72 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
73 | )
74 |
75 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
76 | hidden_states = hidden_states.to(torch.float16)
77 |
78 | # linear proj
79 | hidden_states = attn.to_out[0](hidden_states, *args)
80 | # dropout
81 | hidden_states = attn.to_out[1](hidden_states)
82 |
83 | if input_ndim == 4:
84 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
85 |
86 | if attn.residual_connection:
87 | hidden_states = hidden_states + residual
88 |
89 | hidden_states = hidden_states / attn.rescale_output_factor
90 |
91 | return hidden_states
--------------------------------------------------------------------------------
/scripts/DMD/transformer_train/generate.py:
--------------------------------------------------------------------------------
1 | # contains functions of generating samples
2 | import torch
3 | from accelerate.utils.other import extract_model_from_parallel
4 |
5 |
6 | def model_forward(generator, encoder_hidden_states, encoder_attention_mask, added_cond_kwargs, noise, start_ts):
7 | if isinstance(start_ts, int):
8 | # convert int to long
9 | start_ts_net_in = torch.zeros((noise.size()[0],)) + start_ts
10 | start_ts_net_in = start_ts_net_in.long().to(noise.device)
11 | else:
12 | start_ts_net_in = start_ts.to(noise.device)
13 | noise_pred = generator(hidden_states=noise, encoder_hidden_states=encoder_hidden_states,
14 | encoder_attention_mask=encoder_attention_mask, added_cond_kwargs=added_cond_kwargs,
15 | imestep=start_ts_net_in).sample
16 | B, C = noise.shape[:2]
17 | assert noise_pred.shape == (B, C * 2, *noise.shape[2:])
18 | noise_pred = torch.split(noise_pred, C, dim=1)[0]
19 | return noise_pred
20 |
21 |
22 | def generate_sample_1step(model, scheduler, latents, maxt, prompt_embeds, prompt_attention_masks=None):
23 | t = torch.full((1,), maxt, device=latents.device).long()
24 | noise_pred = forward_model(
25 | model,
26 | latents=latents,
27 | timestep=t,
28 | prompt_embeds=prompt_embeds,
29 | prompt_attention_masks=prompt_attention_masks,
30 | )
31 | latents = eps_to_mu(scheduler, noise_pred, latents, t)
32 | return latents
33 |
34 | def eps_to_mu(scheduler, model_output, sample, timesteps):
35 | alphas_cumprod = scheduler.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
36 | alpha_prod_t = alphas_cumprod[timesteps]
37 | while len(alpha_prod_t.shape) < len(sample.shape):
38 | alpha_prod_t = alpha_prod_t.unsqueeze(-1)
39 | beta_prod_t = 1 - alpha_prod_t
40 | pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
41 | return pred_original_sample
42 |
43 |
44 | def forward_model(model, latents, timestep, prompt_embeds, prompt_attention_masks=None):
45 | added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
46 | if extract_model_from_parallel(model).config.sample_size == 128:
47 | batch_size, _, height, width = latents.shape
48 | resolution = torch.tensor([height, width]).repeat(batch_size, 1)
49 | aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1)
50 | resolution = resolution.to(dtype=prompt_embeds.dtype, device=prompt_embeds.device)
51 | aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=prompt_embeds.device)
52 | added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
53 |
54 | timestep = timestep.expand(latents.shape[0])
55 |
56 | noise_pred = model(
57 | latents,
58 | timestep=timestep,
59 | encoder_hidden_states=prompt_embeds,
60 | encoder_attention_mask=prompt_attention_masks,
61 | added_cond_kwargs=added_cond_kwargs,
62 | ).sample
63 |
64 | if extract_model_from_parallel(model).config.out_channels // 2 == latents.shape[1]:
65 | noise_pred = noise_pred.chunk(2, dim=1)[0]
66 |
67 | return noise_pred
--------------------------------------------------------------------------------
/scripts/DMD/transformer_train/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | import re
5 | import shutil
6 |
7 | from accelerate.checkpointing import save_accelerator_state, save_custom_state
8 | from accelerate.logging import get_logger
9 | from accelerate.utils import (
10 | MODEL_NAME,
11 | DistributedType,
12 | save_fsdp_model,
13 | save_fsdp_optimizer,
14 | is_deepspeed_available
15 | )
16 |
17 | from PIL import Image
18 |
19 | if is_deepspeed_available():
20 | import deepspeed
21 |
22 | from accelerate.utils import (
23 | DeepSpeedEngineWrapper,
24 | DeepSpeedOptimizerWrapper,
25 | DeepSpeedSchedulerWrapper,
26 | DummyOptim,
27 | DummyScheduler,
28 | )
29 |
30 | try:
31 | from torch.optim.lr_scheduler import LRScheduler
32 | except ImportError:
33 | from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
34 |
35 | logger = get_logger(__name__)
36 |
37 | ############# Saving model utils
38 |
39 | def accelerate_save_state(accelerator, output_dir=None, save_unet_only=False, unet_id=0, **save_model_func_kwargs):
40 | """
41 | Saves the current states of the model, optimizer, scaler, RNG generators, and registered objects to a folder.
42 |
43 | If a `ProjectConfiguration` was passed to the `Accelerator` object with `automatic_checkpoint_naming` enabled
44 | then checkpoints will be saved to `self.project_dir/checkpoints`. If the number of current saves is greater
45 | than `total_limit` then the oldest save is deleted. Each checkpoint is saved in seperate folders named
46 | `checkpoint_`.
47 |
48 | Otherwise they are just saved to `output_dir`.
49 |
50 |
51 |
52 | Should only be used when wanting to save a checkpoint during training and restoring the state in the same
53 | environment.
54 |
55 |
56 |
57 | Args:
58 | output_dir (`str` or `os.PathLike`):
59 | The name of the folder to save all relevant weights and states.
60 | save_model_func_kwargs (`dict`, *optional*):
61 | Additional keyword arguments for saving model which can be passed to the underlying save function, such
62 | as optional arguments for DeepSpeed's `save_checkpoint` function.
63 |
64 | Example:
65 |
66 | ```python
67 | >>> from accelerate import Accelerator
68 |
69 | >>> accelerator = Accelerator()
70 | >>> model, optimizer, lr_scheduler = ...
71 | >>> model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
72 | >>> accelerator.save_state(output_dir="my_checkpoint")
73 | ```
74 | """
75 | if accelerator.project_configuration.automatic_checkpoint_naming:
76 | output_dir = os.path.join(accelerator.project_dir, "checkpoints")
77 | os.makedirs(output_dir, exist_ok=True)
78 | if accelerator.project_configuration.automatic_checkpoint_naming:
79 | folders = [os.path.join(output_dir, folder) for folder in os.listdir(output_dir)]
80 | if accelerator.project_configuration.total_limit is not None and (
81 | len(folders) + 1 > accelerator.project_configuration.total_limit
82 | ):
83 |
84 | def _inner(folder):
85 | return list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0]
86 |
87 | folders.sort(key=_inner)
88 | logger.warning(
89 | f"Deleting {len(folders) + 1 - accelerator.project_configuration.total_limit} checkpoints to make room for new checkpoint."
90 | )
91 | for folder in folders[: len(folders) + 1 - accelerator.project_configuration.total_limit]:
92 | shutil.rmtree(folder)
93 | output_dir = os.path.join(output_dir, f"checkpoint_{accelerator.save_iteration}")
94 | if os.path.exists(output_dir):
95 | raise ValueError(
96 | f"Checkpoint directory {output_dir} ({accelerator.save_iteration}) already exists. Please manually override `self.save_iteration` with what iteration to start with."
97 | )
98 | os.makedirs(output_dir, exist_ok=True)
99 | logger.info(f"Saving current state to {output_dir}")
100 |
101 |
102 | # Save the models taking care of FSDP and DeepSpeed nuances
103 |
104 | weights = []
105 | for i, model in enumerate(accelerator._models):
106 | if save_unet_only and i != unet_id:
107 | continue
108 | if accelerator.distributed_type == DistributedType.FSDP:
109 | logger.info("Saving FSDP model")
110 | save_fsdp_model(accelerator.state.fsdp_plugin, accelerator, model, output_dir, i)
111 | logger.info(f"FSDP Model saved to output dir {output_dir}")
112 | elif accelerator.distributed_type == DistributedType.DEEPSPEED:
113 | logger.info("Saving DeepSpeed Model and Optimizer")
114 | ckpt_id = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}"
115 | model.save_checkpoint(output_dir, ckpt_id, **save_model_func_kwargs)
116 | logger.info(f"DeepSpeed Model and Optimizer saved to output dir {os.path.join(output_dir, ckpt_id)}")
117 | elif accelerator.distributed_type == DistributedType.MEGATRON_LM:
118 | logger.info("Saving Megatron-LM Model, Optimizer and Scheduler")
119 | model.save_checkpoint(output_dir)
120 | logger.info(f"Megatron-LM Model , Optimizer and Scheduler saved to output dir {output_dir}")
121 | else:
122 | weights.append(accelerator.get_state_dict(model, unwrap=False))
123 |
124 | # Save the optimizers taking care of FSDP and DeepSpeed nuances
125 | optimizers = []
126 | if not save_unet_only:
127 | if accelerator.distributed_type == DistributedType.FSDP:
128 | for i, opt in enumerate(accelerator._optimizers):
129 | logger.info("Saving FSDP Optimizer")
130 | save_fsdp_optimizer(accelerator.state.fsdp_plugin, accelerator, opt, accelerator._models[i], output_dir, i)
131 | logger.info(f"FSDP Optimizer saved to output dir {output_dir}")
132 | elif accelerator.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
133 | optimizers = accelerator._optimizers
134 |
135 | # Save the lr schedulers taking care of DeepSpeed nuances
136 | schedulers = []
137 | if accelerator.distributed_type == DistributedType.DEEPSPEED:
138 | for i, scheduler in enumerate(accelerator._schedulers):
139 | if isinstance(scheduler, DeepSpeedSchedulerWrapper):
140 | continue
141 | schedulers.append(scheduler)
142 | elif accelerator.distributed_type not in [DistributedType.MEGATRON_LM]:
143 | schedulers = accelerator._schedulers
144 |
145 | # Call model loading hooks that might have been registered with
146 | # accelerator.register_model_state_hook
147 | for hook in accelerator._save_model_state_pre_hook.values():
148 | hook(accelerator._models, weights, output_dir)
149 |
150 | save_location = save_accelerator_state(
151 | output_dir, weights, optimizers, schedulers, accelerator.state.process_index, accelerator.scaler
152 | )
153 | for i, obj in enumerate(accelerator._custom_objects):
154 | save_custom_state(obj, output_dir, i)
155 |
156 | accelerator.project_configuration.iteration += 1
157 | return save_location
158 |
159 |
160 | #### calculation
161 | def compute_snr(timesteps, noise_scheduler):
162 | """
163 | Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
164 | """
165 | alphas_cumprod = noise_scheduler.alphas_cumprod
166 | sqrt_alphas_cumprod = alphas_cumprod**0.5
167 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
168 |
169 | # Expand the tensors.
170 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
171 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
172 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
173 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
174 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
175 |
176 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
177 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
178 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
179 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
180 |
181 | # Compute SNR.
182 | snr = (alpha / sigma) ** 2
183 | return snr
184 |
185 |
186 | def save_image(image, path):
187 | image = (image / 2 + 0.5).clamp(0, 1)
188 | image = image.cpu().permute(0, 2, 3, 1).float().numpy()
189 | image = (image * 255).round().astype("uint8")
190 | image = [Image.fromarray(im) for im in image]
191 | image[0].save(path)
192 |
--------------------------------------------------------------------------------
/scripts/inference_pipeline.py:
--------------------------------------------------------------------------------
1 | from transformers import T5EncoderModel
2 | from diffusers import PixArtAlphaPipeline, Transformer2DModel,DEISMultistepScheduler, DPMSolverMultistepScheduler
3 | import torch
4 | import gc
5 | import argparse
6 | import pathlib
7 | from pathlib import Path
8 | import sys
9 |
10 | current_file_path = Path(__file__).resolve()
11 | sys.path.insert(0, str(current_file_path.parent.parent))
12 | from scripts.diffusers_patches import pixart_sigma_init_patched_inputs
13 |
14 | def main(args):
15 | setattr(Transformer2DModel, '_init_patched_inputs', pixart_sigma_init_patched_inputs)
16 |
17 | gc.collect()
18 | torch.cuda.empty_cache()
19 |
20 | repo_path = args.repo_path
21 | output_image = pathlib.Path(args.output)
22 | positive_prompt = args.positive_prompt
23 | negative_prompt = args.negative_prompt
24 | image_width = args.width
25 | image_height = args.height
26 | num_steps = args.num_steps
27 | guidance_scale = args.guidance_scale
28 | seed = args.seed
29 | low_vram = args.low_vram
30 | num_images = args.num_images
31 | scheduler_type = args.scheduler
32 | karras = args.karras
33 | algorithm_type = args.algorithm
34 | beta_schedule = args.beta_schedule
35 | use_lu_lambdas = args.use_lu_lambdas
36 |
37 | pipe = None
38 | if low_vram:
39 | print('low_vram')
40 | text_encoder = T5EncoderModel.from_pretrained(
41 | repo_path,
42 | subfolder="text_encoder",
43 | load_in_8bit=True,
44 | torch_dtype=torch.float16,
45 | )
46 | pipe = PixArtAlphaPipeline.from_pretrained(
47 | repo_path,
48 | text_encoder=text_encoder,
49 | transformer=None,
50 | torch_dtype=torch.float16,
51 | )
52 |
53 | with torch.no_grad():
54 | prompt = positive_prompt
55 | negative = negative_prompt
56 | prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt, negative_prompt=negative)
57 |
58 | def flush():
59 | gc.collect()
60 | torch.cuda.empty_cache()
61 |
62 | pipe.text_encoder = None
63 | del text_encoder
64 | flush()
65 |
66 | pipe.transformer = Transformer2DModel.from_pretrained(repo_path, subfolder='transformer',
67 | load_in_8bit=True,
68 | torch_dtype=torch.float16)
69 | pipe.to('cuda')
70 | else:
71 | print('low_vram=False')
72 | pipe = PixArtAlphaPipeline.from_pretrained(
73 | repo_path,
74 | ).to('cuda')
75 |
76 | with torch.no_grad():
77 | prompt = positive_prompt
78 | negative = negative_prompt
79 | prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt, negative_prompt=negative)
80 |
81 | generator = torch.Generator()
82 |
83 | if seed != -1:
84 | generator = generator.manual_seed(seed)
85 | else:
86 | generator = None
87 |
88 | prompt_embeds = prompt_embeds.to('cuda')
89 | negative_embeds = negative_embeds.to('cuda')
90 | prompt_attention_mask = prompt_attention_mask.to('cuda')
91 | negative_prompt_attention_mask = negative_prompt_attention_mask.to('cuda')
92 |
93 | if scheduler_type == 'deis':
94 | pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)
95 | else:
96 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
97 |
98 | pipe.scheduler.beta_schedule = beta_schedule
99 | pipe.scheduler.algorithm_type = algorithm_type
100 | pipe.scheduler.use_karras_sigmas = karras
101 | pipe.scheduler.use_lu_lambdas = use_lu_lambdas
102 | latents = pipe(
103 | negative_prompt=None,
104 | num_inference_steps=num_steps,
105 | height=image_height,
106 | width=image_width,
107 | prompt_embeds=prompt_embeds,
108 | guidance_scale=guidance_scale,
109 | negative_prompt_embeds=negative_embeds,
110 | prompt_attention_mask=prompt_attention_mask,
111 | negative_prompt_attention_mask=negative_prompt_attention_mask,
112 | num_images_per_prompt=num_images,
113 | output_type="latent",
114 | generator=generator,
115 | ).images
116 |
117 | words = str(output_image).split('.')
118 | filename = words[0]
119 | extension = words[1]
120 |
121 | with torch.no_grad():
122 | images = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
123 | images = pipe.image_processor.postprocess(images, output_type="pil")
124 |
125 | i = 0
126 | for image in images:
127 | image.save(filename + str(i) + '.' + extension)
128 | i = i + 1
129 |
130 | if __name__ == '__main__':
131 | parser = argparse.ArgumentParser()
132 | parser.add_argument('--repo_path', required=True, type=str, help='Local path or remote path to the pipeline folder')
133 | parser.add_argument('--output', required=False, type=str, default='out.png', help='Path to the generated output image. Supports most image formats i.e. .png, .jpg, .jpeg, .webp')
134 | parser.add_argument('--positive_prompt', required=True, type=str, help='Positive prompt to generate')
135 | parser.add_argument('--negative_prompt', required=False, type=str, default='', help='Negative prompt to generate')
136 | parser.add_argument('--seed', required=False, default=-1, type=int, help='Seed for the random generator')
137 | parser.add_argument('--width', required=False, default=512, type=int, help='Image width to generate')
138 | parser.add_argument('--height', required=False, default=512, type=int, help='Image height to generate')
139 | parser.add_argument('--num_steps', required=False, default=20, type=int, help='Number of inference steps')
140 | parser.add_argument('--guidance_scale', required=False, default=7.0, type=float, help='Guidance scale')
141 | parser.add_argument('--low_vram', required=False, action='store_true')
142 | parser.add_argument('--num_images', required=False, default=1, type=int, help='Number of images per prompt')
143 | parser.add_argument('--scheduler', required=False, default='dpm', type=str, choices=['dpm', 'deis'])
144 | parser.add_argument('--karras', required=False, action='store_true')
145 | parser.add_argument('--algorithm', required=False, default='sde-dpmsolver++', type=str, choices=['dpmsolver', 'dpmsolver++', 'sde-dpmsolver', 'sde-dpmsolver++'])
146 | parser.add_argument('--beta_schedule', required=False, default='linear', type=str, choices=['linear', 'scaled_linear', 'squaredcos_cap_v2'])
147 | parser.add_argument('--use_lu_lambdas', required=False, action='store_true')
148 |
149 | args = parser.parse_args()
150 | main(args)
--------------------------------------------------------------------------------
/scripts/run_pixart_dmd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | from pathlib import Path
5 |
6 | current_file_path = Path(__file__).resolve()
7 | sys.path.insert(0, str(current_file_path.parent.parent))
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("config", type=str, help="config")
11 | parser.add_argument('--work_dir', help='the dir to save logs and models')
12 | parser.add_argument('--init_method', type=str, default='tcp://127.0.0.1:6666', help='')
13 | parser.add_argument('--rank', type=str, default='0', help='')
14 | parser.add_argument('--world_size', type=str, default='1', help='')
15 | parser.add_argument("--is_debugging", action="store_true")
16 | parser.add_argument("--max_samples", type=int, default=500000, help='')
17 | parser.add_argument("--regression_weight", type=float, default=0.25, help='regression loss weight')
18 | parser.add_argument("--resume_from", type=str, default='', help='path of resumed checkpoint')
19 | parser.add_argument("--mixed_precision", type=str, default='no', help='whether use mixed precision')
20 | parser.add_argument("--batch_size", type=int, default=1, help='batch size per gpu')
21 | parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help='gradient_accumulation_steps')
22 | parser.add_argument("--use_dm", type=int, default=1, help='use distribution matching loss')
23 | parser.add_argument("--use_regression", type=int, default=1, help='use regression loss')
24 | parser.add_argument("--one_step_maxt", type=int, default=999, help='maximum timestep of one step generator')
25 | parser.add_argument("--learning_rate", type=float, default=1e-5, help='learning rate')
26 | parser.add_argument("--lr_fake_multiplier", type=float, default=1.0, help='lr of fake model / lr of set lr')
27 | parser.add_argument("--max_grad_norm", type=int, default=10, help='batch size per gpu')
28 | parser.add_argument("--save_image_interval", type=int, default=500, help='iteration interval to save image')
29 | parser.add_argument("--checkpointing_steps", type=int, default=500,
30 | help=(
31 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
32 | " training using `--resume_from_checkpoint`."
33 | ),
34 | )
35 |
36 |
37 | args, args_run = parser.parse_known_args()
38 |
39 | if args.world_size != '1':
40 | os.environ['MASTER_PORT'] = args.init_method.split(':')[-1]
41 | os.environ['MASTER_ADDR'] = args.init_method.split(':')[1][2:]
42 | os.environ['WORLD_SIZE'] = args.world_size
43 | os.environ['RANK'] = args.rank
44 | for k in ['MASTER_PORT', 'MASTER_ADDR', 'WORLD_SIZE', 'RANK']:
45 | print (k, ':', os.environ[k])
46 |
47 | is_debugging = args.is_debugging
48 |
49 | use_dm = True if args.use_dm == 1 else False
50 | use_regression = True if args.use_regression == 1 else False
51 | laion_subset = 'subset6.25'
52 |
53 | if len(args.resume_from) == 0:
54 | resume_from = None
55 | else:
56 | resume_from = args.resume_from
57 |
58 | # activate enviroment
59 | print('setting up env')
60 | print('env set')
61 |
62 | if args.mixed_precision == 'no':
63 | acc_common_args = '--use_fsdp --fsdp_offload_params="False" --fsdp_sharding_strategy=2'
64 | else:
65 | acc_common_args = '--mixed_precision="fp16" --use_fsdp --fsdp_offload_params="False" --fsdp_sharding_strategy=2'
66 |
67 | main_args = (
68 | f'--config="{args.config}" '
69 | f'--train_batch_size={args.batch_size} '
70 | f'--one_step_maxt={args.one_step_maxt} '
71 | f'--output_dir={args.work_dir} '
72 | f'--learning_rate={args.learning_rate} '
73 | f'--max_samples={args.max_samples} '
74 | f'--node_id={int(args.rank)} '
75 | f'--gradient_accumulation_steps={args.gradient_accumulation_steps} '
76 | f'--checkpointing_steps={args.checkpointing_steps} '
77 | f'--lr_fake_multiplier={args.lr_fake_multiplier} '
78 | f'--max_grad_norm={args.max_grad_norm} '
79 | f'--save_image_interval={args.save_image_interval} '
80 | '--max_train_steps=1000000 '
81 | '--di_steps=1 '
82 | '--start_ts=999 '
83 | '--cfg=3 '
84 | '--dataloader_num_workers=16 '
85 | '--resolution=512 '
86 | '--center_crop'
87 | '--random_flip '
88 | '--use_ema '
89 | '--lr_scheduler="constant" '
90 | '--lr_warmup_steps=0 '
91 | '--logging_dir="_logs" '
92 | '--report_to=tensorboard '
93 | '--adam_epsilon=1e-06 '
94 | '--seed=0 '
95 | )
96 |
97 | if args.mixed_precision == 'fp16':
98 | main_args += '--mixed_precision="fp16" '
99 |
100 | if use_dm:
101 | main_args += '--use_dm '
102 | if use_regression:
103 | main_args += f'--use_regression --regression_weight={args.regression_weight} '
104 |
105 | if resume_from is not None:
106 | main_args += f'--resume_from_checkpoint="{resume_from}" '
107 |
108 | if is_debugging:
109 | num_gpus_per_node = 2
110 | else:
111 | num_gpus_per_node = 8
112 |
113 | if args.world_size != '1':
114 | num_processes = int(args.world_size) * num_gpus_per_node
115 | print('num_processes', num_processes)
116 | run_cmd = (f'accelerate launch {acc_common_args} '
117 | f'--num_machines={args.world_size} '
118 | f'--num_processes={num_processes} '
119 | f'--machine_rank={os.environ["RANK"]} '
120 | f'--main_process_ip={os.environ["MASTER_ADDR"]} '
121 | f'--main_process_port={os.environ["MASTER_PORT"]} '
122 | f'train_scripts/train_pixart_dmd.py {main_args}'
123 | )
124 | else:
125 | run_cmd = (f'accelerate launch {acc_common_args} '
126 | f'--num_machines={args.world_size} '
127 | f'--num_processes={num_gpus_per_node} '
128 | 'train_scripts/train_pixart_dmd.py '
129 | f'{main_args}'
130 | )
131 |
132 | print('run_cmd', run_cmd)
133 |
134 | print('running')
135 | os.system(run_cmd)
136 | print('done')
137 |
--------------------------------------------------------------------------------
/scripts/style.css:
--------------------------------------------------------------------------------
1 | /*.gradio-container{width:680px!important}*/
2 | /* style.css */
3 | .gradio_group, .gradio_row, .gradio_column {
4 | display: flex;
5 | flex-direction: row;
6 | justify-content: flex-start;
7 | align-items: flex-start;
8 | flex-wrap: wrap;
9 | }
--------------------------------------------------------------------------------
/tools/convert_diffusers_to_pipeline.py:
--------------------------------------------------------------------------------
1 |
2 | from safetensors import safe_open
3 | from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPipeline, Transformer2DModel
4 | from transformers import T5EncoderModel, T5Tokenizer
5 | import pathlib
6 | import argparse
7 | import gc
8 | import torch
9 | import sys
10 |
11 | from pathlib import Path
12 | current_file_path = Path(__file__).resolve()
13 | sys.path.insert(0, str(current_file_path.parent.parent))
14 | from scripts.diffusers_patches import pixart_sigma_init_patched_inputs
15 |
16 | interpolation_scale_sigma = {256: 0.5, 512: 1, 1024: 2, 2048: 4}
17 | ckpt_id = "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers"
18 |
19 | def flush_memory():
20 | gc.collect()
21 | torch.cuda.empty_cache()
22 |
23 | def main(args):
24 | safetensors_file = pathlib.Path(args.safetensors_path)
25 | image_size = args.image_size
26 |
27 | setattr(Transformer2DModel, '_init_patched_inputs', pixart_sigma_init_patched_inputs)
28 | pathlib.Path(args.output_folder).mkdir(parents=True, exist_ok=True)
29 | transformer = Transformer2DModel(
30 | sample_size=image_size // 8,
31 | num_layers=28,
32 | attention_head_dim=72,
33 | in_channels=4,
34 | out_channels=8,
35 | patch_size=2,
36 | attention_bias=True,
37 | num_attention_heads=16,
38 | cross_attention_dim=1152,
39 | activation_fn="gelu-approximate",
40 | num_embeds_ada_norm=1000,
41 | norm_type="ada_norm_single",
42 | norm_elementwise_affine=False,
43 | norm_eps=1e-6,
44 | caption_channels=4096,
45 | interpolation_scale=interpolation_scale_sigma[image_size],
46 | ).to('cuda')
47 |
48 | state_dict = {}
49 | with safe_open(safetensors_file, framework='pt') as f:
50 | for k in f.keys():
51 | state_dict[k] = f.get_tensor(k)
52 | transformer.load_state_dict(state_dict, strict=True)
53 |
54 | transformer.save_pretrained(pathlib.Path.joinpath(pathlib.Path(args.output_folder), 'transformer'))
55 | scheduler = DPMSolverMultistepScheduler()
56 | vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae")
57 | tokenizer = T5Tokenizer.from_pretrained(ckpt_id, subfolder="tokenizer")
58 | text_encoder = T5EncoderModel.from_pretrained(ckpt_id, subfolder="text_encoder")
59 |
60 | pipe = PixArtAlphaPipeline(transformer=transformer, scheduler=scheduler, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder)
61 | pipe.save_config(pathlib.Path(args.output_folder))
62 | del pipe
63 | del transformer
64 | del scheduler
65 | del vae
66 | del tokenizer
67 | del text_encoder
68 | flush_memory()
69 |
70 | scheduler = DPMSolverMultistepScheduler()
71 | scheduler.save_pretrained(pathlib.Path.joinpath(pathlib.Path(args.output_folder), 'scheduler'))
72 | del scheduler
73 | flush_memory()
74 |
75 | vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae").to('cuda')
76 | vae.save_pretrained(pathlib.Path.joinpath(pathlib.Path(args.output_folder), 'vae'))
77 | del vae
78 | flush_memory()
79 |
80 | tokenizer = T5Tokenizer.from_pretrained(ckpt_id, subfolder="tokenizer")
81 | tokenizer.save_pretrained(pathlib.Path.joinpath(pathlib.Path(args.output_folder), 'tokenizer'))
82 | del tokenizer
83 | flush_memory()
84 |
85 | text_encoder = T5EncoderModel.from_pretrained(ckpt_id, subfolder="text_encoder")
86 | text_encoder.save_pretrained(pathlib.Path.joinpath(pathlib.Path(args.output_folder), 'text_encoder'))
87 | del text_encoder
88 | flush_memory()
89 |
90 | if __name__ == '__main__':
91 | parser = argparse.ArgumentParser()
92 |
93 | parser.add_argument('--safetensors_path', required=True, type=str, help='Path to the .safetensors file to convert to diffusers folder structure')
94 | parser.add_argument('--image_size', required=False, default=512, type=int, choices=[256, 512, 1024, 2048], help='Image size of pretrained model')
95 | parser.add_argument('--output_folder', required=True, type=str, help='Path to the output folder')
96 | parser.add_argument('--multistep', required=False, type=bool, default=True, help='Multistep option')
97 | args = parser.parse_args()
98 | main(args)
--------------------------------------------------------------------------------
/tools/convert_diffusers_to_pixart.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import annotations
3 |
4 | import argparse
5 | import sys
6 | from pathlib import Path
7 |
8 | from safetensors import safe_open
9 |
10 | current_file_path = Path(__file__).resolve()
11 | sys.path.insert(0, str(current_file_path.parent.parent))
12 |
13 | import torch
14 |
15 | ckpt_id = "PixArt-alpha"
16 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125
17 | interpolation_scale_alpha = {256: 1, 512: 1, 1024: 2}
18 | interpolation_scale_sigma = {256: 0.5, 512: 1, 1024: 2, 2048: 4}
19 |
20 | def main(args):
21 | # load the pipe, but only the transformer
22 | repo_path = args.safetensor_path
23 | output_path = args.pth_path
24 |
25 | transformer = safe_open(repo_path, framework='pt')
26 |
27 | state_dict = transformer.keys()
28 |
29 | layer_depth = sum([key.endswith("attn1.to_out.0.weight") for key in state_dict]) or 28
30 |
31 | # check for micro condition??
32 | converted_state_dict = {
33 | 'x_embedder.proj.weight': transformer.get_tensor('pos_embed.proj.weight'),
34 | 'x_embedder.proj.bias': transformer.get_tensor('pos_embed.proj.bias'),
35 | 'y_embedder.y_proj.fc1.weight': transformer.get_tensor('caption_projection.linear_1.weight'),
36 | 'y_embedder.y_proj.fc1.bias': transformer.get_tensor('caption_projection.linear_1.bias'),
37 | 'y_embedder.y_proj.fc2.weight': transformer.get_tensor('caption_projection.linear_2.weight'),
38 | 'y_embedder.y_proj.fc2.bias': transformer.get_tensor('caption_projection.linear_2.bias'),
39 | 't_embedder.mlp.0.weight': transformer.get_tensor('adaln_single.emb.timestep_embedder.linear_1.weight'),
40 | 't_embedder.mlp.0.bias': transformer.get_tensor('adaln_single.emb.timestep_embedder.linear_1.bias'),
41 | 't_embedder.mlp.2.weight': transformer.get_tensor('adaln_single.emb.timestep_embedder.linear_2.weight'),
42 | 't_embedder.mlp.2.bias': transformer.get_tensor('adaln_single.emb.timestep_embedder.linear_2.bias')
43 | }
44 | if 'adaln_single.emb.resolution_embedder.linear_1.weight' in state_dict:
45 | converted_state_dict['csize_embedder.mlp.0.weight'] = transformer.get_tensor('adaln_single.emb.resolution_embedder.linear_1.weight')
46 | converted_state_dict['csize_embedder.mlp.0.bias'] = transformer.get_tensor('adaln_single.emb.resolution_embedder.linear_1.bias')
47 | converted_state_dict['csize_embedder.mlp.2.weight'] = transformer.get_tensor('adaln_single.emb.resolution_embedder.linear_2.weight')
48 | converted_state_dict['csize_embedder.mlp.2.bias'] = transformer.get_tensor('adaln_single.emb.resolution_embedder.linear_2.bias')
49 | converted_state_dict['ar_embedder.mlp.0.weight'] = transformer.get_tensor('adaln_single.emb.aspect_ratio_embedder.linear_1.weight')
50 | converted_state_dict['ar_embedder.mlp.0.bias'] = transformer.get_tensor('adaln_single.emb.aspect_ratio_embedder.linear_1.bias')
51 | converted_state_dict['ar_embedder.mlp.2.weight'] = transformer.get_tensor('adaln_single.emb.aspect_ratio_embedder.linear_2.weight')
52 | converted_state_dict['ar_embedder.mlp.2.bias'] = transformer.get_tensor('adaln_single.emb.aspect_ratio_embedder.linear_2.bias')
53 |
54 | # shared norm
55 | converted_state_dict['t_block.1.weight'] = transformer.get_tensor('adaln_single.linear.weight')
56 | converted_state_dict['t_block.1.bias'] = transformer.get_tensor('adaln_single.linear.bias')
57 |
58 | for depth in range(layer_depth):
59 | print(f"Converting layer {depth}")
60 | converted_state_dict[f"blocks.{depth}.scale_shift_table"] = transformer.get_tensor(f"transformer_blocks.{depth}.scale_shift_table")
61 |
62 | # self attention
63 | q = transformer.get_tensor(f'transformer_blocks.{depth}.attn1.to_q.weight')
64 | q_bias = transformer.get_tensor(f'transformer_blocks.{depth}.attn1.to_q.bias')
65 | k = transformer.get_tensor(f'transformer_blocks.{depth}.attn1.to_k.weight')
66 | k_bias = transformer.get_tensor(f'transformer_blocks.{depth}.attn1.to_k.bias')
67 | v = transformer.get_tensor(f'transformer_blocks.{depth}.attn1.to_v.weight')
68 | v_bias = transformer.get_tensor(f'transformer_blocks.{depth}.attn1.to_v.bias')
69 | converted_state_dict[f'blocks.{depth}.attn.qkv.weight'] = torch.cat((q, k, v))
70 | converted_state_dict[f'blocks.{depth}.attn.qkv.bias'] = torch.cat((q_bias, k_bias, v_bias))
71 |
72 | # projection
73 | converted_state_dict[f"blocks.{depth}.attn.proj.weight"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn1.to_out.0.weight")
74 | converted_state_dict[f"blocks.{depth}.attn.proj.bias"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn1.to_out.0.bias")
75 |
76 | # check for qk norm
77 | if f'transformer_blocks.{depth}.attn1.q_norm.weight' in state_dict:
78 | converted_state_dict[f"blocks.{depth}.attn.q_norm.weight"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn1.q_norm.weight")
79 | converted_state_dict[f"blocks.{depth}.attn.q_norm.bias"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn1.q_norm.bias")
80 | converted_state_dict[f"blocks.{depth}.attn.k_norm.weight"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn1.k_norm.weight")
81 | converted_state_dict[f"blocks.{depth}.attn.k_norm.bias"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn1.k_norm.bias")
82 |
83 | # feed-forward
84 | converted_state_dict[f"blocks.{depth}.mlp.fc1.weight"] = transformer.get_tensor(f"transformer_blocks.{depth}.ff.net.0.proj.weight")
85 | converted_state_dict[f"blocks.{depth}.mlp.fc1.bias"] = transformer.get_tensor(f"transformer_blocks.{depth}.ff.net.0.proj.bias")
86 | converted_state_dict[f"blocks.{depth}.mlp.fc2.weight"] = transformer.get_tensor(f"transformer_blocks.{depth}.ff.net.2.weight")
87 | converted_state_dict[f"blocks.{depth}.mlp.fc2.bias"] = transformer.get_tensor(f"transformer_blocks.{depth}.ff.net.2.bias")
88 |
89 | # cross-attention
90 | q = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_q.weight")
91 | q_bias = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_q.bias")
92 | k = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_k.weight")
93 | k_bias = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_k.bias")
94 | v = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_v.weight")
95 | v_bias = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_v.bias")
96 |
97 | converted_state_dict[f"blocks.{depth}.cross_attn.q_linear.weight"] = q
98 | converted_state_dict[f"blocks.{depth}.cross_attn.q_linear.bias"] = q_bias
99 | converted_state_dict[f"blocks.{depth}.cross_attn.kv_linear.weight"] = torch.cat((k, v))
100 | converted_state_dict[f"blocks.{depth}.cross_attn.kv_linear.bias"] = torch.cat((k_bias, v_bias))
101 |
102 | converted_state_dict[f"blocks.{depth}.cross_attn.proj.weight"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_out.0.weight")
103 | converted_state_dict[f"blocks.{depth}.cross_attn.proj.bias"] = transformer.get_tensor(f"transformer_blocks.{depth}.attn2.to_out.0.bias")
104 |
105 | # final block
106 | converted_state_dict["final_layer.linear.weight"] = transformer.get_tensor("proj_out.weight")
107 | converted_state_dict["final_layer.linear.bias"] = transformer.get_tensor("proj_out.bias")
108 | converted_state_dict["final_layer.scale_shift_table"] = transformer.get_tensor("scale_shift_table")
109 |
110 | # save the state_dict
111 | to_save = {}
112 | to_save['state_dict'] = converted_state_dict
113 | torch.save(to_save, output_path)
114 |
115 | if __name__ == '__main__':
116 | parser = argparse.ArgumentParser()
117 | parser.add_argument('--safetensor_path', type=str, required=True, help='Path and filename of a safetensor file to convert. i.e. output/mymodel.safetensors')
118 | parser.add_argument('--pth_path', type=str, required=True, help='Path and filename to the output file i.e. output/mymodel.pth')
119 |
120 | args = parser.parse_args()
121 | main(args)
122 |
--------------------------------------------------------------------------------
/tools/convert_images_to_json.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 | from os import walk
4 | from os import path
5 | from PIL import Image
6 | import json
7 | import tqdm
8 |
9 | def print_usage():
10 | print('convert_images_to_json [params] images_path output_path')
11 | print('--caption_extension')
12 |
13 | def main():
14 | args = sys.argv
15 | if len(args) < 3:
16 | print_usage()
17 | return
18 |
19 | input_folder = args[1]
20 | output_folder = args[-1]
21 |
22 | caption_extension = '.txt'
23 | try:
24 | caption_arg = args.index('--caption_extension')
25 | caption_extension = args[caption_arg + 1]
26 | except:
27 | pass
28 |
29 | # create a folder with the output path
30 | output_folder = Path(output_folder)
31 | output_folder.mkdir(parents=True, exist_ok=True)
32 |
33 | # create a InternData and a InternImgs inside the output path
34 | intern_data_folder = output_folder.joinpath('InternData')
35 | intern_data_folder.mkdir(parents=True, exist_ok=True)
36 |
37 | intern_imgs_folder = output_folder.joinpath('InternImgs')
38 | intern_imgs_folder.mkdir(parents=True, exist_ok=True)
39 |
40 | # create a data_info.json inside InternData
41 | data_info_path = intern_data_folder.joinpath('data_info.json')
42 |
43 | # create a table which will contain all the entries
44 | json_entries = []
45 | with open(data_info_path, 'w') as json_file:
46 | for (dirpath, dirnames, filenames) in walk(input_folder):
47 | for filename in tqdm.tqdm(filenames):
48 | if not caption_extension in filename:
49 | continue
50 |
51 | # check if an image exists for this caption
52 | image_filename = filename[:-len(caption_extension)]
53 |
54 | for image_extension in ['.jpg', '.png', '.jpeg', 'webp', '.JPEG', '.JPG']:
55 | image_path = Path(dirpath).joinpath(image_filename + image_extension)
56 | if path.exists(image_path):
57 | write_entry(json_entries, dirpath, image_path, Path(dirpath).joinpath(filename), image_filename + image_extension, intern_imgs_folder)
58 | break
59 |
60 | # use the entries
61 | json_file.write(json.dumps(json_entries))
62 |
63 | def write_entry(json_entries, folder, image_path, caption_path, image_filename, intern_imgs_path):
64 | # open the file containing the prompt
65 | with open(caption_path) as prompt_file:
66 | prompt = prompt_file.read()
67 |
68 | # read the images info
69 | image = Image.open(image_path)
70 | image_width = image.width
71 | image_height = image.height
72 | ratio = image_height / image_width
73 |
74 | entry = {}
75 | entry['width'] = image_width
76 | entry['height'] = image_height
77 | entry['ratio'] = ratio
78 | entry['path'] = image_filename
79 | entry['prompt'] = prompt
80 | entry['sharegpt4v'] = ''
81 |
82 | json_entries.append(entry)
83 |
84 | # make sure to copy the image to the internimgs folder with the new filename!
85 | image_output_path = intern_imgs_path.joinpath(image_filename)
86 | image.save(image_output_path)
87 |
88 | if __name__ == '__main__':
89 | main()
--------------------------------------------------------------------------------
/tools/download.py:
--------------------------------------------------------------------------------
1 |
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """
9 | Functions for downloading pre-trained PixArt models
10 | """
11 | from torchvision.datasets.utils import download_url
12 | import torch
13 | import os
14 | import argparse
15 |
16 |
17 | pretrained_models = {
18 | 'PixArt-Sigma-XL-2-512-MS.pth', 'PixArt-Sigma-XL-2-256x256.pth', 'PixArt-Sigma-XL-2-1024-MS.pth'
19 | }
20 |
21 |
22 | def find_model(model_name):
23 | """
24 | Finds a pre-trained G.pt model, downloading it if necessary. Alternatively, loads a model from a local path.
25 | """
26 | if model_name in pretrained_models: # Find/download our pre-trained G.pt checkpoints
27 | return download_model(model_name)
28 | else: # Load a custom PixArt checkpoint:
29 | assert os.path.isfile(model_name), f'Could not find PixArt checkpoint at {model_name}'
30 | return torch.load(model_name, map_location=lambda storage, loc: storage)
31 |
32 |
33 | def download_model(model_name):
34 | """
35 | Downloads a pre-trained PixArt model from the web.
36 | """
37 | assert model_name in pretrained_models
38 | local_path = f'output/pretrained_models/{model_name}'
39 | if not os.path.isfile(local_path):
40 | hf_endpoint = os.environ.get("HF_ENDPOINT")
41 | if hf_endpoint is None:
42 | hf_endpoint = "https://huggingface.co"
43 | os.makedirs('output/pretrained_models', exist_ok=True)
44 | web_path = f'{hf_endpoint}/PixArt-alpha/PixArt-Sigma/resolve/main/{model_name}'
45 | download_url(web_path, 'output/pretrained_models/')
46 | model = torch.load(local_path, map_location=lambda storage, loc: storage)
47 | return model
48 |
49 |
50 | if __name__ == "__main__":
51 | parser = argparse.ArgumentParser()
52 | parser.add_argument('--model_names', nargs='+', type=str, default=pretrained_models)
53 | args = parser.parse_args()
54 | model_names = args.model_names
55 | model_names = set(model_names)
56 |
57 | # Download PixArt checkpoints
58 | for model in model_names:
59 | download_model(model)
60 | print('Done.')
61 |
--------------------------------------------------------------------------------
/tools/generate_dmd_data_noise_pairs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import annotations
3 | import argparse
4 | import os
5 | import sys
6 | import json
7 | from pathlib import Path
8 |
9 | current_file_path = Path(__file__).resolve()
10 | sys.path.insert(0, str(current_file_path.parent.parent))
11 |
12 | import numpy as np
13 | import random
14 | import torch
15 | from diffusers import PixArtAlphaPipeline, Transformer2DModel
16 | from tqdm import tqdm
17 |
18 | MAX_SEED = np.iinfo(np.int32).max
19 |
20 | def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
21 | if randomize_seed:
22 | seed = random.randint(0, MAX_SEED)
23 | return seed
24 |
25 | @torch.no_grad()
26 | @torch.inference_mode()
27 | def generate(items):
28 |
29 | seed = int(randomize_seed_fn(0, randomize_seed=False))
30 | generator = torch.Generator().manual_seed(seed)
31 |
32 | for item in tqdm(items, "Generating: "):
33 | prompt = item['prompt']
34 | save_name = item['path'].split('.')[0]
35 |
36 | # noise
37 | latent_size = pipe.transformer.config.sample_size
38 | noise = torch.randn(
39 | (1, 4, latent_size, latent_size), generator=generator, dtype=torch.float32
40 | ).to(weight_dtype).to(device)
41 |
42 | # image
43 | img_latent = pipe(
44 | prompt=prompt,
45 | latents=noise.to(weight_dtype),
46 | generator=generator,
47 | output_type="latent",
48 | max_sequence_length=T5_token_max_length,
49 | ).images[0]
50 |
51 | # save noise-denoised latent features
52 | noise_save_path = os.path.join(noise_save_dir, f"{save_name}.npy")
53 | np.save(noise_save_path, noise[0].cpu().numpy())
54 | img_latent_save_path = os.path.join(img_latent_save_dir, f"{save_name}.npy")
55 | np.save(img_latent_save_path, img_latent.cpu().numpy())
56 |
57 | if args.save_img:
58 | image = pipe.vae.decode(img_latent / pipe.vae.config.scaling_factor, return_dict=False)[0]
59 | image = pipe.image_processor.postprocess(image, output_type="pil")
60 | img_save_path = os.path.join(img_save_dir, f"{save_name}.png")
61 | image.save(img_save_path)
62 |
63 |
64 | def parse_args():
65 | parser = argparse.ArgumentParser(description="Process some integers.")
66 | parser.add_argument('--work-dir', help='the dir to save logs and models')
67 | parser.add_argument('--save_img', action='store_true', help='if save latents and images at the same time')
68 | parser.add_argument('--sample_nums', default=640_000, type=int, help='sample numbers')
69 | parser.add_argument('--T5_token_max_length', default=120, type=int, choices=[120, 300], help='T5 token length')
70 | parser.add_argument(
71 | '--model_path', default="PixArt-alpha/PixArt-XL-2-512x512", help='the dir to load a ckpt for teacher model')
72 | parser.add_argument(
73 | '--pipeline_load_from', default="PixArt-alpha/PixArt-XL-2-1024-MS", type=str,
74 | help="Download for loading text_encoder, "
75 | "tokenizer and vae from https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS")
76 | args = parser.parse_args()
77 | return args
78 |
79 |
80 | # Use PixArt-Alpha to generate PixArt-Alpha-DMD training data (noise-image pairs).
81 | if __name__ == '__main__':
82 | args = parse_args()
83 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
84 |
85 | metadata_json_list = ["pixart-sigma-toy-dataset/InternData/data_info.json",]
86 | # dataset
87 | meta_data_clean = []
88 | for json_file in metadata_json_list:
89 | meta_data = json.load(open(json_file, 'r'))
90 | meta_data_clean.extend([item for item in meta_data if item['ratio'] <= 4.5])
91 |
92 | weight_dtype = torch.float16
93 | T5_token_max_length = args.T5_token_max_length
94 | if torch.cuda.is_available():
95 |
96 | # Teacher Model
97 | pipe = PixArtAlphaPipeline.from_pretrained(
98 | args.pipeline_load_from,
99 | transformer=None,
100 | torch_dtype=weight_dtype,
101 | )
102 | pipe.transformer = Transformer2DModel.from_pretrained(
103 | args.model_path, subfolder="transformer", torch_dtype=weight_dtype
104 | )
105 | pipe.to(device)
106 |
107 | print(f"INFO: Select only first {args.sample_nums} samples")
108 | meta_data_clean = meta_data_clean[:args.sample_nums]
109 |
110 | # save path
111 | if args.save_img:
112 | img_save_dir = os.path.join(f'pixart-sigma-toy-dataset/InternData/InternImgs_DMD_images')
113 | os.makedirs(img_save_dir, exist_ok=True)
114 | img_latent_save_dir = os.path.join(f'pixart-sigma-toy-dataset/InternData/InternImgs_DMD_latents')
115 | noise_save_dir = os.path.join(f'pixart-sigma-toy-dataset/InternData/InternImgs_DMD_noises')
116 | os.umask(0o000) # file permission: 666; dir permission: 777
117 | os.makedirs(img_latent_save_dir, exist_ok=True)
118 | os.makedirs(noise_save_dir, exist_ok=True)
119 |
120 | # generate
121 | generate(meta_data_clean)
122 |
123 |
124 |
--------------------------------------------------------------------------------
/tools/merge_transformers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import annotations
3 | import argparse
4 | import os
5 | import sys
6 | from pathlib import Path
7 | import gc
8 |
9 | current_file_path = Path(__file__).resolve()
10 | sys.path.insert(0, str(current_file_path.parent.parent))
11 |
12 | import torch
13 | from transformers import T5EncoderModel, T5Tokenizer
14 | import pathlib
15 |
16 | from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPipeline, Transformer2DModel
17 | from scripts.diffusers_patches import pixart_sigma_init_patched_inputs
18 |
19 | interpolation_scale_sigma = {256: 0.5, 512: 1, 1024: 2, 2048: 4}
20 |
21 | def main(args):
22 | # load the first checkpoint
23 | repo_path_a = args.repo_path_a
24 | repo_path_b = args.repo_path_b
25 | output_folder = pathlib.Path(args.output_folder)
26 | ratio = args.ratio
27 |
28 | setattr(Transformer2DModel, '_init_patched_inputs', pixart_sigma_init_patched_inputs)
29 | transformer_a = Transformer2DModel.from_pretrained(repo_path_a, subfolder='transformer')
30 | state_dict_a = transformer_a.state_dict()
31 |
32 | # load the second checkpoint
33 | transformer_b = Transformer2DModel.from_pretrained(repo_path_b, subfolder='transformer')
34 | state_dict_b = transformer_b.state_dict()
35 |
36 | new_state_dict = {}
37 |
38 | for key, value in state_dict_a.items():
39 | value_a = state_dict_a[key]
40 | value_b = state_dict_b[key]
41 | new_val = torch.lerp(value_a, value_b, ratio)
42 | new_state_dict[key] = new_val
43 |
44 | # delete the transformers to reduce RAM requirements
45 | del transformer_a
46 | del transformer_b
47 | del state_dict_a
48 | del state_dict_b
49 | gc.collect()
50 |
51 | # save the new transformer
52 | new_transformer = Transformer2DModel.from_pretrained(repo_path_a, subfolder='transformer')
53 | new_transformer.load_state_dict(new_state_dict)
54 | new_transformer.save_pretrained(output_folder)
55 |
56 | if __name__ == '__main__':
57 | parser = argparse.ArgumentParser()
58 | parser.add_argument('--repo_path_a', required=True, type=str)
59 | parser.add_argument('--repo_path_b', required=True, type=str)
60 | parser.add_argument('--output_folder', required=True, type=str)
61 | parser.add_argument('--ratio', required=True, type=float)
62 | parser.add_argument('--version', required=False, default='sigma', type=str)
63 |
64 | args = parser.parse_args()
65 | main(args)
--------------------------------------------------------------------------------
/train_scripts/train_dmd.sh:
--------------------------------------------------------------------------------
1 | config=configs/pixart_app_config/PixArt-DMD_xl2_img512_internalms.py
2 | work_dir=output/debug/
3 |
4 | # machine
5 | machine_num=1
6 | np=8
7 |
8 | # training settings
9 | max_samples=500000
10 | batchsize=1
11 | mix_precision='no'
12 | use_dm=1
13 | use_regression=1
14 | regression_weight=0.25
15 | one_step_maxt=400
16 | learning_rate=1e-6
17 | lr_fake_multiplier=1.0
18 | max_grad_norm=10
19 | gradient_accumulation_steps=2
20 | save_image_interval=5000
21 |
22 | # resume from checkpoint
23 | resume_from=""
24 |
25 | # train
26 | python_command="python scripts/run_pixart_dmd.py --world_size 1 ${config} "
27 | python_command+="--is_debugging "
28 | python_command+="--work_dir=${work_dir} --machine_num=${machine_num} --np=${np} "
29 | python_command+="--max_samples=${max_samples} "
30 | python_command+="--batch_size=${batchsize} --mixed_precision=${mix_precision} "
31 | python_command+="--use_dm=${use_dm} "
32 | python_command+="--use_regression=${use_regression} --regression_weight=${regression_weight} "
33 | python_command+="--one_step_maxt=${one_step_maxt} "
34 | python_command+="--learning_rate=${learning_rate} "
35 | python_command+="--lr_fake_multiplier=${lr_fake_multiplier} "
36 | python_command+="--max_grad_norm=${max_grad_norm} "
37 | python_command+="--gradient_accumulation_steps=${gradient_accumulation_steps} "
38 | python_command+="--save_image_interval=${save_image_interval} "
39 |
40 | python_command+="--resume_from=${resume_from} "
41 |
42 | eval $python_command
--------------------------------------------------------------------------------
/train_scripts/train_pixart_lora.sh:
--------------------------------------------------------------------------------
1 | pip install -U peft
2 |
3 | dataset_id=svjack/pokemon-blip-captions-en-zh
4 | model_id=PixArt-alpha/PixArt-XL-2-512x512
5 |
6 | accelerate launch --num_processes=1 --main_process_port=36667 train_scripts/train_pixart_lora_hf.py \
7 | --mixed_precision="fp16" \
8 | --pretrained_model_name_or_path=$model_id \
9 | --dataset_name=$dataset_id \
10 | --caption_column="text" \
11 | --resolution=512 \
12 | --random_flip \
13 | --train_batch_size=16 \
14 | --num_train_epochs=80 \
15 | --checkpointing_steps=1000 \
16 | --learning_rate=1e-05 \
17 | --lr_scheduler="constant" \
18 | --lr_warmup_steps=0 \
19 | --seed=42 \
20 | --output_dir="output/pixart-pokemon-model" \
21 | --validation_prompt="cute dragon creature" \
22 | --report_to="tensorboard" \
23 | --gradient_checkpointing \
24 | --checkpoints_total_limit=10 \
25 | --validation_epochs=5 \
26 | --max_token_length=120 \
27 | --rank=16
--------------------------------------------------------------------------------