├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── configs ├── flux_dev_config.json ├── flux_schnell_config.json ├── flux_vae │ └── config.json ├── gemma_2_2b │ ├── config.json │ ├── special_tokens_map.json │ ├── tokenizer.json │ ├── tokenizer.model │ └── tokenizer_config.json ├── hy_vae_config.json └── t5_old │ ├── config.json │ ├── spiece.model │ └── tokenizer.json ├── docs └── supported_models.md ├── examples ├── cosmos_dataset.toml ├── dataset.toml ├── main_example.toml ├── recommended_lumina_dataset_config.toml └── wan_14b_min_vram.toml ├── models ├── base.py ├── chroma.py ├── cosmos.py ├── flux.py ├── hidream.py ├── hunyuan_video.py ├── ltx_video.py ├── lumina_2.py ├── sdxl.py └── wan.py ├── optimizers ├── adamw_8bit.py ├── automagic.py ├── gradient_release.py └── optimizer_utils.py ├── requirements.txt ├── tools ├── cosmos_vae_test.py ├── hunyuan_video_vae_test.py ├── image_resize_test.py └── wan_vae_test.py ├── train.py └── utils ├── common.py ├── dataset.py ├── isolate_rng.py ├── offloading.py ├── patches.py ├── pipeline.py ├── saver.py └── unsloth_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/HunyuanVideo"] 2 | path = submodules/HunyuanVideo 3 | url = https://github.com/Tencent/HunyuanVideo 4 | [submodule "submodules/Cosmos"] 5 | path = submodules/Cosmos 6 | url = https://github.com/NVIDIA/Cosmos 7 | [submodule "submodules/Lumina_2"] 8 | path = submodules/Lumina_2 9 | url = https://github.com/Alpha-VLLM/Lumina-Image-2.0 10 | [submodule "submodules/Wan2_1"] 11 | path = submodules/Wan2_1 12 | url = https://github.com/Wan-Video/Wan2.1 13 | [submodule "submodules/flow"] 14 | path = submodules/flow 15 | url = https://github.com/lodestone-rock/flow 16 | [submodule "submodules/HiDream"] 17 | path = submodules/HiDream 18 | url = https://github.com/HiDream-ai/HiDream-I1 19 | [submodule "submodules/LTX_Video"] 20 | path = submodules/LTX_Video 21 | url = https://github.com/Lightricks/LTX-Video 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 tdrussell 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # diffusion-pipe 2 | A pipeline parallel training script for diffusion models. 3 | 4 | Currently supports SDXL, Flux, LTX-Video, HunyuanVideo (t2v), Cosmos, Lumina Image 2.0, Wan2.1 (t2v and i2v), Chroma, HiDream. 5 | 6 | **Work in progress.** This is a side project for me and my time is limited. I will try to add new models and features when I can. 7 | 8 | ## Features 9 | - Pipeline parallelism, for training models larger than can fit on a single GPU 10 | - Useful metrics logged to Tensorboard 11 | - Compute metrics on a held-out eval set, for measuring generalization 12 | - Training state checkpointing and resuming from checkpoint 13 | - Efficient multi-process, multi-GPU pre-caching of latents and text embeddings 14 | - Seemlessly supports both image and video models in a unified way 15 | - Easily add new models by implementing a single subclass 16 | 17 | ## Recent changes 18 | - 2025-05-22 19 | - Add Automagic optimizer 20 | - Support i2v training for LTX-Video. Thanks @GallenShao for the PR! 21 | - Support multiple shuffling of tags when caching text embeddings. Credit to @gitmylo for the PR. 22 | - 2025-05-07 23 | - Switch to official implementation of LTX-Video. Allows training the 13b LTX-Video model. 24 | - 2025-04-19 25 | - Add support for first-frame-last-frame Wan model. Credit to @kabachuha for the PR. 26 | - Add wandb support. Credit to @ecarmen16 for the PR. 27 | - 2025-04-18 28 | - Fix block swapping for HiDream. With ```blocks_to_swap = 24``` you can train rank 32 LoRA on a single 4090. 29 | - Support nf4 quantization for HiDream. With nf4 transformer, you can train LoRA on a single 4090 even without block swapping. See supported models doc for how to enable. 30 | - 2025-04-15 31 | - Support HiDream. 32 | - 2025-03-18 33 | - Add unsloth activation checkpointing. Reduces VRAM for a small performance hit. 34 | - Add partition_split option for manually controlling how layers are divided across multiple GPUs. Thanks @arczewski for the PR! 35 | - 2025-03-16 36 | - Support loading any optimizer from the pytorch-optimizer library. 37 | - Wan transformer and UMT5 can now be loaded from ComfyUI files. Thanks to @qiwang1996 for the PR! 38 | - 2025-03-09 39 | - Block swapping is supported for Wan, HunyuanVideo, Flux, and Chroma. 40 | - Big thanks to @kohya-ss and [Musubi Tuner](https://github.com/kohya-ss/musubi-tuner) from which most of the implementation is taken. 41 | - See the example hunyuan_video.toml file for how to configure. 42 | - Reduced memory use of Wan by removing some forced casts to float32. I am able to measure a very small, but consistent increase in validation loss, so there is at least some decrease in quality. But the memory savings are large when training on videos, and it is likely worth it. 43 | - On the 14B t2v model, by using fp8 transformer, AdamW8bitKahan optimizer, and offloading most of the blocks (e.g. blocks_to_swap=32), you can (just barely) train 512x512x81 sized videos on a single 4090. 44 | 45 | ## Windows support 46 | It will be difficult or impossible to make training work on native Windows. This is because Deepspeed only has [partial Windows support](https://github.com/microsoft/DeepSpeed/blob/master/blogs/windows/08-2024/README.md). Deepspeed is a hard requirement because the entire training script is built around Deepspeed pipeline parallelism. However, it will work on Windows Subsystem for Linux, specifically WSL 2. If you must use Windows I recommend trying WSL 2. 47 | 48 | ## Installing 49 | Clone the repository: 50 | ``` 51 | git clone --recurse-submodules https://github.com/tdrussell/diffusion-pipe 52 | ``` 53 | 54 | If you alread cloned it and forgot to do --recurse-submodules: 55 | ``` 56 | git submodule init 57 | git submodule update 58 | ``` 59 | 60 | Install Miniconda: https://docs.anaconda.com/miniconda/ 61 | 62 | Create the environment: 63 | ``` 64 | conda create -n diffusion-pipe python=3.12 65 | conda activate diffusion-pipe 66 | ``` 67 | 68 | Install PyTorch first. As of this writing (May 5, 2025), you need PyTorch 2.6.0 CUDA 12.4 version (or earlier) for flash attention to work: 69 | ``` 70 | pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 71 | ``` 72 | 73 | Install nvcc: https://anaconda.org/nvidia/cuda-nvcc. Probably try to make it match the CUDA version of PyTorch. 74 | 75 | Install the rest of the dependencies: 76 | ``` 77 | pip install -r requirements.txt 78 | ``` 79 | 80 | ### Cosmos requirements 81 | NVIDIA Cosmos additionally requires TransformerEngine. This dependency isn't in the requirements file. Installing this was a bit tricky for me. On Ubuntu 24.04, I had to install GCC version 12 (13 is the default in the package manager), and make sure GCC 12 and CUDNN were set during installation like this: 82 | ``` 83 | CC=/usr/bin/gcc-12 CUDNN_PATH=/home/anon/miniconda3/envs/diffusion-pipe/lib/python3.12/site-packages/nvidia/cudnn pip install transformer_engine[pytorch] 84 | ``` 85 | 86 | ## Dataset preparation 87 | A dataset consists of one or more directories containing image or video files, and corresponding captions. You can mix images and videos in the same directory, but it's probably a good idea to separate them in case you need to specify certain settings on a per-directory basis. Caption files should be .txt files with the same base name as the corresponding media file, e.g. image1.png should have caption file image1.txt in the same directory. If a media file doesn't have a matching caption file, a warning is printed, but training will proceed with an empty caption. 88 | 89 | For images, any image format that can be loaded by Pillow should work. For videos, any format that can be loaded by ImageIO should work. Note that this means **WebP videos are not supported**, because ImageIO can't load multi-frame WebPs. 90 | 91 | ## Supported models 92 | See the [supported models doc](./docs/supported_models.md) for more information on how to configure each model, the options it supports, and the format of the saved LoRAs. 93 | 94 | ## Training 95 | **Start by reading through the config files in the examples directory.** Almost everything is commented, explaining what each setting does. [This config file](./examples/main_example.toml) is the main example with all of the comments. [This dataset config file](./examples/dataset.toml) has the documentation for the dataset options. 96 | 97 | Once you've familiarized yourself with the config file format, go ahead and make a copy and edit to your liking. At minimum, change all the paths to conform to your setup, including the paths in the dataset config file. 98 | 99 | Launch training like this: 100 | ``` 101 | NCCL_P2P_DISABLE="1" NCCL_IB_DISABLE="1" deepspeed --num_gpus=1 train.py --deepspeed --config examples/hunyuan_video.toml 102 | ``` 103 | RTX 4000 series needs those 2 environment variables set. Other GPUs may not need them. You can try without them, Deepspeed will complain if it's wrong. 104 | 105 | If you enabled checkpointing, you can resume training from the latest checkpoint by simply re-running the exact same command but with the `--resume_from_checkpoint` flag. You can also specify a specific checkpoint folder name after the flag to resume from that particular checkpoint (e.g. `--resume_from_checkpoint "20250212_07-06-40"`). This option is particularly useful if you have run multiple training sessions with different datasets and want to resume from a specific training folder. 106 | 107 | Please note that resuming from checkpoint uses the **config file on the command line**, not the config file saved into the output directory. You are responsible for making sure that the config file you pass in matches what was previously used. 108 | 109 | ## Output files 110 | A new directory will be created in ```output_dir``` for each training run. This contains the checkpoints, saved models, and Tensorboard metrics. Saved models/LoRAs will be in directories named like epoch1, epoch2, etc. Deepspeed checkpoints are in directories named like global_step1234. These checkpoints contain all training state, including weights, optimizer, and dataloader state, but can't be used directly for inference. The saved model directory will have the safetensors weights, PEFT adapter config JSON, as well as the diffusion-pipe config file for easier tracking of training run settings. 111 | 112 | ## Reducing VRAM requirements 113 | The [wan_14b_min_vram.toml](./examples/wan_14b_min_vram.toml) example file has all of these settings enabled. 114 | - Use AdamW8BitKahan optimizer: 115 | ``` 116 | [optimizer] 117 | type = 'AdamW8bitKahan' 118 | lr = 5e-5 119 | betas = [0.9, 0.99] 120 | weight_decay = 0.01 121 | stabilize = false 122 | ``` 123 | - Use block swapping if the model supports it: ```blocks_to_swap = 32``` 124 | - Try the expandable_segments feature in the CUDA memory allocator: 125 | - ```PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True NCCL_P2P_DISABLE="1" NCCL_IB_DISABLE="1" deepspeed --num_gpus=1 train.py --deepspeed --config /home/you/path/to/config.toml``` 126 | - I've seen this help a lot when training on video with multiple aspect ratio buckets. 127 | - On my system, sometimes this causes random CUDA failures. If training gets through a few steps though, it will train indefinitely without failures. Very weird. 128 | - Use unsloth activation checkpointing: ```activation_checkpointing = 'unsloth'``` 129 | 130 | ## Parallelism 131 | This code uses hybrid data- and pipeline-parallelism. Set the ```--num_gpus``` flag appropriately for your setup. Set ```pipeline_stages``` in the config file to control the degree of pipeline parallelism. Then the data parallelism degree will automatically be set to use all GPUs (number of GPUs must be divisible by pipeline_stages). For example, with 4 GPUs and pipeline_stages=2, you will run two instances of the model, each divided across two GPUs. 132 | 133 | ## Pre-caching 134 | Latents and text embeddings are cached to disk before training happens. This way, the VAE and text encoders don't need to be kept loaded during training. The Huggingface Datasets library is used for all the caching. Cache files are reused between training runs if they exist. All cache files are written into a directory named "cache" inside each dataset directory. 135 | 136 | This caching also means that training LoRAs for text encoders is not currently supported. 137 | 138 | Two flags are relevant for caching. ```--cache_only``` does the caching flow, then exits without training anything. ```--regenerate_cache``` forces cache regeneration. If you edit the dataset in-place (like changing a caption), you need to force regenerate the cache (or delete the cache dir) for the changes to be picked up. 139 | 140 | ## Extra 141 | You can check out my [qlora-pipe](https://github.com/tdrussell/qlora-pipe) project, which is basically the same thing as this but for LLMs. 142 | -------------------------------------------------------------------------------- /configs/flux_dev_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "FluxTransformer2DModel", 3 | "_diffusers_version": "0.30.0.dev0", 4 | "_name_or_path": "../checkpoints/flux-dev/transformer", 5 | "attention_head_dim": 128, 6 | "guidance_embeds": true, 7 | "in_channels": 64, 8 | "joint_attention_dim": 4096, 9 | "num_attention_heads": 24, 10 | "num_layers": 19, 11 | "num_single_layers": 38, 12 | "patch_size": 1, 13 | "pooled_projection_dim": 768 14 | } 15 | -------------------------------------------------------------------------------- /configs/flux_schnell_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "FluxTransformer2DModel", 3 | "_diffusers_version": "0.30.0.dev0", 4 | "attention_head_dim": 128, 5 | "guidance_embeds": false, 6 | "in_channels": 64, 7 | "joint_attention_dim": 4096, 8 | "num_attention_heads": 24, 9 | "num_layers": 19, 10 | "num_single_layers": 38, 11 | "patch_size": 1, 12 | "pooled_projection_dim": 768 13 | } -------------------------------------------------------------------------------- /configs/flux_vae/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKL", 3 | "_diffusers_version": "0.30.0.dev0", 4 | "_name_or_path": "../checkpoints/flux-dev", 5 | "act_fn": "silu", 6 | "block_out_channels": [ 7 | 128, 8 | 256, 9 | 512, 10 | 512 11 | ], 12 | "down_block_types": [ 13 | "DownEncoderBlock2D", 14 | "DownEncoderBlock2D", 15 | "DownEncoderBlock2D", 16 | "DownEncoderBlock2D" 17 | ], 18 | "force_upcast": true, 19 | "in_channels": 3, 20 | "latent_channels": 16, 21 | "latents_mean": null, 22 | "latents_std": null, 23 | "layers_per_block": 2, 24 | "mid_block_add_attention": true, 25 | "norm_num_groups": 32, 26 | "out_channels": 3, 27 | "sample_size": 1024, 28 | "scaling_factor": 0.3611, 29 | "shift_factor": 0.1159, 30 | "up_block_types": [ 31 | "UpDecoderBlock2D", 32 | "UpDecoderBlock2D", 33 | "UpDecoderBlock2D", 34 | "UpDecoderBlock2D" 35 | ], 36 | "use_post_quant_conv": false, 37 | "use_quant_conv": false 38 | } 39 | -------------------------------------------------------------------------------- /configs/gemma_2_2b/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "Gemma2ForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "attn_logit_softcapping": 50.0, 8 | "bos_token_id": 2, 9 | "cache_implementation": "hybrid", 10 | "eos_token_id": 1, 11 | "final_logit_softcapping": 30.0, 12 | "head_dim": 256, 13 | "hidden_act": "gelu_pytorch_tanh", 14 | "hidden_activation": "gelu_pytorch_tanh", 15 | "hidden_size": 2304, 16 | "initializer_range": 0.02, 17 | "intermediate_size": 9216, 18 | "max_position_embeddings": 8192, 19 | "model_type": "gemma2", 20 | "num_attention_heads": 8, 21 | "num_hidden_layers": 26, 22 | "num_key_value_heads": 4, 23 | "pad_token_id": 0, 24 | "query_pre_attn_scalar": 256, 25 | "rms_norm_eps": 1e-06, 26 | "rope_theta": 10000.0, 27 | "sliding_window": 4096, 28 | "torch_dtype": "float32", 29 | "transformers_version": "4.42.4", 30 | "use_cache": true, 31 | "vocab_size": 256000 32 | } 33 | -------------------------------------------------------------------------------- /configs/gemma_2_2b/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "additional_special_tokens": [ 3 | "", 4 | "" 5 | ], 6 | "bos_token": { 7 | "content": "", 8 | "lstrip": false, 9 | "normalized": false, 10 | "rstrip": false, 11 | "single_word": false 12 | }, 13 | "eos_token": { 14 | "content": "", 15 | "lstrip": false, 16 | "normalized": false, 17 | "rstrip": false, 18 | "single_word": false 19 | }, 20 | "pad_token": { 21 | "content": "", 22 | "lstrip": false, 23 | "normalized": false, 24 | "rstrip": false, 25 | "single_word": false 26 | }, 27 | "unk_token": { 28 | "content": "", 29 | "lstrip": false, 30 | "normalized": false, 31 | "rstrip": false, 32 | "single_word": false 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/gemma_2_2b/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdrussell/diffusion-pipe/7fd5d8c361f64714ac63462b1f28148abad199b7/configs/gemma_2_2b/tokenizer.model -------------------------------------------------------------------------------- /configs/hy_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKLCausal3D", 3 | "_diffusers_version": "0.4.2", 4 | "act_fn": "silu", 5 | "block_out_channels": [ 6 | 128, 7 | 256, 8 | 512, 9 | 512 10 | ], 11 | "down_block_types": [ 12 | "DownEncoderBlockCausal3D", 13 | "DownEncoderBlockCausal3D", 14 | "DownEncoderBlockCausal3D", 15 | "DownEncoderBlockCausal3D" 16 | ], 17 | "in_channels": 3, 18 | "latent_channels": 16, 19 | "layers_per_block": 2, 20 | "norm_num_groups": 32, 21 | "out_channels": 3, 22 | "sample_size": 256, 23 | "sample_tsize": 64, 24 | "up_block_types": [ 25 | "UpDecoderBlockCausal3D", 26 | "UpDecoderBlockCausal3D", 27 | "UpDecoderBlockCausal3D", 28 | "UpDecoderBlockCausal3D" 29 | ], 30 | "scaling_factor": 0.476986, 31 | "time_compression_ratio": 4, 32 | "mid_block_add_attention": true 33 | } 34 | -------------------------------------------------------------------------------- /configs/t5_old/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 65536, 6 | "d_kv": 128, 7 | "d_model": 1024, 8 | "decoder_start_token_id": 0, 9 | "dropout_rate": 0.1, 10 | "eos_token_id": 1, 11 | "initializer_factor": 1.0, 12 | "is_encoder_decoder": true, 13 | "layer_norm_epsilon": 1e-06, 14 | "model_type": "t5", 15 | "n_positions": 512, 16 | "num_heads": 128, 17 | "num_layers": 24, 18 | "output_past": true, 19 | "pad_token_id": 0, 20 | "relative_attention_num_buckets": 32, 21 | "task_specific_params": { 22 | "summarization": { 23 | "early_stopping": true, 24 | "length_penalty": 2.0, 25 | "max_length": 200, 26 | "min_length": 30, 27 | "no_repeat_ngram_size": 3, 28 | "num_beams": 4, 29 | "prefix": "summarize: " 30 | }, 31 | "translation_en_to_de": { 32 | "early_stopping": true, 33 | "max_length": 300, 34 | "num_beams": 4, 35 | "prefix": "translate English to German: " 36 | }, 37 | "translation_en_to_fr": { 38 | "early_stopping": true, 39 | "max_length": 300, 40 | "num_beams": 4, 41 | "prefix": "translate English to French: " 42 | }, 43 | "translation_en_to_ro": { 44 | "early_stopping": true, 45 | "max_length": 300, 46 | "num_beams": 4, 47 | "prefix": "translate English to Romanian: " 48 | } 49 | }, 50 | "vocab_size": 32128 51 | } 52 | -------------------------------------------------------------------------------- /configs/t5_old/spiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tdrussell/diffusion-pipe/7fd5d8c361f64714ac63462b1f28148abad199b7/configs/t5_old/spiece.model -------------------------------------------------------------------------------- /docs/supported_models.md: -------------------------------------------------------------------------------- 1 | # Summary 2 | 3 | | Model | LoRA | Full Fine Tune | fp8/quantization | 4 | |----------------|------|----------------|------------------| 5 | |SDXL |✅ |✅ |❌ | 6 | |Flux |✅ |✅ |✅ | 7 | |LTX-Video |✅ |❌ |❌ | 8 | |HunyuanVideo |✅ |❌ |✅ | 9 | |Cosmos |✅ |❌ |❌ | 10 | |Lumina Image 2.0|✅ |✅ |❌ | 11 | |Wan2.1 |✅ |❌ |✅ | 12 | |Chroma |✅ |✅ |✅ | 13 | |HiDream |✅ |❌ |✅ | 14 | 15 | 16 | ## SDXL 17 | ``` 18 | [model] 19 | type = 'sdxl' 20 | checkpoint_path = '/data2/imagegen_models/sdxl/sd_xl_base_1.0_0.9vae.safetensors' 21 | dtype = 'bfloat16' 22 | # You can train v-prediction models (e.g. NoobAI vpred) by setting this option. 23 | #v_pred = true 24 | # Min SNR is supported. Same meaning as sd-scripts 25 | #min_snr_gamma = 5 26 | # Debiased estimation loss is supported. Same meaning as sd-scripts. 27 | #debiased_estimation_loss = true 28 | # You can set separate learning rates for unet and text encoders. If one of these isn't set, the optimizer learning rate will apply. 29 | unet_lr = 4e-5 30 | text_encoder_1_lr = 2e-5 31 | text_encoder_2_lr = 2e-5 32 | ``` 33 | Unlike other models, for SDXL the text embeddings are not cached, and the text encoders are trained. 34 | 35 | SDXL can be full fine tuned. Just remove the [adapter] table in the config file. You will need 48GB VRAM. 2x24GB GPUs works with pipeline_stages=2. 36 | 37 | SDXL LoRAs are saved in Kohya sd-scripts format. SDXL full fine tune models are saved in the original SDXL checkpoint format. 38 | 39 | ## Flux 40 | ``` 41 | [model] 42 | type = 'flux' 43 | # Path to Huggingface Diffusers directory for Flux 44 | diffusers_path = '/data2/imagegen_models/FLUX.1-dev' 45 | # You can override the transformer from a BFL format checkpoint. 46 | #transformer_path = '/data2/imagegen_models/flux-dev-single-files/consolidated_s6700-schnell.safetensors' 47 | dtype = 'bfloat16' 48 | # Flux supports fp8 for the transformer when training LoRA. 49 | transformer_dtype = 'float8' 50 | # Resolution-dependent timestep shift towards more noise. Same meaning as sd-scripts. 51 | flux_shift = true 52 | # For FLEX.1-alpha, you can bypass the guidance embedding which is the recommended way to train that model. 53 | #bypass_guidance_embedding = true 54 | ``` 55 | For Flux, you can override the transformer weights by setting transformer_path to an original Black Forest Labs (BFL) format checkpoint. For example, the above config loads the model from Diffusers format FLUX.1-dev, but the transformer_path, if uncommented, loads the transformer from Flux Dev De-distill. 56 | 57 | Flux LoRAs are saved in Diffusers format. 58 | 59 | ## LTX-Video 60 | ``` 61 | [model] 62 | type = 'ltx-video' 63 | diffusers_path = '/data2/imagegen_models/LTX-Video' 64 | # Point this to one of the single checkpoint files to load the transformer and VAE from it. 65 | single_file_path = '/data2/imagegen_models/LTX-Video/ltx-video-2b-v0.9.1.safetensors' 66 | dtype = 'bfloat16' 67 | # Can load the transformer in fp8. 68 | #transformer_dtype = 'float8' 69 | timestep_sample_method = 'logit_normal' 70 | # Probability to use the first video frame as conditioning (i.e. i2v training). 71 | #first_frame_conditioning_p = 1.0 72 | ``` 73 | You can train the more recent LTX-Video versions by using single_file_path. Note that you will still need to set diffusers_path to the original model folder (it gets the text encoder from here). Only t2i and t2v training is supported. 74 | 75 | LTX-Video LoRAs are saved in ComfyUI format. 76 | 77 | ## HunyuanVideo 78 | ``` 79 | [model] 80 | type = 'hunyuan-video' 81 | # Can load Hunyuan Video entirely from the ckpt path set up for the official inference scripts. 82 | #ckpt_path = '/home/anon/HunyuanVideo/ckpts' 83 | # Or you can load it by pointing to all the ComfyUI files. 84 | transformer_path = '/data2/imagegen_models/hunyuan_video_comfyui/hunyuan_video_720_cfgdistill_fp8_e4m3fn.safetensors' 85 | vae_path = '/data2/imagegen_models/hunyuan_video_comfyui/hunyuan_video_vae_bf16.safetensors' 86 | llm_path = '/data2/imagegen_models/hunyuan_video_comfyui/llava-llama-3-8b-text-encoder-tokenizer' 87 | clip_path = '/data2/imagegen_models/hunyuan_video_comfyui/clip-vit-large-patch14' 88 | # Base dtype used for all models. 89 | dtype = 'bfloat16' 90 | # Hunyuan Video supports fp8 for the transformer when training LoRA. 91 | transformer_dtype = 'float8' 92 | # How to sample timesteps to train on. Can be logit_normal or uniform. 93 | timestep_sample_method = 'logit_normal' 94 | ``` 95 | HunyuanVideo LoRAs are saved in a Diffusers-style format. The keys are named according to the original model, and prefixed with "transformer.". This format will directly work with ComfyUI. 96 | 97 | ## Cosmos 98 | ``` 99 | [model] 100 | type = 'cosmos' 101 | # Point these paths at the ComfyUI files. 102 | transformer_path = '/data2/imagegen_models/cosmos/cosmos-1.0-diffusion-7b-text2world.pt' 103 | vae_path = '/data2/imagegen_models/cosmos/cosmos_cv8x8x8_1.0.safetensors' 104 | text_encoder_path = '/data2/imagegen_models/cosmos/oldt5_xxl_fp16.safetensors' 105 | dtype = 'bfloat16' 106 | ``` 107 | Tentative support is added for Cosmos (text2world diffusion variants). Compared to HunyuanVideo, Cosmos is not good for fine-tuning on commodity hardware. 108 | 109 | 1. Cosmos supports a fixed, limited set of resolutions and frame lengths. Because of this, the 7b model is actually slower to train than HunyuanVideo (12b parameters), because you can't get away with training on lower-resolution images like you can with Hunyuan. And video training is nearly impossible unless you have enormous amounts of VRAM, because for videos you must use the full 121 frame length. 110 | 2. Cosmos seems much worse at generalizing from image-only training to video. 111 | 3. The Cosmos base model is much more limited in the types of content that it knows, which makes fine tuning for most concepts more difficult. 112 | 113 | I will likely not be actively supporting Cosmos going forward. All the pieces are there, and if you really want to try training it you can. But don't expect me to spend time trying to fix things if something doesn't work right. 114 | 115 | Cosmos LoRAs are saved in ComfyUI format. 116 | 117 | ## Lumina Image 2.0 118 | ``` 119 | [model] 120 | type = 'lumina_2' 121 | # Point these paths at the ComfyUI files. 122 | transformer_path = '/data2/imagegen_models/lumina-2-single-files/lumina_2_model_bf16.safetensors' 123 | llm_path = '/data2/imagegen_models/lumina-2-single-files/gemma_2_2b_fp16.safetensors' 124 | vae_path = '/data2/imagegen_models/lumina-2-single-files/flux_vae.safetensors' 125 | dtype = 'bfloat16' 126 | lumina_shift = true 127 | ``` 128 | See the [Lumina 2 example dataset config](../examples/recommended_lumina_dataset_config.toml) which shows how to add a caption prefix and contains the recommended resolution settings. 129 | 130 | In addition to LoRA, Lumina 2 supports full fine tuning. It can be fine tuned at 1024x1024 resolution on a single 24GB GPU. For FFT, delete or comment out the [adapter] block in the config. If doing FFT with 24GB VRAM, you will need to use an alternative optimizer to lower VRAM use: 131 | ``` 132 | [optimizer] 133 | type = 'adamw8bitkahan' 134 | lr = 5e-6 135 | betas = [0.9, 0.99] 136 | weight_decay = 0.01 137 | eps = 1e-8 138 | gradient_release = true 139 | ``` 140 | 141 | This uses a custom AdamW8bit optimizer with Kahan summation (required for proper bf16 training), and it enables an experimental gradient release for more VRAM saving. If you are training only at 512 resolution, you can remove the gradient release part. If you have a >24GB GPU, or multiple GPUs and use pipeline parallelism, you can perhaps just use the normal adamw_optimi optimizer type. 142 | 143 | Lumina 2 LoRAs are saved in ComfyUI format. 144 | 145 | ## Wan2.1 146 | ``` 147 | [model] 148 | type = 'wan' 149 | ckpt_path = '/data2/imagegen_models/Wan2.1-T2V-1.3B' 150 | dtype = 'bfloat16' 151 | # You can use fp8 for the transformer when training LoRA. 152 | #transformer_dtype = 'float8' 153 | timestep_sample_method = 'logit_normal' 154 | ``` 155 | 156 | Both t2v and i2v Wan2.1 variants are supported. Set ckpt_path to the original model checkpoint directory, e.g. [Wan2.1-T2V-1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B). 157 | 158 | (Optional) You may skip downloading the transformer and UMT5 text encoder from the original checkpoint, and instead pass in paths to the ComfyUI safetensors files instead. 159 | 160 | Download checkpoint but skip the transformer and UMT5: 161 | ``` 162 | huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir Wan2.1-T2V-1.3B --exclude "diffusion_pytorch_model*" "models_t5*" 163 | ``` 164 | 165 | Then use this config: 166 | ``` 167 | [model] 168 | type = 'wan' 169 | ckpt_path = '/data2/imagegen_models/Wan2.1-T2V-1.3B' 170 | transformer_path = '/data2/imagegen_models/wan_comfyui/wan2.1_t2v_1.3B_bf16.safetensors' 171 | llm_path = '/data2/imagegen_models/wan_comfyui/wrapper/umt5-xxl-enc-bf16.safetensors' 172 | dtype = 'bfloat16' 173 | # You can use fp8 for the transformer when training LoRA. 174 | #transformer_dtype = 'float8' 175 | timestep_sample_method = 'logit_normal' 176 | ``` 177 | You still need ckpt_path, it's just that it can be missing the transformer files and/or UMT5. The transformer/UMT5 can be loaded from the native ComfyUI repackaged file, or the file for Kijai's wrapper extension. Additionally, you can mix and match components, for example, using the transformer from the ComfyUI repackaged repository alongside the UMT5 safetensors from Kijai's wrapper repository for training or other combinations. 178 | 179 | For i2v training, you **MUST** train on a dataset of only videos. The training script will crash with an error otherwise. The first frame of each video clip is used as the image conditioning, and the model is trained to predict the rest of the video. Please pay attention to the video_clip_mode setting. It defaults to 'single_beginning' if unset, which is reasonable for i2v training, but if you set it to something else during t2v training it may not be what you want for i2v. Only the 14B model has an i2v variant, and it requires training on videos, so VRAM requirements are high. Use block swapping as needed if you don't have enough VRAM. 180 | 181 | Wan2.1 LoRAs are saved in ComfyUI format. 182 | 183 | ## Chroma 184 | ``` 185 | [model] 186 | type = 'chroma' 187 | diffusers_path = '/data2/imagegen_models/FLUX.1-dev' 188 | transformer_path = '/data2/imagegen_models/chroma/chroma-unlocked-v10.safetensors' 189 | dtype = 'bfloat16' 190 | # You can optionally load the transformer in fp8 when training LoRAs. 191 | transformer_dtype = 'float8' 192 | flux_shift = true 193 | ``` 194 | Chroma is a model that is architecturally modifed and finetuned from Flux Schnell. The modifications are significant enough that it has its own model type. Set transformer_path to the Chroma single model file, and set diffusers_path to either Flux Dev or Schnell Diffusers folder (the Diffusers model is needed for loading the VAE and text encoder). 195 | 196 | Chroma LoRAs are saved in ComfyUI format. 197 | 198 | ## HiDream 199 | ``` 200 | [model] 201 | type = 'hidream' 202 | diffusers_path = '/data/imagegen_models/HiDream-I1-Full' 203 | llama3_path = '/data2/models/Meta-Llama-3.1-8B-Instruct' 204 | llama3_4bit = true 205 | dtype = 'bfloat16' 206 | transformer_dtype = 'float8' 207 | # Can use nf4 quantization for even more VRAM saving. 208 | #transformer_dtype = 'nf4' 209 | max_llama3_sequence_length = 128 210 | # Can use a resolution-dependent timestep shift, like Flux. Unsure if results are better. 211 | #flux_shift = true 212 | ``` 213 | 214 | Only the Full version is tested. Dev and Fast likely will not work properly due to being distilled, and because you can't set the guidance value. 215 | 216 | **HiDream doesn't perform well at resolutions under 1024**. The model uses the same training objective and VAE as Flux, so the loss values are directly comparable between the two. When I compare with Flux, there is moderate degradation in the loss value at 768 resolution. There is severe degradation in the loss value at 512 resolution, and inference at 512 produces completely fried images. 217 | 218 | The official inference code uses a max sequence length of 128 for all text encoders. You can change the sequence length of llama3 (which carries almost all the weight) by changing max_llama3_sequence_length. A value of 256 causes a slight increase in stabilized validation loss of the model before any training happens, so there is some quality degradation. If you have many captions longer than 128 tokens, it may be worth increasing this value, but this is untested. I would not increase it beyond 256. 219 | 220 | Due to how the Llama3 text embeddings are computed, the Llama3 text encoder must be kept loaded and its embeddings computed during training, rather than being pre-cached. Otherwise the cache would use an enormous amount of space on disk. This increases memory use, but you can have Llama3 in 4bit with essentially 0 measurable effect on validation loss. 221 | 222 | Without block swapping, you will need 48GB VRAM, or 2x24GB with pipeline parallelism. With enough block swapping you can train on a single 24GB GPU. Using nf4 quantization also allows training with 24GB, but there may be some quality decrease. 223 | 224 | HiDream LoRAs are saved in ComfyUI format. -------------------------------------------------------------------------------- /examples/cosmos_dataset.toml: -------------------------------------------------------------------------------- 1 | # Cosmos will only work properly with a fixed set of resolutions and frame lengths. Rather than setting up flexible aspect ratio buckets 2 | # and target resolutions, for this model you must directly specify the desired final size_buckets. This is for image training. For video 3 | # training, it seems you must use the full 121 frame length at these resolutions, which I can't even test because it would take too much 4 | # VRAM. 5 | size_buckets = [ 6 | [960, 960, 1], 7 | [960, 704, 1], 8 | [704, 960, 1], 9 | [1280, 704, 1], 10 | [704, 1280, 1], 11 | ] 12 | 13 | [[directory]] 14 | path = '/home/anon/data/images/grayscale' 15 | num_repeats = 10 -------------------------------------------------------------------------------- /examples/dataset.toml: -------------------------------------------------------------------------------- 1 | # Resolutions to train on, given as the side length of a square image. You can have multiple sizes here. 2 | # !!!WARNING!!!: this might work differently to how you think it does. Images are first grouped to aspect ratio 3 | # buckets, then each image is resized to ALL of the areas specified by the resolutions list. This is a way to do 4 | # multi-resolution training, i.e. training on multiple total pixel areas at once. Your dataset is effectively duplicated 5 | # as many times as the length of this list. 6 | # If you just want to use predetermined (width, height, frames) size buckets, see the example cosmos_dataset.toml 7 | # file for how you can do that. 8 | resolutions = [512] 9 | 10 | # You can give resolutions as (width, height) pairs also. This doesn't do anything different, it's just 11 | # another way of specifying the area(s) (i.e. total number of pixels) you want to train on. 12 | # resolutions = [[1280, 720]] 13 | 14 | # Enable aspect ratio bucketing. For the different AR buckets, the final size will be such that 15 | # the areas match the resolutions you configured above. 16 | enable_ar_bucket = true 17 | 18 | # The aspect ratio and frame bucket settings may be specified for each [[directory]] entry as well. 19 | # Directory-level settings will override top-level settings. 20 | 21 | # Min and max aspect ratios, given as width/height ratio. 22 | min_ar = 0.5 23 | max_ar = 2.0 24 | # Total number of aspect ratio buckets, evenly spaced (in log space) between min_ar and max_ar. 25 | num_ar_buckets = 7 26 | 27 | # Can manually specify ar_buckets instead of using the range-style config above. 28 | # Each entry can be width/height ratio, or (width, height) pair. But you can't mix them, because of TOML. 29 | # ar_buckets = [[512, 512], [448, 576]] 30 | # ar_buckets = [1.0, 1.5] 31 | 32 | # For video training, you need to configure frame buckets (similar to aspect ratio buckets). There will always 33 | # be a frame bucket of 1 for images. Videos will be assigned to the longest frame bucket possible, such that the video 34 | # is still greater than or equal to the frame bucket length. 35 | # But videos are never assigned to the image frame bucket (1); if the video is very short it would just be dropped. 36 | frame_buckets = [1, 33] 37 | # If you have >24GB VRAM, or multiple GPUs and use pipeline parallelism, or lower the spatial resolution, you could maybe train with longer frame buckets 38 | # frame_buckets = [1, 33, 65, 97] 39 | 40 | # Shuffle tags before caching, 0 to keep original caption (unless shuffle_tags is set to true). Increases caching time a tiny bit for higher values. 41 | # cache_shuffle_num = 10 42 | # Delimiter for tags, only used if cache_shuffle_num is not 0. Defaults to ", ". 43 | # "tag1, tag2, tag3" has ", " as delimiter and will possibly be shuffled like "tag3, tag1, tag2". "tag1;tag2;tag3" has ";" as delimiter and will possibly be shuffled like "tag2;tag1;tag3". 44 | # cache_shuffle_delimiter = ", " 45 | 46 | [[directory]] 47 | # Path to directory of images/videos, and corresponding caption files. The caption files should match the media file name, but with a .txt extension. 48 | # A missing caption file will log a warning, but then just train using an empty caption. 49 | path = '/home/anon/data/images/grayscale' 50 | 51 | # You can do masked training, where the mask indicates which parts of the image to train on. The masking is done in the loss function. The mask directory should have mask 52 | # images with the same names (ignoring the extension) as the training images. E.g. training image 1.jpg could have mask image 1.jpg, 1.png, etc. If a training image doesn't 53 | # have a corresponding mask, a warning is printed but training proceeds with no mask for that image. In the mask, white means train on this, black means mask it out. Values 54 | # in between black and white become a weight between 0 and 1, i.e. you can use a suitable value of grey for mask weight of 0.5. In actuality, only the R channel is extracted 55 | # and converted to the mask weight. 56 | # The mask_path can point to any directory containing mask images. 57 | #mask_path = '/home/anon/data/images/grayscale/masks' 58 | 59 | # How many repeats for 1 epoch. The dataset will act like it is duplicated this many times. 60 | # The semantics of this are the same as sd-scripts: num_repeats=1 means one epoch is a single pass over all examples (no duplication). 61 | num_repeats = 1 62 | 63 | # Example of overriding some settings, and using ar_buckets to directly specify ARs. 64 | # ar_buckets = [[448, 576]] 65 | # resolutions = [[448, 576]] 66 | # frame_buckets = [1] 67 | 68 | 69 | # You can list multiple directories. 70 | 71 | # [[directory]] 72 | # path = '/home/anon/data/images/something_else' 73 | # num_repeats = 5 74 | -------------------------------------------------------------------------------- /examples/main_example.toml: -------------------------------------------------------------------------------- 1 | # Output path for training runs. Each training run makes a new directory in here. 2 | output_dir = '/data/diffusion_pipe_training_runs/hunyuan_video_test' 3 | 4 | # Dataset config file. 5 | dataset = 'examples/dataset.toml' 6 | # You can have separate eval datasets. Give them a name for Tensorboard metrics. 7 | # eval_datasets = [ 8 | # {name = 'something', config = 'path/to/eval_dataset.toml'}, 9 | # ] 10 | 11 | # training settings 12 | 13 | # I usually set this to a really high value because I don't know how long I want to train. 14 | epochs = 1000 15 | # Batch size of a single forward/backward pass for one GPU. 16 | micro_batch_size_per_gpu = 1 17 | # For mixed video / image training, you can have a different batch size for images. 18 | #image_micro_batch_size_per_gpu = 4 19 | # Pipeline parallelism degree. A single instance of the model is divided across this many GPUs. 20 | pipeline_stages = 1 21 | # Number of micro-batches sent through the pipeline for each training step. 22 | # If pipeline_stages > 1, a higher GAS means better GPU utilization due to smaller pipeline bubbles (where GPUs aren't overlapping computation). 23 | gradient_accumulation_steps = 1 24 | # Grad norm clipping. 25 | gradient_clipping = 1.0 26 | # Learning rate warmup. 27 | warmup_steps = 100 28 | # Force the learning rate to be this value, regardless of what the optimizer or anything else says. 29 | # Can be used to change learning rate even when resuming from checkpoint. 30 | #force_constant_lr = 1e-5 31 | 32 | # Block swapping is supported for Wan, HunyuanVideo, Flux, and Chroma. This value controls the number 33 | # of blocks kept offloaded to RAM. Increasing it lowers VRAM use, but has a performance penalty. The 34 | # exactly performance penalty depends on the model and the type of training you are doing (e.g. images vs video). 35 | # Block swapping only works for LoRA training, and requires pipeline_stages=1. 36 | #blocks_to_swap = 20 37 | 38 | # eval settings 39 | 40 | eval_every_n_epochs = 1 41 | eval_before_first_step = true 42 | # Might want to set these lower for eval so that less images get dropped (eval dataset size is usually much smaller than training set). 43 | # Each size bucket of images/videos is rounded down to the nearest multiple of the global batch size, so higher global batch size means 44 | # more dropped images. Usually doesn't matter for training but the eval set is much smaller so it can matter. 45 | eval_micro_batch_size_per_gpu = 1 46 | # Batch size for images when doing mixed image / video training. Will be micro_batch_size_per_gpu if not set. 47 | #image_eval_micro_batch_size_per_gpu = 4 48 | eval_gradient_accumulation_steps = 1 49 | # If using block swap, you can disable it for eval. Eval uses less memory, so depending on block swapping amount you can maybe get away with 50 | # doing this, and then eval is much faster. 51 | #disable_block_swap_for_eval = true 52 | 53 | # misc settings 54 | 55 | # Probably want to set this a bit higher if you have a smaller dataset so you don't end up with a million saved models. 56 | save_every_n_epochs = 2 57 | # Can checkpoint the training state every n number of epochs or minutes. Set only one of these. You can resume from checkpoints using the --resume_from_checkpoint flag. 58 | #checkpoint_every_n_epochs = 1 59 | checkpoint_every_n_minutes = 120 60 | # Always set to true unless you have a huge amount of VRAM. 61 | # This can also be 'unsloth' to reduce VRAM even more, with a slight performance hit. 62 | activation_checkpointing = true 63 | 64 | # Controls how Deepspeed decides how to divide layers across GPUs. Probably don't change this. 65 | partition_method = 'parameters' 66 | # Alternatively you can use 'manual' in combination with partition_split, which specifies the split points for dividing 67 | # layers between GPUs. For example, with two GPUs, partition_split=[10] puts layers 0-9 on GPU 0, and the rest on GPU 1. 68 | # With three GPUs, partition_split=[10, 20] puts layers 0-9 on GPU 0, layers 10-19 on GPU 1, and the rest on GPU 2. 69 | # Length of partition_split must be pipeline_stages-1. 70 | #partition_split = [N] 71 | 72 | # dtype for saving the LoRA or model, if different from training dtype 73 | save_dtype = 'bfloat16' 74 | # Batch size for caching latents and text embeddings. Increasing can lead to higher GPU utilization during caching phase but uses more memory. 75 | caching_batch_size = 1 76 | # How often deepspeed logs to console. 77 | steps_per_print = 1 78 | # How to extract video clips for training from a single input video file. 79 | # The video file is first assigned to one of the configured frame buckets, but then we must extract one or more clips of exactly the right 80 | # number of frames for that bucket. 81 | # single_beginning: one clip starting at the beginning of the video 82 | # single_middle: one clip from the middle of the video (cutting off the start and end equally) 83 | # multiple_overlapping: extract the minimum number of clips to cover the full range of the video. They might overlap some. 84 | # default is single_beginning 85 | video_clip_mode = 'single_beginning' 86 | 87 | # This is how you configure HunyuanVideo. Other models will be different. See docs/supported_models.md for 88 | # details on the configuration and options for each model. 89 | [model] 90 | type = 'hunyuan-video' 91 | # Can load HunyuanVideo entirely from the ckpt path set up for the official inference scripts. 92 | #ckpt_path = '/home/anon/HunyuanVideo/ckpts' 93 | # Or you can load it by pointing to all the ComfyUI files. 94 | transformer_path = '/data2/imagegen_models/hunyuan_video_comfyui/hunyuan_video_720_cfgdistill_fp8_e4m3fn.safetensors' 95 | vae_path = '/data2/imagegen_models/hunyuan_video_comfyui/hunyuan_video_vae_bf16.safetensors' 96 | llm_path = '/data2/imagegen_models/hunyuan_video_comfyui/llava-llama-3-8b-text-encoder-tokenizer' 97 | clip_path = '/data2/imagegen_models/hunyuan_video_comfyui/clip-vit-large-patch14' 98 | # Base dtype used for all models. 99 | dtype = 'bfloat16' 100 | # Hunyuan Video supports fp8 for the transformer when training LoRA. 101 | transformer_dtype = 'float8' 102 | # How to sample timesteps to train on. Can be logit_normal or uniform. 103 | timestep_sample_method = 'logit_normal' 104 | 105 | # For models that support full fine tuning, simply delete or comment out the [adapter] table to FFT. 106 | [adapter] 107 | type = 'lora' 108 | rank = 32 109 | # Dtype for the LoRA weights you are training. 110 | dtype = 'bfloat16' 111 | # You can initialize the lora weights from a previously trained lora. 112 | #init_from_existing = '/data/diffusion_pipe_training_runs/something/epoch50' 113 | # Experimental. Can fuse LoRAs into the base weights before training. Right now only for Flux. 114 | #fuse_adapters = [ 115 | # {path = '/data2/imagegen_models/loras/some_lora.safetensors', weight = 1.0} 116 | #] 117 | 118 | [optimizer] 119 | # AdamW from the optimi library is a good default since it automatically uses Kahan summation when training bfloat16 weights. 120 | # Look at train.py for other options. You could also easily edit the file and add your own. 121 | type = 'adamw_optimi' 122 | lr = 2e-5 123 | betas = [0.9, 0.99] 124 | weight_decay = 0.01 125 | eps = 1e-8 126 | 127 | # Can use this optimizer for a bit less memory usage. 128 | # [optimizer] 129 | # type = 'AdamW8bitKahan' 130 | # lr = 2e-5 131 | # betas = [0.9, 0.99] 132 | # weight_decay = 0.01 133 | # stabilize = false 134 | 135 | # Any optimizer not explicitly supported will be dynamically loaded from the pytorch-optimizer library. 136 | # [optimizer] 137 | # type = 'Prodigy' 138 | # lr = 1 139 | # betas = [0.9, 0.99] 140 | # weight_decay = 0.01 141 | 142 | [monitoring] 143 | # Set to true and fill in these fields to enable wandb 144 | enable_wandb = false 145 | wandb_api_key = '' 146 | wandb_tracker_name = '' 147 | wandb_run_name = '' 148 | -------------------------------------------------------------------------------- /examples/recommended_lumina_dataset_config.toml: -------------------------------------------------------------------------------- 1 | # Lumina certainly works at low res (512), and is able to learn from it, but the quality at lower res 2 | # seems not as good as a model like Flux. So we will train 50/50 on low res and standard res. 3 | # I'm not sure what is best and urge the community to experiment and try things themselves. 4 | resolutions = [512, 1024] 5 | 6 | enable_ar_bucket = true 7 | min_ar = 0.5 8 | max_ar = 2.0 9 | num_ar_buckets = 9 10 | 11 | # It seems Lumina was trained with this caption prefix for most images. Probably good to fine tune with it. 12 | # But I am still unsure, and haven't done direct comparisons. 13 | caption_prefix = 'You are an assistant designed to generate high-quality images based on user prompts. ' 14 | 15 | [[directory]] 16 | path = '/home/anon/data/something' -------------------------------------------------------------------------------- /examples/wan_14b_min_vram.toml: -------------------------------------------------------------------------------- 1 | # This configuration should allow you to train Wan 14b t2v on 512x512x81 sized videos (or varying aspect ratios of the same size), with 24GB VRAM. 2 | 3 | # change this 4 | output_dir = '/data/diffusion_pipe_training_runs/tmp' 5 | # and this 6 | dataset = '/home/anon/code/diffusion-pipe-configs/datasets/wan/video.toml' 7 | 8 | # training settings 9 | epochs = 1000 10 | micro_batch_size_per_gpu = 1 11 | pipeline_stages = 1 12 | gradient_accumulation_steps = 1 13 | gradient_clipping = 1 14 | warmup_steps = 10 15 | 16 | # eval settings 17 | eval_every_n_epochs = 1 18 | eval_before_first_step = true 19 | eval_micro_batch_size_per_gpu = 1 20 | eval_gradient_accumulation_steps = 1 21 | 22 | # misc settings 23 | save_every_n_epochs = 5 24 | checkpoint_every_n_minutes = 120 25 | activation_checkpointing = 'unsloth' 26 | partition_method = 'parameters' 27 | save_dtype = 'bfloat16' 28 | caching_batch_size = 1 29 | steps_per_print = 1 30 | video_clip_mode = 'single_beginning' 31 | blocks_to_swap = 32 32 | 33 | [model] 34 | type = 'wan' 35 | ckpt_path = '/data2/imagegen_models/Wan2.1-T2V-14B' 36 | dtype = 'bfloat16' 37 | transformer_dtype = 'float8' 38 | timestep_sample_method = 'logit_normal' 39 | 40 | [adapter] 41 | type = 'lora' 42 | rank = 32 43 | dtype = 'bfloat16' 44 | 45 | [optimizer] 46 | type = 'AdamW8bitKahan' 47 | lr = 2e-5 48 | betas = [0.9, 0.99] 49 | weight_decay = 0.01 50 | stabilize = false 51 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import re 3 | 4 | import peft 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | import safetensors.torch 9 | import torchvision 10 | from PIL import Image, ImageOps 11 | from torchvision import transforms 12 | import imageio 13 | 14 | from utils.common import is_main_process, VIDEO_EXTENSIONS, round_to_nearest_multiple, round_down_to_multiple 15 | 16 | 17 | def make_contiguous(*tensors): 18 | return tuple(x.contiguous() for x in tensors) 19 | 20 | 21 | def extract_clips(video, target_frames, video_clip_mode): 22 | # video is (channels, num_frames, height, width) 23 | frames = video.shape[1] 24 | if frames < target_frames: 25 | # TODO: think about how to handle this case. Maybe the video should have already been thrown out? 26 | print(f'video with shape {video.shape} is being skipped because it has less than the target_frames') 27 | return [] 28 | 29 | if video_clip_mode == 'single_beginning': 30 | return [video[:, :target_frames, ...]] 31 | elif video_clip_mode == 'single_middle': 32 | start = int((frames - target_frames) / 2) 33 | assert frames-start >= target_frames 34 | return [video[:, start:start+target_frames, ...]] 35 | elif video_clip_mode == 'multiple_overlapping': 36 | # Extract multiple clips so we use the whole video for training. 37 | # The clips might overlap a little bit. We never cut anything off the end of the video. 38 | num_clips = ((frames - 1) // target_frames) + 1 39 | start_indices = torch.linspace(0, frames-target_frames, num_clips).int() 40 | return [video[:, i:i+target_frames, ...] for i in start_indices] 41 | else: 42 | raise NotImplementedError(f'video_clip_mode={video_clip_mode} is not recognized') 43 | 44 | 45 | def convert_crop_and_resize(pil_img, width_and_height): 46 | if pil_img.mode not in ['RGB', 'RGBA'] and 'transparency' in pil_img.info: 47 | pil_img = pil_img.convert('RGBA') 48 | 49 | # add white background for transparent images 50 | if pil_img.mode == 'RGBA': 51 | canvas = Image.new('RGBA', pil_img.size, (255, 255, 255)) 52 | canvas.alpha_composite(pil_img) 53 | pil_img = canvas.convert('RGB') 54 | else: 55 | pil_img = pil_img.convert('RGB') 56 | 57 | return ImageOps.fit(pil_img, width_and_height) 58 | 59 | 60 | class PreprocessMediaFile: 61 | def __init__(self, config, support_video=False, framerate=None, round_height=1, round_width=1, round_frames=1): 62 | self.config = config 63 | self.video_clip_mode = config.get('video_clip_mode', 'single_beginning') 64 | print(f'using video_clip_mode={self.video_clip_mode}') 65 | self.pil_to_tensor = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) 66 | self.support_video = support_video 67 | self.framerate = framerate 68 | print(f'using framerate={self.framerate}') 69 | self.round_height = round_height 70 | self.round_width = round_width 71 | self.round_frames = round_frames 72 | if self.support_video: 73 | assert self.framerate 74 | 75 | def __call__(self, filepath, mask_filepath, size_bucket=None): 76 | is_video = (Path(filepath).suffix in VIDEO_EXTENSIONS) 77 | if is_video: 78 | assert self.support_video 79 | num_frames = 0 80 | for frame in imageio.v3.imiter(filepath, fps=self.framerate): 81 | num_frames += 1 82 | height, width = frame.shape[:2] 83 | video = imageio.v3.imiter(filepath, fps=self.framerate) 84 | else: 85 | num_frames = 1 86 | pil_img = Image.open(filepath) 87 | height, width = pil_img.height, pil_img.width 88 | video = [pil_img] 89 | 90 | if size_bucket is not None: 91 | size_bucket_width, size_bucket_height, size_bucket_frames = size_bucket 92 | else: 93 | size_bucket_width, size_bucket_height, size_bucket_frames = width, height, num_frames 94 | 95 | height_rounded = round_to_nearest_multiple(size_bucket_height, self.round_height) 96 | width_rounded = round_to_nearest_multiple(size_bucket_width, self.round_width) 97 | frames_rounded = round_down_to_multiple(size_bucket_frames - 1, self.round_frames) + 1 98 | resize_wh = (width_rounded, height_rounded) 99 | 100 | if mask_filepath: 101 | mask_img = Image.open(mask_filepath).convert('RGB') 102 | img_hw = (height, width) 103 | mask_hw = (mask_img.height, mask_img.width) 104 | if mask_hw != img_hw: 105 | raise ValueError( 106 | f'Mask shape {mask_hw} was not the same as image shape {img_hw}.\n' 107 | f'Image path: {filepath}\n' 108 | f'Mask path: {mask_filepath}' 109 | ) 110 | mask_img = ImageOps.fit(mask_img, resize_wh) 111 | mask = torchvision.transforms.functional.to_tensor(mask_img)[0].to(torch.float16) # use first channel 112 | else: 113 | mask = None 114 | 115 | resized_video = torch.empty((num_frames, 3, height_rounded, width_rounded)) 116 | for i, frame in enumerate(video): 117 | if not isinstance(frame, Image.Image): 118 | frame = torchvision.transforms.functional.to_pil_image(frame) 119 | cropped_image = convert_crop_and_resize(frame, resize_wh) 120 | resized_video[i, ...] = self.pil_to_tensor(cropped_image) 121 | 122 | if not self.support_video: 123 | return [(resized_video.squeeze(0), mask)] 124 | 125 | # (num_frames, channels, height, width) -> (channels, num_frames, height, width) 126 | resized_video = torch.permute(resized_video, (1, 0, 2, 3)) 127 | if not is_video: 128 | return [(resized_video, mask)] 129 | else: 130 | videos = extract_clips(resized_video, frames_rounded, self.video_clip_mode) 131 | return [(video, mask) for video in videos] 132 | 133 | 134 | class BasePipeline: 135 | framerate = None 136 | 137 | def load_diffusion_model(self): 138 | pass 139 | 140 | def get_vae(self): 141 | raise NotImplementedError() 142 | 143 | def get_text_encoders(self): 144 | raise NotImplementedError() 145 | 146 | def configure_adapter(self, adapter_config): 147 | target_linear_modules = set() 148 | for name, module in self.transformer.named_modules(): 149 | if module.__class__.__name__ not in self.adapter_target_modules: 150 | continue 151 | for full_submodule_name, submodule in module.named_modules(prefix=name): 152 | if isinstance(submodule, nn.Linear): 153 | target_linear_modules.add(full_submodule_name) 154 | target_linear_modules = list(target_linear_modules) 155 | 156 | adapter_type = adapter_config['type'] 157 | if adapter_type == 'lora': 158 | peft_config = peft.LoraConfig( 159 | r=adapter_config['rank'], 160 | lora_alpha=adapter_config['alpha'], 161 | lora_dropout=adapter_config['dropout'], 162 | bias='none', 163 | target_modules=target_linear_modules 164 | ) 165 | else: 166 | raise NotImplementedError(f'Adapter type {adapter_type} is not implemented') 167 | self.peft_config = peft_config 168 | self.lora_model = peft.get_peft_model(self.transformer, peft_config) 169 | if is_main_process(): 170 | self.lora_model.print_trainable_parameters() 171 | for name, p in self.transformer.named_parameters(): 172 | p.original_name = name 173 | if p.requires_grad: 174 | p.data = p.data.to(adapter_config['dtype']) 175 | 176 | def save_adapter(self, save_dir, peft_state_dict): 177 | raise NotImplementedError() 178 | 179 | def load_adapter_weights(self, adapter_path): 180 | if is_main_process(): 181 | print(f'Loading adapter weights from path {adapter_path}') 182 | safetensors_files = list(Path(adapter_path).glob('*.safetensors')) 183 | if len(safetensors_files) == 0: 184 | raise RuntimeError(f'No safetensors file found in {adapter_path}') 185 | if len(safetensors_files) > 1: 186 | raise RuntimeError(f'Multiple safetensors files found in {adapter_path}') 187 | adapter_state_dict = safetensors.torch.load_file(safetensors_files[0]) 188 | modified_state_dict = {} 189 | model_parameters = set(name for name, p in self.transformer.named_parameters()) 190 | for k, v in adapter_state_dict.items(): 191 | # Replace Diffusers or ComfyUI prefix 192 | k = re.sub(r'^(transformer|diffusion_model)\.', '', k) 193 | # Replace weight at end for LoRA format 194 | k = re.sub(r'\.weight$', '.default.weight', k) 195 | if k not in model_parameters: 196 | raise RuntimeError(f'modified_state_dict key {k} is not in the model parameters') 197 | modified_state_dict[k] = v 198 | self.transformer.load_state_dict(modified_state_dict, strict=False) 199 | 200 | def save_model(self, save_dir, diffusers_sd): 201 | raise NotImplementedError() 202 | 203 | def get_preprocess_media_file_fn(self): 204 | return PreprocessMediaFile(self.config, support_video=False) 205 | 206 | def get_call_vae_fn(self, vae): 207 | raise NotImplementedError() 208 | 209 | def get_call_text_encoder_fn(self, text_encoder): 210 | raise NotImplementedError() 211 | 212 | def prepare_inputs(self, inputs, timestep_quantile=None): 213 | raise NotImplementedError() 214 | 215 | def to_layers(self): 216 | raise NotImplementedError() 217 | 218 | def model_specific_dataset_config_validation(self, dataset_config): 219 | pass 220 | 221 | # Get param groups that will be passed into the optimizer. Models can override this, e.g. SDXL 222 | # supports separate learning rates for unet and text encoders. 223 | def get_param_groups(self, parameters): 224 | return parameters 225 | 226 | # Default loss_fn. MSE between output and target, with mask support. 227 | def get_loss_fn(self): 228 | def loss_fn(output, label): 229 | target, mask = label 230 | with torch.autocast('cuda', enabled=False): 231 | output = output.to(torch.float32) 232 | target = target.to(output.device, torch.float32) 233 | loss = F.mse_loss(output, target, reduction='none') 234 | # empty tensor means no masking 235 | if mask.numel() > 0: 236 | mask = mask.to(output.device, torch.float32) 237 | loss *= mask 238 | loss = loss.mean() 239 | return loss 240 | return loss_fn 241 | 242 | def enable_block_swap(self, blocks_to_swap): 243 | raise NotImplementedError('Block swapping is not implemented for this model') 244 | 245 | def prepare_block_swap_training(self): 246 | pass 247 | 248 | def prepare_block_swap_inference(self, disable_block_swap=False): 249 | pass 250 | -------------------------------------------------------------------------------- /models/chroma.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | import sys 4 | import os.path 5 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(__file__)), '../submodules/flow')) 6 | 7 | import diffusers 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from einops import rearrange 12 | from safetensors.torch import save_file 13 | from accelerate import init_empty_weights 14 | 15 | from models.base import BasePipeline, make_contiguous 16 | from utils.common import AUTOCAST_DTYPE, load_state_dict 17 | from utils.offloading import ModelOffloader 18 | from src.models.chroma.model import Chroma, chroma_params, modify_mask_to_attend_padding 19 | from src.models.chroma.module.layers import timestep_embedding, distribute_modulations, ModulationOut 20 | 21 | 22 | KEEP_IN_HIGH_PRECISION = ['norm', 'bias', 'img_in', 'txt_in', 'distilled_guidance_layer', 'final_layer'] 23 | 24 | 25 | def time_shift(mu: float, sigma: float, t: torch.Tensor): 26 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 27 | 28 | 29 | def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15): 30 | m = (y2 - y1) / (x2 - x1) 31 | b = y1 - m * x1 32 | return lambda x: m * x + b 33 | 34 | 35 | @dataclass 36 | class ModulationOutSpec: 37 | shift: slice 38 | scale: slice 39 | gate: slice 40 | 41 | 42 | # Adapted from the function of the same name in the original training code, but only computes the slices, 43 | # doesn't actually slice the tensor yet. I did this because pipeline parallelism makes it nearly impossible 44 | # to pass a dictionary between GPUs. So we have to pass the pre-sliced tensor then extract the slice on 45 | # the layer right before it's used. 46 | def distribute_modulations(): 47 | block_dict = {} 48 | 49 | # HARD CODED VALUES! lookup table for the generated vectors 50 | # TODO: move this into chroma config! 51 | # Add 38 single mod blocks 52 | for i in range(38): 53 | key = f"single_blocks.{i}.modulation.lin" 54 | block_dict[key] = None 55 | 56 | # Add 19 image double blocks 57 | for i in range(19): 58 | key = f"double_blocks.{i}.img_mod.lin" 59 | block_dict[key] = None 60 | 61 | # Add 19 text double blocks 62 | for i in range(19): 63 | key = f"double_blocks.{i}.txt_mod.lin" 64 | block_dict[key] = None 65 | 66 | # Add the final layer 67 | block_dict["final_layer.adaLN_modulation.1"] = None 68 | # 6.2b version 69 | block_dict["lite_double_blocks.4.img_mod.lin"] = None 70 | block_dict["lite_double_blocks.4.txt_mod.lin"] = None 71 | 72 | idx = 0 # Index to keep track of the vector slices 73 | 74 | for key in block_dict.keys(): 75 | if "single_blocks" in key: 76 | # Single block: 1 ModulationOut 77 | block_dict[key] = ModulationOutSpec( 78 | shift=slice(idx, idx+1), 79 | scale=slice(idx+1, idx+2), 80 | gate=slice(idx+2, idx+3), 81 | ) 82 | idx += 3 # Advance by 3 vectors 83 | 84 | elif "img_mod" in key: 85 | # Double block: List of 2 ModulationOut 86 | double_block = [] 87 | for _ in range(2): # Create 2 ModulationOut objects 88 | double_block.append( 89 | ModulationOutSpec( 90 | shift=slice(idx, idx+1), 91 | scale=slice(idx+1, idx+2), 92 | gate=slice(idx+2, idx+3), 93 | ) 94 | ) 95 | idx += 3 # Advance by 3 vectors per ModulationOut 96 | block_dict[key] = double_block 97 | 98 | elif "txt_mod" in key: 99 | # Double block: List of 2 ModulationOut 100 | double_block = [] 101 | for _ in range(2): # Create 2 ModulationOut objects 102 | double_block.append( 103 | ModulationOutSpec( 104 | shift=slice(idx, idx+1), 105 | scale=slice(idx+1, idx+2), 106 | gate=slice(idx+2, idx+3), 107 | ) 108 | ) 109 | idx += 3 # Advance by 3 vectors per ModulationOut 110 | block_dict[key] = double_block 111 | 112 | elif "final_layer" in key: 113 | # Final layer: 1 ModulationOut 114 | block_dict[key] = [ 115 | slice(idx, idx+1), 116 | slice(idx+1, idx+2), 117 | ] 118 | idx += 2 # Advance by 3 vectors 119 | 120 | return block_dict 121 | 122 | modulation_distribute_dict = distribute_modulations() 123 | 124 | 125 | class ChromaPipeline(BasePipeline): 126 | name = 'chroma' 127 | 128 | checkpointable_layers = [ 129 | 'TransformerWrapper', 130 | 'SingleTransformerWrapper', 131 | ] 132 | 133 | adapter_target_modules = ['DoubleStreamBlock', 'SingleStreamBlock'] 134 | 135 | def __init__(self, config): 136 | self.config = config 137 | self.model_config = self.config['model'] 138 | self.offloader_double = ModelOffloader('dummy', [], 0, 0, True, torch.device('cuda'), False, debug=False) 139 | self.offloader_single = ModelOffloader('dummy', [], 0, 0, True, torch.device('cuda'), False, debug=False) 140 | 141 | dtype = self.model_config['dtype'] 142 | self.diffusers_pipeline = diffusers.FluxPipeline.from_pretrained(self.model_config['diffusers_path'], torch_dtype=dtype, transformer=None) 143 | 144 | def __getattr__(self, name): 145 | return getattr(self.diffusers_pipeline, name) 146 | 147 | def load_diffusion_model(self): 148 | dtype = self.model_config['dtype'] 149 | transformer_dtype = self.model_config.get('transformer_dtype', dtype) 150 | with init_empty_weights(): 151 | transformer = Chroma(chroma_params) 152 | transformer.load_state_dict(load_state_dict(self.model_config['transformer_path']), assign=True) 153 | 154 | for name, p in transformer.named_parameters(): 155 | if not any(x in name for x in KEEP_IN_HIGH_PRECISION): 156 | p.data = p.data.to(transformer_dtype) 157 | 158 | self.diffusers_pipeline.transformer = transformer 159 | self.transformer.train() 160 | for name, p in self.transformer.named_parameters(): 161 | p.original_name = name 162 | 163 | def get_vae(self): 164 | return self.vae 165 | 166 | def get_text_encoders(self): 167 | return [self.text_encoder_2] 168 | 169 | def save_adapter(self, save_dir, peft_state_dict): 170 | self.peft_config.save_pretrained(save_dir) 171 | # ComfyUI format. 172 | peft_state_dict = {'diffusion_model.'+k: v for k, v in peft_state_dict.items()} 173 | save_file(peft_state_dict, save_dir / 'adapter_model.safetensors', metadata={'format': 'pt'}) 174 | 175 | def save_model(self, save_dir, diffusers_sd): 176 | save_file(diffusers_sd, save_dir / 'model.safetensors', metadata={"format": "pt"}) 177 | 178 | def get_call_vae_fn(self, vae): 179 | def fn(tensor): 180 | latents = vae.encode(tensor.to(vae.device, vae.dtype)).latent_dist.sample() 181 | if hasattr(vae.config, 'shift_factor') and vae.config.shift_factor is not None: 182 | latents = latents - vae.config.shift_factor 183 | latents = latents * vae.config.scaling_factor 184 | return {'latents': latents} 185 | return fn 186 | 187 | def get_call_text_encoder_fn(self, text_encoder): 188 | def fn(caption, is_video): 189 | # args are lists 190 | assert not any(is_video) 191 | max_sequence_length = 512 192 | text_inputs = self.tokenizer_2( 193 | caption, 194 | padding="max_length", 195 | max_length=max_sequence_length, 196 | truncation=True, 197 | return_length=False, 198 | return_overflowing_tokens=False, 199 | return_tensors="pt", 200 | ) 201 | text_input_ids = text_inputs.input_ids 202 | untruncated_ids = self.tokenizer_2(caption, padding="longest", return_tensors="pt").input_ids 203 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 204 | removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) 205 | print( 206 | "The following part of your input was truncated because `max_sequence_length` is set to " 207 | f" {max_sequence_length} tokens: {removed_text}" 208 | ) 209 | device = text_encoder.device 210 | prompt_embeds = self.text_encoder_2(text_input_ids.to(device), text_inputs.attention_mask.to(device), output_hidden_states=False)[0] 211 | return {'t5_embed': prompt_embeds, 't5_attention_mask': text_inputs.attention_mask} 212 | return fn 213 | 214 | def prepare_inputs(self, inputs, timestep_quantile=None): 215 | latents = inputs['latents'].float() 216 | t5_embed = inputs['t5_embed'] 217 | t5_attention_mask = inputs['t5_attention_mask'] 218 | mask = inputs['mask'] 219 | 220 | # The following code taken and slightly modified from x-flux (https://github.com/XLabs-AI/x-flux/tree/main) 221 | bs, c, h, w = latents.shape 222 | latents = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 223 | 224 | if mask is not None: 225 | mask = mask.unsqueeze(1).expand((-1, c, -1, -1)) # make mask (bs, c, img_h, img_w) 226 | mask = F.interpolate(mask, size=(h, w), mode='nearest-exact') # resize to latent spatial dimension 227 | mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 228 | 229 | img_ids = self._prepare_latent_image_ids(bs, h // 2, w // 2, latents.device, latents.dtype) 230 | if img_ids.ndim == 2: 231 | # This method must return tensors with batch dimension, since we proceed to split along batch dimension for pipelining. 232 | img_ids = img_ids.unsqueeze(0).repeat((bs, 1, 1)) 233 | txt_ids = torch.zeros(bs, t5_embed.shape[1], 3).to(latents.device, latents.dtype) 234 | 235 | timestep_sample_method = self.model_config.get('timestep_sample_method', 'logit_normal') 236 | 237 | if timestep_sample_method == 'logit_normal': 238 | dist = torch.distributions.normal.Normal(0, 1) 239 | elif timestep_sample_method == 'uniform': 240 | dist = torch.distributions.uniform.Uniform(0, 1) 241 | else: 242 | raise NotImplementedError() 243 | 244 | if timestep_quantile is not None: 245 | t = dist.icdf(torch.full((bs,), timestep_quantile, device=latents.device)) 246 | else: 247 | t = dist.sample((bs,)).to(latents.device) 248 | 249 | if timestep_sample_method == 'logit_normal': 250 | sigmoid_scale = self.model_config.get('sigmoid_scale', 1.0) 251 | t = t * sigmoid_scale 252 | t = torch.sigmoid(t) 253 | 254 | if shift := self.model_config.get('shift', None): 255 | t = (t * shift) / (1 + (shift - 1) * t) 256 | elif self.model_config.get('flux_shift', False): 257 | mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) 258 | t = time_shift(mu, 1.0, t) 259 | 260 | x_1 = latents 261 | x_0 = torch.randn_like(x_1) 262 | t_expanded = t.view(-1, 1, 1) 263 | x_t = (1 - t_expanded) * x_1 + t_expanded * x_0 264 | target = x_0 - x_1 265 | # guidance needs to be 0 on this model 266 | guidance_vec = torch.zeros((x_t.shape[0],), device=x_t.device, dtype=torch.float32) 267 | 268 | return (x_t, t5_embed, t5_attention_mask, t, img_ids, txt_ids, guidance_vec), (target, mask) 269 | 270 | def to_layers(self): 271 | transformer = self.transformer 272 | layers = [InitialLayer(transformer)] 273 | for i, block in enumerate(transformer.double_blocks): 274 | layers.append(TransformerWrapper(block, i, self.offloader_double)) 275 | layers.append(concatenate_hidden_states) 276 | for i, block in enumerate(transformer.single_blocks): 277 | layers.append(SingleTransformerWrapper(block, i, self.offloader_single)) 278 | layers.append(FinalLayer(transformer)) 279 | return layers 280 | 281 | def enable_block_swap(self, blocks_to_swap): 282 | transformer = self.transformer 283 | double_blocks = transformer.double_blocks 284 | single_blocks = transformer.single_blocks 285 | num_double_blocks = len(double_blocks) 286 | num_single_blocks = len(single_blocks) 287 | double_blocks_to_swap = blocks_to_swap // 2 288 | # This swaps more than blocks_to_swap total blocks. A bit odd, but the model does have twice as many 289 | # single blocks as double. I'm just replicating the behavior of Musubi Tuner. 290 | single_blocks_to_swap = (blocks_to_swap - double_blocks_to_swap) * 2 + 1 291 | 292 | assert double_blocks_to_swap <= num_double_blocks - 2 and single_blocks_to_swap <= num_single_blocks - 2, ( 293 | f'Cannot swap more than {num_double_blocks - 2} double blocks and {num_single_blocks - 2} single blocks. ' 294 | f'Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks.' 295 | ) 296 | 297 | self.offloader_double = ModelOffloader( 298 | 'DoubleBlock', double_blocks, num_double_blocks, double_blocks_to_swap, True, torch.device('cuda'), self.config['reentrant_activation_checkpointing'] 299 | ) 300 | self.offloader_single = ModelOffloader( 301 | 'SingleBlock', single_blocks, num_single_blocks, single_blocks_to_swap, True, torch.device('cuda'), self.config['reentrant_activation_checkpointing'] 302 | ) 303 | transformer.double_blocks = None 304 | transformer.single_blocks = None 305 | transformer.to('cuda') 306 | transformer.double_blocks = double_blocks 307 | transformer.single_blocks = single_blocks 308 | self.prepare_block_swap_training() 309 | print( 310 | f'Block swap enabled. Swapping {blocks_to_swap} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}.' 311 | ) 312 | 313 | def prepare_block_swap_training(self): 314 | self.offloader_double.enable_block_swap() 315 | self.offloader_double.set_forward_only(False) 316 | self.offloader_double.prepare_block_devices_before_forward() 317 | self.offloader_single.enable_block_swap() 318 | self.offloader_single.set_forward_only(False) 319 | self.offloader_single.prepare_block_devices_before_forward() 320 | 321 | def prepare_block_swap_inference(self, disable_block_swap=False): 322 | if disable_block_swap: 323 | self.offloader_double.disable_block_swap() 324 | self.offloader_single.disable_block_swap() 325 | self.offloader_double.set_forward_only(True) 326 | self.offloader_double.prepare_block_devices_before_forward() 327 | self.offloader_single.set_forward_only(True) 328 | self.offloader_single.prepare_block_devices_before_forward() 329 | 330 | 331 | class InitialLayer(nn.Module): 332 | def __init__(self, model): 333 | super().__init__() 334 | self.img_in = model.img_in 335 | self.txt_in = model.txt_in 336 | self.distilled_guidance_layer = model.distilled_guidance_layer 337 | self.pe_embedder = model.pe_embedder 338 | self.mod_index = model.mod_index 339 | self.model = [model] 340 | 341 | def __getattr__(self, name): 342 | return getattr(self.model[0], name) 343 | 344 | @torch.autocast('cuda', dtype=AUTOCAST_DTYPE) 345 | def forward(self, inputs): 346 | for item in inputs: 347 | if torch.is_floating_point(item): 348 | item.requires_grad_(True) 349 | img, txt, txt_mask, timesteps, img_ids, txt_ids, guidance = inputs 350 | if img.ndim != 3 or txt.ndim != 3: 351 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 352 | 353 | img = self.img_in(img) 354 | txt = self.txt_in(txt) 355 | 356 | # See comments in original Chroma training code. This is supposed to be in a no_grad block. 357 | with torch.no_grad(): 358 | distill_timestep = timestep_embedding(timesteps, 16) 359 | distil_guidance = timestep_embedding(guidance, 16) 360 | # get all modulation index 361 | modulation_index = timestep_embedding(self.mod_index.to(distill_timestep.device), 32) 362 | # we need to broadcast the modulation index here so each batch has all of the index 363 | modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1) 364 | # and we need to broadcast timestep and guidance along too 365 | timestep_guidance = ( 366 | torch.cat([distill_timestep, distil_guidance], dim=1) 367 | .unsqueeze(1) 368 | .repeat(1, self.mod_index_length, 1) 369 | ) 370 | # then and only then we could concatenate it together 371 | input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) 372 | mod_vectors = self.distilled_guidance_layer(input_vec) 373 | # Need to force this to True for Deepspeed pipeline parallelism. 374 | mod_vectors.requires_grad_(True) 375 | 376 | ids = torch.cat((txt_ids, img_ids), dim=1) 377 | pe = self.pe_embedder(ids) 378 | 379 | max_len = txt.shape[1] 380 | 381 | with torch.no_grad(): 382 | txt_mask_w_padding = modify_mask_to_attend_padding( 383 | txt_mask, max_len, 1 384 | ) 385 | txt_img_mask = torch.cat( 386 | [ 387 | txt_mask_w_padding, 388 | torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device), 389 | ], 390 | dim=1, 391 | ) 392 | txt_img_mask = txt_img_mask.float().T @ txt_img_mask.float() 393 | txt_img_mask = ( 394 | txt_img_mask[None, None, ...] 395 | .repeat(txt.shape[0], self.num_heads, 1, 1) 396 | .int() 397 | .bool() 398 | ) 399 | 400 | return make_contiguous(img, txt, pe, mod_vectors, txt_img_mask) 401 | 402 | 403 | class TransformerWrapper(nn.Module): 404 | def __init__(self, block, idx, offloader): 405 | super().__init__() 406 | self.block = block 407 | self.idx = idx 408 | self.offloader = offloader 409 | 410 | @torch.autocast('cuda', dtype=AUTOCAST_DTYPE) 411 | def forward(self, inputs): 412 | img, txt, pe, mod_vectors, txt_img_mask = inputs 413 | 414 | self.offloader.wait_for_block(self.idx) 415 | 416 | img_mod_spec = modulation_distribute_dict[f"double_blocks.{self.idx}.img_mod.lin"] 417 | txt_mod_spec = modulation_distribute_dict[f"double_blocks.{self.idx}.txt_mod.lin"] 418 | img_mod = [ 419 | ModulationOut( 420 | shift=mod_vectors[:, spec.shift, :], 421 | scale=mod_vectors[:, spec.scale, :], 422 | gate=mod_vectors[:, spec.gate, :], 423 | ) 424 | for spec in img_mod_spec 425 | ] 426 | txt_mod = [ 427 | ModulationOut( 428 | shift=mod_vectors[:, spec.shift, :], 429 | scale=mod_vectors[:, spec.scale, :], 430 | gate=mod_vectors[:, spec.gate, :], 431 | ) 432 | for spec in txt_mod_spec 433 | ] 434 | double_mod = [img_mod, txt_mod] 435 | img, txt = self.block( 436 | img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask 437 | ) 438 | 439 | self.offloader.submit_move_blocks_forward(self.idx) 440 | 441 | return make_contiguous(img, txt, pe, mod_vectors, txt_img_mask) 442 | 443 | 444 | def concatenate_hidden_states(inputs): 445 | img, txt, pe, mod_vectors, txt_img_mask = inputs 446 | img = torch.cat((txt, img), 1) 447 | return img, txt, pe, mod_vectors, txt_img_mask 448 | 449 | 450 | class SingleTransformerWrapper(nn.Module): 451 | def __init__(self, block, idx, offloader): 452 | super().__init__() 453 | self.block = block 454 | self.idx = idx 455 | self.offloader = offloader 456 | 457 | @torch.autocast('cuda', dtype=AUTOCAST_DTYPE) 458 | def forward(self, inputs): 459 | img, txt, pe, mod_vectors, txt_img_mask = inputs 460 | 461 | self.offloader.wait_for_block(self.idx) 462 | 463 | single_mod_spec = modulation_distribute_dict[f"single_blocks.{self.idx}.modulation.lin"] 464 | single_mod = ModulationOut( 465 | shift=mod_vectors[:, single_mod_spec.shift, :], 466 | scale=mod_vectors[:, single_mod_spec.scale, :], 467 | gate=mod_vectors[:, single_mod_spec.gate, :], 468 | ) 469 | img = self.block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask) 470 | 471 | self.offloader.submit_move_blocks_forward(self.idx) 472 | 473 | return make_contiguous(img, txt, pe, mod_vectors, txt_img_mask) 474 | 475 | 476 | class FinalLayer(nn.Module): 477 | def __init__(self, model): 478 | super().__init__() 479 | self.final_layer = model.final_layer 480 | self.model = [model] 481 | 482 | def __getattr__(self, name): 483 | return getattr(self.model[0], name) 484 | 485 | @torch.autocast('cuda', dtype=AUTOCAST_DTYPE) 486 | def forward(self, inputs): 487 | img, txt, pe, mod_vectors, txt_img_mask = inputs 488 | img = img[:, txt.shape[1] :, ...] 489 | final_mod_spec = modulation_distribute_dict["final_layer.adaLN_modulation.1"] 490 | final_mod = [mod_vectors[:, s, :] for s in final_mod_spec] 491 | img = self.final_layer( 492 | img, distill_vec=final_mod 493 | ) # (N, T, patch_size ** 2 * out_channels) 494 | return img 495 | -------------------------------------------------------------------------------- /models/cosmos.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pprint import pprint 3 | import os.path 4 | sys.path.insert(0, os.path.abspath('submodules/Cosmos')) 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from transformers import T5TokenizerFast, T5EncoderModel 10 | import accelerate 11 | from einops import rearrange 12 | import safetensors 13 | 14 | from models.base import BasePipeline, PreprocessMediaFile, make_contiguous 15 | from utils.common import load_state_dict, AUTOCAST_DTYPE, is_main_process 16 | from cosmos1.models.diffusion.inference.inference_utils import load_model_by_config 17 | from cosmos1.models.autoregressive.tokenizer.modules import EncoderFactorized, DecoderFactorized, CausalConv3d 18 | 19 | 20 | FRAMERATE = 24 21 | SIGMA_DATA = 0.5 22 | 23 | SUPPORTED_SIZE_BUCKETS = [ 24 | [960, 960, 1], 25 | [960, 704, 1], 26 | [704, 960, 1], 27 | [1280, 704, 1], 28 | [704, 1280, 1], 29 | [960, 960, 121], 30 | [960, 704, 121], 31 | [704, 960, 121], 32 | [1280, 704, 121], 33 | [704, 1280, 121], 34 | ] 35 | 36 | 37 | def get_per_sigma_loss_weights(sigma: torch.Tensor): 38 | """ 39 | Args: 40 | sigma (tensor): noise level 41 | 42 | Returns: 43 | loss weights per sigma noise level 44 | """ 45 | return (sigma**2 + SIGMA_DATA**2) / (sigma * SIGMA_DATA) ** 2 46 | 47 | 48 | class CausalContinuousVideoTokenizer(nn.Module): 49 | def __init__(self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs) -> None: 50 | super().__init__() 51 | self.name = kwargs.get("name", "CausalContinuousVideoTokenizer") 52 | self.embedding_dim = embedding_dim 53 | self.spatial_compression = kwargs['spatial_compression'] 54 | self.temporal_compression = kwargs['temporal_compression'] 55 | self.sigma_data = SIGMA_DATA 56 | self.encoder = EncoderFactorized(z_channels=z_factor * z_channels, **kwargs) 57 | self.decoder = DecoderFactorized(z_channels=z_channels, **kwargs) 58 | 59 | self.quant_conv = CausalConv3d(z_factor * z_channels, embedding_dim, kernel_size=1, padding=0) 60 | self.post_quant_conv = CausalConv3d(embedding_dim, z_channels, kernel_size=1, padding=0) 61 | 62 | latent_temporal_chunk = 16 63 | self.latent_mean = nn.Parameter(torch.zeros([self.embedding_dim * latent_temporal_chunk], dtype=torch.float32)) 64 | self.latent_std = nn.Parameter(torch.ones([self.embedding_dim * latent_temporal_chunk], dtype=torch.float32)) 65 | 66 | def encode(self, x): 67 | h = self.encoder(x) 68 | z = self.quant_conv(h) 69 | latent_ch = z.shape[1] 70 | latent_t = z.shape[2] 71 | dtype = z.dtype 72 | mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device) 73 | std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device) 74 | return ((z - mean) / std) * self.sigma_data 75 | 76 | def decode(self, z): 77 | in_dtype = z.dtype 78 | latent_ch = z.shape[1] 79 | latent_t = z.shape[2] 80 | mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) 81 | std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) 82 | z = z / self.sigma_data 83 | z = z * std + mean 84 | z = self.post_quant_conv(z) 85 | return self.decoder(z) 86 | 87 | 88 | def load_custom_video_vae(path): 89 | with accelerate.init_empty_weights(): 90 | vae = CausalContinuousVideoTokenizer( 91 | attn_resolutions=[32], 92 | channels=128, 93 | channels_mult=[2, 4, 4], 94 | dropout=0.0, 95 | in_channels=3, 96 | num_res_blocks=2, 97 | out_channels=3, 98 | resolution=1024, 99 | patch_size=4, 100 | patch_method="haar", 101 | z_channels=16, 102 | z_factor=1, 103 | num_groups=1, 104 | legacy_mode=False, 105 | spatial_compression=8, 106 | temporal_compression=8, 107 | embedding_dim=16, 108 | ) 109 | missing_keys, unexpected_keys = vae.load_state_dict(load_state_dict(path), assign=True, strict=False) 110 | assert len(missing_keys) == 0 111 | vae.eval() 112 | return vae 113 | 114 | 115 | def vae_encode(tensor, vae): 116 | # tensor values already in range [-1, 1] here 117 | p = next(vae.encoder.parameters()) 118 | # TODO: the official code would call vae.encode_image() when it detects frames=1. 119 | # Should we use the image encoder (separate model)? 120 | return vae.encode(tensor.to(p.device, p.dtype)) 121 | 122 | 123 | def dataset_config_validation(config): 124 | if 'min_ar' in config or 'max_ar' in config or 'num_ar_buckets' in config or 'resolutions' in config: 125 | return False 126 | size_buckets = config.get('size_buckets', []) 127 | if len(size_buckets) == 0: 128 | return False 129 | for size_bucket in size_buckets: 130 | if size_bucket not in SUPPORTED_SIZE_BUCKETS: 131 | return False 132 | return True 133 | 134 | 135 | class CosmosPipeline(BasePipeline): 136 | name = 'cosmos' 137 | framerate = FRAMERATE 138 | checkpointable_layers = ['InitialLayer', 'TransformerLayer', 'FinalLayer'] 139 | adapter_target_modules = ['GeneralDITTransformerBlock'] 140 | 141 | def __init__(self, config): 142 | self.config = config 143 | self.model_config = self.config['model'] 144 | 145 | # TODO: different model variants 146 | self.model = load_model_by_config( 147 | config_job_name='Cosmos_1_0_Diffusion_Text2World_7B', 148 | config_file='submodules/Cosmos/cosmos1/models/diffusion/config/config.py', 149 | ) 150 | 151 | self.vae = load_custom_video_vae(self.model_config['vae_path']) 152 | 153 | self.tokenizer = T5TokenizerFast( 154 | vocab_file='configs/t5_old/spiece.model', 155 | tokenizer_file='configs/t5_old/tokenizer.json', 156 | ) 157 | t5_state_dict = load_state_dict(self.model_config['text_encoder_path']) 158 | self.text_encoder = T5EncoderModel.from_pretrained( 159 | None, 160 | config='configs/t5_old/config.json', 161 | state_dict=t5_state_dict, 162 | torch_dtype='auto', 163 | local_files_only=True, 164 | ) 165 | 166 | def load_diffusion_model(self): 167 | with accelerate.init_empty_weights(): 168 | self.model.model = self.model.build_model() 169 | net_state_dict = load_state_dict(self.model_config['transformer_path']) 170 | incompatible = self.model.model.load_state_dict(net_state_dict, strict=False, assign=True) 171 | missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] 172 | assert len(missing_keys) == 0 173 | self.transformer = self.model.net 174 | 175 | def model_specific_dataset_config_validation(self, dataset_config): 176 | passes_validation = True 177 | passes_validation &= dataset_config_validation(dataset_config) 178 | for directory_config in dataset_config['directory']: 179 | passes_validation &= dataset_config_validation(directory_config) 180 | if not passes_validation: 181 | if is_main_process(): 182 | print('WARNING: Cosmos supports a limited set of resolutions. Anything else will likely not work correctly.' 183 | ' See the cosmos_dataset.toml example. If you still want to proceed with the current configuration,' 184 | ' run the script with the --i_know_what_i_am_doing flag.') 185 | quit() 186 | 187 | def get_vae(self): 188 | return self.vae 189 | 190 | def get_text_encoders(self): 191 | return [self.text_encoder] 192 | 193 | def save_adapter(self, save_dir, peft_state_dict): 194 | self.peft_config.save_pretrained(save_dir) 195 | # ComfyUI format. 196 | peft_state_dict = {'diffusion_model.'+k: v for k, v in peft_state_dict.items()} 197 | safetensors.torch.save_file(peft_state_dict, save_dir / 'adapter_model.safetensors', metadata={'format': 'pt'}) 198 | 199 | def get_preprocess_media_file_fn(self): 200 | return PreprocessMediaFile( 201 | self.config, 202 | support_video=True, 203 | framerate=self.framerate, 204 | round_height=8, 205 | round_width=8, 206 | round_frames=8, 207 | ) 208 | 209 | def get_call_vae_fn(self, vae): 210 | def fn(tensor): 211 | return {'latents': vae_encode(tensor, vae)} 212 | return fn 213 | 214 | def get_call_text_encoder_fn(self, text_encoder): 215 | def fn(captions, is_video): 216 | # args are lists 217 | batch_encoding = self.tokenizer.batch_encode_plus( 218 | captions, 219 | return_tensors="pt", 220 | truncation=True, 221 | padding="max_length", 222 | max_length=512, 223 | return_length=True, 224 | return_offsets_mapping=False, 225 | ) 226 | 227 | input_ids = batch_encoding.input_ids 228 | attn_mask = batch_encoding.attention_mask 229 | 230 | device = text_encoder.device 231 | outputs = text_encoder(input_ids=input_ids.to(device), attention_mask=attn_mask.to(device)) 232 | 233 | encoded_text = outputs.last_hidden_state 234 | lengths = attn_mask.sum(dim=1).cpu() 235 | 236 | for batch_id in range(encoded_text.shape[0]): 237 | encoded_text[batch_id][lengths[batch_id] :] = 0 238 | 239 | return {'prompt_embeds': encoded_text, 'prompt_attention_mask': attn_mask} 240 | return fn 241 | 242 | def prepare_inputs(self, inputs, timestep_quantile=None): 243 | latents = inputs['latents'].float() 244 | prompt_embeds = inputs['prompt_embeds'] 245 | mask = inputs['mask'] 246 | 247 | bs, channels, num_frames, h, w = latents.shape 248 | device = latents.device 249 | 250 | if mask is not None: 251 | mask = mask.unsqueeze(1) # make mask (bs, 1, img_h, img_w) 252 | mask = F.interpolate(mask, size=(h, w), mode='nearest-exact') # resize to latent spatial dimension 253 | mask = mask.unsqueeze(2) # make mask same number of dims as target 254 | 255 | noise = torch.randn_like(latents) 256 | 257 | dist = torch.distributions.normal.Normal(0, 1) 258 | if timestep_quantile is not None: 259 | log_sigma = dist.icdf(torch.full((bs,), timestep_quantile, device=device)) 260 | else: 261 | log_sigma = dist.sample((bs,)).to(device) 262 | sigma = torch.exp(log_sigma) 263 | 264 | x_t = latents + sigma.view(-1, 1, 1, 1, 1) * noise 265 | 266 | c_skip, c_out, c_in, c_noise = self.model.scaling(sigma=sigma) 267 | x = x_t * c_in.view(-1, 1, 1, 1, 1) 268 | timesteps = c_noise 269 | target = latents 270 | 271 | return (x, x_t, timesteps, prompt_embeds, sigma), (target, mask) 272 | 273 | def to_layers(self): 274 | layers = [InitialLayer(self.transformer, self.vae.spatial_compression)] 275 | for name, block in self.transformer.blocks.items(): 276 | layers.append(TransformerLayer(block)) 277 | layers.append(FinalLayer(self)) 278 | return layers 279 | 280 | def get_loss_fn(self): 281 | def loss_fn(output, label): 282 | output, weights_per_sigma = output 283 | target, mask = label 284 | with torch.autocast('cuda', enabled=False): 285 | output = output.to(torch.float32) 286 | target = target.to(output.device, torch.float32) 287 | loss = F.mse_loss(output, target, reduction='none') 288 | # empty tensor means no masking 289 | if mask.numel() > 0: 290 | mask = mask.to(output.device, torch.float32) 291 | loss *= mask 292 | loss = loss * weights_per_sigma 293 | loss = loss.mean() 294 | return loss 295 | return loss_fn 296 | 297 | 298 | class InitialLayer(nn.Module): 299 | def __init__(self, transformer, spatial_compression_factor): 300 | super().__init__() 301 | self.transformer = [transformer] 302 | self.spatial_compression_factor = spatial_compression_factor 303 | self.x_embedder = transformer.x_embedder 304 | self.extra_pos_embedder = transformer.extra_pos_embedder 305 | self.pos_embedder = transformer.pos_embedder 306 | self.t_embedder = transformer.t_embedder 307 | 308 | def __getattr__(self, name): 309 | return getattr(self.transformer[0], name) 310 | 311 | @torch.autocast('cuda', dtype=AUTOCAST_DTYPE) 312 | def forward(self, inputs): 313 | for item in inputs: 314 | if torch.is_floating_point(item): 315 | item.requires_grad_(True) 316 | 317 | x, x_t, timesteps, crossattn_emb, sigma = inputs 318 | original_shape = x.shape 319 | dtype = x.dtype 320 | device = x.device 321 | 322 | # Official code sets up these inputs in prepare_data_batch. 323 | fps = torch.tensor([FRAMERATE] * 1, dtype=dtype, device=device) 324 | height = x.shape[-2] * self.spatial_compression_factor 325 | width = x.shape[-1] * self.spatial_compression_factor 326 | image_size = torch.tensor([[height, width, height, width]] * 1, dtype=dtype, device=device) 327 | padding_mask = torch.zeros((1, 1, height, width), dtype=dtype, device=device) 328 | 329 | # x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( 330 | # x, 331 | # fps=fps, 332 | # padding_mask=padding_mask, 333 | # latent_condition=None, 334 | # latent_condition_sigma=None, 335 | # ) 336 | inputs = self.forward_before_blocks( 337 | x=x, 338 | timesteps=timesteps, 339 | crossattn_emb=crossattn_emb, 340 | # Model should have use_cross_attn_mask=False. Will assert fail otherwise. 341 | crossattn_mask=None, 342 | fps=fps, 343 | image_size=image_size, 344 | padding_mask=padding_mask, 345 | ) 346 | x, affline_emb_B_D, crossattn_emb, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = ( 347 | inputs["x"], 348 | inputs["affline_emb_B_D"], 349 | inputs["crossattn_emb"], 350 | inputs["rope_emb_L_1_1_D"], 351 | inputs["adaln_lora_B_3D"], 352 | inputs["original_shape"], 353 | ) 354 | extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"] 355 | if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: 356 | assert ( 357 | x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape 358 | ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" 359 | 360 | original_shape = torch.tensor(original_shape) 361 | return make_contiguous( 362 | x, 363 | x_t, 364 | affline_emb_B_D, 365 | crossattn_emb, 366 | rope_emb_L_1_1_D, 367 | adaln_lora_B_3D, 368 | extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, 369 | original_shape, 370 | sigma, 371 | ) 372 | 373 | 374 | class TransformerLayer(nn.Module): 375 | def __init__(self, block): 376 | super().__init__() 377 | self.block = block 378 | 379 | @torch.autocast('cuda', dtype=AUTOCAST_DTYPE) 380 | def forward(self, inputs): 381 | x, x_t, affline_emb_B_D, crossattn_emb, rope_emb_L_1_1_D, adaln_lora_B_3D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, original_shape, sigma = inputs 382 | x = self.block( 383 | x, 384 | affline_emb_B_D, 385 | crossattn_emb, 386 | None, 387 | rope_emb_L_1_1_D=rope_emb_L_1_1_D, 388 | adaln_lora_B_3D=adaln_lora_B_3D, 389 | extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, 390 | ) 391 | return make_contiguous( 392 | x, 393 | x_t, 394 | affline_emb_B_D, 395 | crossattn_emb, 396 | rope_emb_L_1_1_D, 397 | adaln_lora_B_3D, 398 | extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, 399 | original_shape, 400 | sigma, 401 | ) 402 | 403 | 404 | class FinalLayer(nn.Module): 405 | def __init__(self, pipeline): 406 | super().__init__() 407 | self.pipeline = pipeline 408 | self.final_layer = self.pipeline.transformer.final_layer 409 | 410 | def __getattr__(self, name): 411 | return getattr(self.pipeline.transformer, name) 412 | 413 | @torch.autocast('cuda', dtype=AUTOCAST_DTYPE) 414 | def forward(self, inputs): 415 | x, x_t, affline_emb_B_D, crossattn_emb, rope_emb_L_1_1_D, adaln_lora_B_3D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, original_shape, sigma = inputs 416 | original_shape = original_shape.tolist() 417 | 418 | x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") 419 | output = self.decoder_head( 420 | x_B_T_H_W_D=x_B_T_H_W_D, 421 | emb_B_D=affline_emb_B_D, 422 | crossattn_emb=None, 423 | origin_shape=original_shape, 424 | crossattn_mask=None, 425 | adaln_lora_B_3D=adaln_lora_B_3D, 426 | ) 427 | 428 | c_skip, c_out, c_in, c_noise = self.pipeline.model.scaling(sigma=sigma) 429 | c_skip = c_skip.view(-1, 1, 1, 1, 1) 430 | c_out = c_out.view(-1, 1, 1, 1, 1) 431 | sigma = sigma.view(-1, 1, 1, 1, 1) 432 | x0_pred = c_skip*x_t + c_out*output 433 | weights_per_sigma = get_per_sigma_loss_weights(sigma) 434 | return x0_pred, weights_per_sigma 435 | -------------------------------------------------------------------------------- /models/ltx_video.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os.path 3 | import sys 4 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(__file__)), '../submodules/LTX_Video')) 5 | 6 | import random 7 | import safetensors 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | from models.base import BasePipeline, PreprocessMediaFile, make_contiguous 13 | from utils.common import AUTOCAST_DTYPE 14 | from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline as OriginalLTXVideoPipeline 15 | from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier 16 | from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder 17 | from ltx_video.models.transformers.transformer3d import Transformer3DModel 18 | from ltx_video.models.autoencoders.vae_encode import vae_encode 19 | 20 | 21 | KEEP_IN_HIGH_PRECISION = ['norm', 'bias', 'scale_shift_table', 'patchify_proj', 'proj_out', 'adaln_single', 'caption_projection'] 22 | 23 | 24 | class LTXVideoPipeline(BasePipeline): 25 | name = 'ltx-video' 26 | framerate = 25 27 | checkpointable_layers = ['TransformerLayer'] 28 | adapter_target_modules = ['BasicTransformerBlock'] 29 | 30 | def __init__(self, config): 31 | self.config = config 32 | self.model_config = self.config['model'] 33 | dtype = self.model_config['dtype'] 34 | 35 | diffusers_path = self.model_config['diffusers_path'] 36 | single_file_path = Path(self.model_config['single_file_path']) 37 | 38 | # The VAE could be different for each model version, so we have to make sure to use different cache directories. 39 | self.name = single_file_path.stem 40 | 41 | vae = CausalVideoAutoencoder.from_pretrained(single_file_path) 42 | self.diffusers_pipeline = OriginalLTXVideoPipeline.from_pretrained( 43 | diffusers_path, 44 | transformer=None, 45 | vae=vae, 46 | patchifier=SymmetricPatchifier(patch_size=1), 47 | prompt_enhancer_image_caption_model=None, 48 | prompt_enhancer_image_caption_processor=None, 49 | prompt_enhancer_llm_model=None, 50 | prompt_enhancer_llm_tokenizer=None, 51 | torch_dtype=dtype, 52 | ) 53 | 54 | def __getattr__(self, name): 55 | return getattr(self.diffusers_pipeline, name) 56 | 57 | def load_diffusion_model(self): 58 | single_file_path = self.model_config['single_file_path'] 59 | dtype = self.model_config['dtype'] 60 | transformer_dtype = self.model_config.get('transformer_dtype', dtype) 61 | 62 | transformer = Transformer3DModel.from_pretrained(single_file_path, torch_dtype=dtype) 63 | for name, p in transformer.named_parameters(): 64 | if not (any(x in name for x in KEEP_IN_HIGH_PRECISION)): 65 | p.data = p.data.to(transformer_dtype) 66 | 67 | transformer.train() 68 | for name, p in transformer.named_parameters(): 69 | p.original_name = name 70 | self.diffusers_pipeline.transformer = transformer 71 | 72 | def get_vae(self): 73 | return self.vae 74 | 75 | def get_text_encoders(self): 76 | return [self.text_encoder] 77 | 78 | def save_adapter(self, save_dir, peft_state_dict): 79 | self.peft_config.save_pretrained(save_dir) 80 | # ComfyUI format 81 | peft_state_dict = {'diffusion_model.'+k: v for k, v in peft_state_dict.items()} 82 | safetensors.torch.save_file(peft_state_dict, save_dir / 'adapter_model.safetensors', metadata={'format': 'pt'}) 83 | 84 | def save_model(self, save_dir, diffusers_sd): 85 | raise NotImplementedError() 86 | 87 | def get_preprocess_media_file_fn(self): 88 | return PreprocessMediaFile( 89 | self.config, 90 | support_video=True, 91 | framerate=self.framerate, 92 | round_height=32, 93 | round_width=32, 94 | round_frames=8, 95 | ) 96 | 97 | def get_call_vae_fn(self, vae): 98 | def fn(tensor): 99 | latents = vae_encode( 100 | tensor.to(dtype=vae.dtype, device=vae.device), 101 | vae, 102 | vae_per_channel_normalize=True, 103 | ) 104 | return {'latents': latents} 105 | return fn 106 | 107 | def get_call_text_encoder_fn(self, text_encoder): 108 | def fn(caption, is_video): 109 | # args are lists 110 | ( 111 | prompt_embeds, 112 | prompt_attention_mask, 113 | negative_prompt_embeds, 114 | negative_prompt_attention_mask, 115 | ) = self.encode_prompt(caption, do_classifier_free_guidance=False, device=text_encoder.device) 116 | return {'prompt_embeds': prompt_embeds, 'prompt_attention_mask': prompt_attention_mask} 117 | return fn 118 | 119 | def prepare_inputs(self, inputs, timestep_quantile=None): 120 | latents = inputs['latents'].float() 121 | prompt_embeds = inputs['prompt_embeds'] 122 | prompt_attention_mask = inputs['prompt_attention_mask'] 123 | mask = inputs['mask'] 124 | 125 | bs, channels, num_frames, height, width = latents.shape 126 | 127 | temporal_downscale = self.vae.temporal_downscale_factor 128 | spatial_downscale = self.vae.spatial_downscale_factor 129 | latents, pixel_coords, conditioning_mask, num_cond_latents = ( 130 | self.prepare_conditioning( 131 | conditioning_items=[], 132 | init_latents=latents, 133 | num_frames=(num_frames-1)*temporal_downscale + 1, 134 | height=height*spatial_downscale, 135 | width=width*spatial_downscale, 136 | vae_per_channel_normalize=True, 137 | ) 138 | ) 139 | 140 | if mask is not None: 141 | # untested 142 | mask = mask.unsqueeze(1).unsqueeze(1).expand((-1, channels, num_frames, -1, -1)) # make mask (bs, c, f, img_h, img_w) 143 | mask = F.interpolate(mask, size=(height, width), mode='nearest-exact') # resize to latent spatial dimension 144 | mask, _ = self.patchifier.patchify( 145 | latents=mask 146 | ) 147 | 148 | timestep_sample_method = self.model_config.get('timestep_sample_method', 'logit_normal') 149 | 150 | if timestep_sample_method == 'logit_normal': 151 | dist = torch.distributions.normal.Normal(0, 1) 152 | elif timestep_sample_method == 'uniform': 153 | dist = torch.distributions.uniform.Uniform(0, 1) 154 | else: 155 | raise NotImplementedError() 156 | 157 | if timestep_quantile is not None: 158 | t = dist.icdf(torch.full((bs,), timestep_quantile, device=latents.device)) 159 | else: 160 | t = dist.sample((bs,)).to(latents.device) 161 | 162 | if timestep_sample_method == 'logit_normal': 163 | sigmoid_scale = self.model_config.get('sigmoid_scale', 1.0) 164 | t = t * sigmoid_scale 165 | t = torch.sigmoid(t) 166 | 167 | x_1 = latents 168 | x_0 = torch.randn_like(x_1) 169 | t_expanded = t.view(-1, 1, 1) 170 | 171 | # Copied and modified from https://github.com/Lightricks/LTX-Video-Trainer/blob/main/src/ltxv_trainer/trainer.py 172 | if mask is None: 173 | mask = torch.ones_like(x_1) 174 | first_frame_conditioning_p = self.model_config.get('first_frame_conditioning_p', 0) 175 | # If first frame conditioning is enabled, the first latent (first video frame) is left (almost) unchanged. 176 | if first_frame_conditioning_p and random.random() < first_frame_conditioning_p: 177 | t_expanded = t_expanded.repeat(1, x_1.shape[1], 1) 178 | first_frame_end_idx = height * width 179 | 180 | # if we only have one frame (e.g. when training on still images), 181 | # skip this step otherwise we have no target to train on. 182 | if first_frame_end_idx < x_1.shape[1]: 183 | t_expanded[:, :first_frame_end_idx] = 1e-5 # Small sigma close to 0 for the first frame. 184 | mask[:, :first_frame_end_idx] = 0.0 # Mask out the loss for the first frame. 185 | 186 | x_t = (1 - t_expanded) * x_1 + t_expanded * x_0 187 | target = x_0 - x_1 188 | 189 | fractional_coords = pixel_coords.to(torch.float32) 190 | fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / self.framerate) 191 | 192 | return (x_t, prompt_embeds, prompt_attention_mask, t, fractional_coords), (target, mask) 193 | 194 | def to_layers(self): 195 | transformer = self.transformer 196 | layers = [InitialLayer(transformer)] 197 | for block in transformer.transformer_blocks: 198 | layers.append(TransformerLayer(block)) 199 | layers.append(OutputLayer(transformer)) 200 | return layers 201 | 202 | def get_loss_fn(self): 203 | def loss_fn(output, label): 204 | target, mask = label 205 | with torch.autocast('cuda', enabled=False): 206 | output = output.to(torch.float32) 207 | target = target.to(output.device, torch.float32) 208 | loss = F.mse_loss(output, target, reduction='none') 209 | # empty tensor means no masking 210 | if mask.numel() > 0: 211 | mask = mask.to(output.device, torch.float32) 212 | # Copied and modified from https://github.com/Lightricks/LTX-Video-Trainer/blob/main/src/ltxv_trainer/trainer.py 213 | loss = loss.mul(mask).div(mask.mean()) # divide by mean to keep the loss scale unchanged. 214 | loss = loss.mean() 215 | return loss 216 | return loss_fn 217 | 218 | 219 | 220 | class InitialLayer(nn.Module): 221 | def __init__(self, transformer): 222 | super().__init__() 223 | self.transformer = [transformer] 224 | self.patchify_proj = transformer.patchify_proj 225 | self.timestep_scale_multiplier = transformer.timestep_scale_multiplier 226 | self.adaln_single = transformer.adaln_single 227 | self.caption_projection = transformer.caption_projection 228 | 229 | def __getattr__(self, name): 230 | return getattr(self.transformer[0], name) 231 | 232 | @torch.autocast('cuda', dtype=AUTOCAST_DTYPE) 233 | def forward(self, inputs): 234 | (hidden_states, encoder_hidden_states, encoder_attention_mask, timestep, indices_grid) = inputs 235 | 236 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 237 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 238 | encoder_attention_mask = ( 239 | 1 - encoder_attention_mask.to(hidden_states.dtype) 240 | ) * -10000.0 241 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 242 | 243 | hidden_states = self.patchify_proj(hidden_states) 244 | 245 | if self.timestep_scale_multiplier: 246 | timestep = self.timestep_scale_multiplier * timestep 247 | 248 | freqs_cos, freqs_sin = self.precompute_freqs_cis(indices_grid) 249 | 250 | batch_size = hidden_states.shape[0] 251 | timestep, embedded_timestep = self.adaln_single( 252 | timestep.flatten(), 253 | {"resolution": None, "aspect_ratio": None}, 254 | batch_size=batch_size, 255 | hidden_dtype=hidden_states.dtype, 256 | ) 257 | # Second dimension is 1 or number of tokens (if timestep_per_token) 258 | timestep = timestep.view(batch_size, -1, timestep.shape[-1]) 259 | embedded_timestep = embedded_timestep.view( 260 | batch_size, -1, embedded_timestep.shape[-1] 261 | ) 262 | 263 | if self.caption_projection is not None: 264 | batch_size = hidden_states.shape[0] 265 | encoder_hidden_states = self.caption_projection(encoder_hidden_states) 266 | encoder_hidden_states = encoder_hidden_states.view( 267 | batch_size, -1, hidden_states.shape[-1] 268 | ) 269 | 270 | outputs = make_contiguous(hidden_states, encoder_hidden_states, timestep, embedded_timestep, freqs_cos, freqs_sin, encoder_attention_mask) 271 | for tensor in outputs: 272 | if torch.is_floating_point(tensor): 273 | tensor.requires_grad_(True) 274 | return outputs 275 | 276 | 277 | class TransformerLayer(nn.Module): 278 | def __init__(self, block): 279 | super().__init__() 280 | self.block = block 281 | 282 | @torch.autocast('cuda', dtype=AUTOCAST_DTYPE) 283 | def forward(self, inputs): 284 | hidden_states, encoder_hidden_states, timestep, embedded_timestep, freqs_cos, freqs_sin, encoder_attention_mask = inputs 285 | hidden_states = self.block( 286 | hidden_states, 287 | freqs_cis=(freqs_cos, freqs_sin), 288 | attention_mask=None, 289 | encoder_hidden_states=encoder_hidden_states, 290 | encoder_attention_mask=encoder_attention_mask, 291 | timestep=timestep, 292 | ) 293 | return make_contiguous(hidden_states, encoder_hidden_states, timestep, embedded_timestep, freqs_cos, freqs_sin, encoder_attention_mask) 294 | 295 | 296 | class OutputLayer(nn.Module): 297 | def __init__(self, transformer): 298 | super().__init__() 299 | self.transformer = [transformer] 300 | self.scale_shift_table = transformer.scale_shift_table 301 | self.norm_out = transformer.norm_out 302 | self.proj_out = transformer.proj_out 303 | 304 | @torch.autocast('cuda', dtype=AUTOCAST_DTYPE) 305 | def forward(self, inputs): 306 | hidden_states, encoder_hidden_states, timestep, embedded_timestep, freqs_cos, freqs_sin, encoder_attention_mask = inputs 307 | 308 | scale_shift_values = ( 309 | self.scale_shift_table[None, None] + embedded_timestep[:, :, None] 310 | ) 311 | shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] 312 | hidden_states = self.norm_out(hidden_states) 313 | hidden_states = hidden_states * (1 + scale) + shift 314 | return self.proj_out(hidden_states) 315 | -------------------------------------------------------------------------------- /models/lumina_2.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Tuple 3 | import sys 4 | import os.path 5 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(__file__)), '../submodules')) 6 | 7 | import diffusers 8 | import transformers 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | import safetensors 13 | from accelerate import init_empty_weights 14 | from accelerate.utils import set_module_tensor_to_device 15 | 16 | from models.base import BasePipeline, make_contiguous 17 | from utils.common import AUTOCAST_DTYPE, load_state_dict 18 | 19 | from Lumina_2.models.model import NextDiT_2B_GQA_patch2_Adaln_Refiner 20 | 21 | 22 | def get_lin_function( 23 | x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 24 | ): 25 | m = (y2 - y1) / (x2 - x1) 26 | b = y1 - m * x1 27 | return lambda x: m * x + b 28 | 29 | 30 | def time_shift(mu: float, sigma: float, t: torch.Tensor): 31 | t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 32 | return t 33 | 34 | 35 | class Lumina2Pipeline(BasePipeline): 36 | name = 'lumina_2' 37 | checkpointable_layers = ['InitialLayer', 'TransformerLayer'] 38 | # This will also train the noise_refiner and context_refiner layers, which aren't part of the main stack of transformer 39 | # layers, since they also use this class. 40 | adapter_target_modules = ['JointTransformerBlock'] 41 | 42 | def __init__(self, config): 43 | self.config = config 44 | self.model_config = self.config['model'] 45 | dtype = self.model_config['dtype'] 46 | 47 | self.vae = diffusers.AutoencoderKL.from_single_file(self.model_config['vae_path'], config='configs/flux_vae') 48 | 49 | self.tokenizer = transformers.AutoTokenizer.from_pretrained('configs/gemma_2_2b') 50 | self.tokenizer.padding_side = 'right' 51 | 52 | text_encoder_config = transformers.AutoConfig.from_pretrained('configs/gemma_2_2b') 53 | with init_empty_weights(): 54 | self.text_encoder = transformers.AutoModel.from_config(text_encoder_config) 55 | state_dict = load_state_dict(self.model_config['llm_path']) 56 | for name, param in self.text_encoder.named_parameters(): 57 | set_module_tensor_to_device(self.text_encoder, name, device='cpu', dtype=dtype, value=state_dict['model.'+name]) 58 | 59 | self.text_encoder.eval() 60 | cap_feat_dim = self.text_encoder.config.hidden_size 61 | 62 | with init_empty_weights(): 63 | self.transformer = NextDiT_2B_GQA_patch2_Adaln_Refiner( 64 | in_channels=16, 65 | qk_norm=True, 66 | cap_feat_dim=cap_feat_dim, 67 | ) 68 | state_dict = load_state_dict(self.model_config['transformer_path']) 69 | for name, param in self.transformer.named_parameters(): 70 | set_module_tensor_to_device(self.transformer, name, device='cpu', dtype=dtype, value=state_dict[name]) 71 | 72 | self.transformer.train() 73 | for name, p in self.transformer.named_parameters(): 74 | p.original_name = name 75 | 76 | def get_vae(self): 77 | return self.vae 78 | 79 | def get_text_encoders(self): 80 | return [self.text_encoder] 81 | 82 | def save_adapter(self, save_dir, peft_state_dict): 83 | self.peft_config.save_pretrained(save_dir) 84 | # ComfyUI format. 85 | peft_state_dict = {'diffusion_model.'+k: v for k, v in peft_state_dict.items()} 86 | safetensors.torch.save_file(peft_state_dict, save_dir / 'adapter_model.safetensors', metadata={'format': 'pt'}) 87 | 88 | def save_model(self, save_dir, state_dict): 89 | safetensors.torch.save_file(state_dict, save_dir / 'model.safetensors', metadata={'format': 'pt'}) 90 | 91 | def get_call_vae_fn(self, vae): 92 | def fn(tensor): 93 | latents = vae.encode(tensor.to(vae.device, vae.dtype)).latent_dist.sample() 94 | if hasattr(vae.config, 'shift_factor') and vae.config.shift_factor is not None: 95 | latents = latents - vae.config.shift_factor 96 | latents = latents * vae.config.scaling_factor 97 | return {'latents': latents} 98 | return fn 99 | 100 | def get_call_text_encoder_fn(self, text_encoder): 101 | def fn(caption, is_video): 102 | # args are lists 103 | assert not any(is_video) 104 | text_inputs = self.tokenizer( 105 | caption, 106 | padding='max_length', 107 | max_length=256, 108 | truncation=True, 109 | return_tensors="pt", 110 | ) 111 | 112 | text_input_ids = text_inputs.input_ids 113 | prompt_masks = text_inputs.attention_mask 114 | 115 | device = self.text_encoder.device 116 | prompt_embeds = self.text_encoder( 117 | input_ids=text_input_ids.to(device), 118 | attention_mask=prompt_masks.to(device), 119 | output_hidden_states=True, 120 | ).hidden_states[-2] 121 | return {'prompt_embeds': prompt_embeds, 'prompt_masks': prompt_masks} 122 | return fn 123 | 124 | def prepare_inputs(self, inputs, timestep_quantile=None): 125 | latents = inputs['latents'].float() 126 | prompt_embeds = inputs['prompt_embeds'] 127 | prompt_masks = inputs['prompt_masks'] 128 | mask = inputs['mask'] 129 | 130 | bs, c, h, w = latents.shape 131 | 132 | if mask is not None: 133 | mask = mask.unsqueeze(1) # make mask (bs, 1, img_h, img_w) 134 | mask = F.interpolate(mask, size=(h, w), mode='nearest-exact') # resize to latent spatial dimension 135 | 136 | timestep_sample_method = self.model_config.get('timestep_sample_method', 'logit_normal') 137 | 138 | if timestep_sample_method == 'logit_normal': 139 | dist = torch.distributions.normal.Normal(0, 1) 140 | elif timestep_sample_method == 'uniform': 141 | dist = torch.distributions.uniform.Uniform(0, 1) 142 | else: 143 | raise NotImplementedError() 144 | 145 | if timestep_quantile is not None: 146 | t = dist.icdf(torch.full((bs,), timestep_quantile, device=latents.device)) 147 | else: 148 | t = dist.sample((bs,)).to(latents.device) 149 | 150 | if timestep_sample_method == 'logit_normal': 151 | sigmoid_scale = self.model_config.get('sigmoid_scale', 1.0) 152 | t = t * sigmoid_scale 153 | t = torch.sigmoid(t) 154 | 155 | if shift := self.model_config.get('shift', None): 156 | t = (t * shift) / (1 + (shift - 1) * t) 157 | elif self.model_config.get('lumina_shift', False): 158 | mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) 159 | t = time_shift(mu, 1.0, t) 160 | 161 | noise = torch.randn_like(latents) 162 | t_expanded = t.view(-1, 1, 1, 1) 163 | noisy_latents = (1 - t_expanded) * latents + t_expanded * noise 164 | target = latents - noise 165 | 166 | # If t is the amount of noise, then the timestep this model takes as input is 1-t. 167 | return (noisy_latents, 1-t, prompt_embeds, prompt_masks), (target, mask) 168 | 169 | def to_layers(self): 170 | transformer = self.transformer 171 | layers = [InitialLayer(transformer)] 172 | for block in transformer.layers: 173 | layers.append(TransformerLayer(block)) 174 | layers.append(FinalLayer(transformer)) 175 | return layers 176 | 177 | 178 | class InitialLayer(nn.Module): 179 | def __init__(self, model): 180 | super().__init__() 181 | self.t_embedder = model.t_embedder 182 | self.cap_embedder = model.cap_embedder 183 | self.rope_embedder = model.rope_embedder 184 | self.context_refiner = model.context_refiner 185 | self.x_embedder = model.x_embedder 186 | self.noise_refiner = model.noise_refiner 187 | self.model = [model] 188 | 189 | def __getattr__(self, name): 190 | return getattr(self.model[0], name) 191 | 192 | @torch.autocast('cuda', dtype=AUTOCAST_DTYPE) 193 | def forward(self, inputs): 194 | for item in inputs: 195 | if torch.is_floating_point(item): 196 | item.requires_grad_(True) 197 | x, t, cap_feats, cap_mask = inputs 198 | 199 | t = self.t_embedder(t) 200 | adaln_input = t 201 | cap_feats = self.cap_embedder(cap_feats) 202 | x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t) 203 | img_size = torch.tensor(img_size).to(x.device) 204 | cap_size = torch.tensor(cap_size).to(x.device) 205 | freqs_cis = freqs_cis.to(x.device) 206 | return make_contiguous(x, mask, freqs_cis, adaln_input, img_size, cap_size) 207 | 208 | def patchify_and_embed( 209 | self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor 210 | ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: 211 | bsz = len(x) 212 | pH = pW = self.patch_size 213 | device = x[0].device 214 | 215 | l_effective_cap_len = cap_mask.sum(dim=1).tolist() 216 | img_sizes = [(img.size(1), img.size(2)) for img in x] 217 | l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] 218 | 219 | first_img_len = l_effective_img_len[0] 220 | for img_len in l_effective_img_len: 221 | assert img_len == first_img_len 222 | # Modification from original code: don't allow seq_len to vary dynamically. Pipeline parallelism requires that the tensors 223 | # passed between layers always be the same size! 224 | max_seq_len = first_img_len + cap_mask.shape[-1] 225 | max_cap_len = max(l_effective_cap_len) 226 | max_img_len = max(l_effective_img_len) 227 | 228 | position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) 229 | 230 | for i in range(bsz): 231 | cap_len = l_effective_cap_len[i] 232 | img_len = l_effective_img_len[i] 233 | H, W = img_sizes[i] 234 | H_tokens, W_tokens = H // pH, W // pW 235 | assert H_tokens * W_tokens == img_len 236 | 237 | position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) 238 | position_ids[i, cap_len:cap_len+img_len, 0] = cap_len 239 | row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() 240 | col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() 241 | position_ids[i, cap_len:cap_len+img_len, 1] = row_ids 242 | position_ids[i, cap_len:cap_len+img_len, 2] = col_ids 243 | 244 | freqs_cis = self.rope_embedder(position_ids) 245 | 246 | # build freqs_cis for cap and image individually 247 | cap_freqs_cis_shape = list(freqs_cis.shape) 248 | # cap_freqs_cis_shape[1] = max_cap_len 249 | cap_freqs_cis_shape[1] = cap_feats.shape[1] 250 | cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) 251 | 252 | img_freqs_cis_shape = list(freqs_cis.shape) 253 | img_freqs_cis_shape[1] = max_img_len 254 | img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) 255 | 256 | for i in range(bsz): 257 | cap_len = l_effective_cap_len[i] 258 | img_len = l_effective_img_len[i] 259 | cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] 260 | img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] 261 | 262 | # refine context 263 | for layer in self.context_refiner: 264 | cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) 265 | 266 | # refine image 267 | flat_x = [] 268 | for i in range(bsz): 269 | img = x[i] 270 | C, H, W = img.size() 271 | img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) 272 | flat_x.append(img) 273 | x = flat_x 274 | padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) 275 | padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device) 276 | for i in range(bsz): 277 | padded_img_embed[i, :l_effective_img_len[i]] = x[i] 278 | padded_img_mask[i, :l_effective_img_len[i]] = True 279 | 280 | padded_img_embed = self.x_embedder(padded_img_embed) 281 | for layer in self.noise_refiner: 282 | padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) 283 | 284 | mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) 285 | padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype) 286 | for i in range(bsz): 287 | cap_len = l_effective_cap_len[i] 288 | img_len = l_effective_img_len[i] 289 | 290 | mask[i, :cap_len+img_len] = True 291 | padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] 292 | padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len] 293 | 294 | return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis 295 | 296 | 297 | class TransformerLayer(nn.Module): 298 | def __init__(self, block): 299 | super().__init__() 300 | self.block = block 301 | 302 | @torch.autocast('cuda', dtype=AUTOCAST_DTYPE) 303 | def forward(self, inputs): 304 | x, mask, freqs_cis, adaln_input, img_size, cap_size = inputs 305 | x = self.block(x, mask, freqs_cis, adaln_input) 306 | return make_contiguous(x, mask, freqs_cis, adaln_input, img_size, cap_size) 307 | 308 | 309 | class FinalLayer(nn.Module): 310 | def __init__(self, model): 311 | super().__init__() 312 | self.final_layer = model.final_layer 313 | # norm_final isn't used, but by registering it we will keep it in the saved model, preventing ComfyUI from logging a 314 | # warning that it's missing. 315 | self.norm_final = model.norm_final 316 | self.model = [model] 317 | 318 | def __getattr__(self, name): 319 | return getattr(self.model[0], name) 320 | 321 | @torch.autocast('cuda', dtype=AUTOCAST_DTYPE) 322 | def forward(self, inputs): 323 | x, mask, freqs_cis, adaln_input, img_size, cap_size = inputs 324 | x = self.final_layer(x, adaln_input) 325 | img_size = [(row[0].item(), row[1].item()) for row in img_size] 326 | cap_size = [row.item() for row in cap_size] 327 | return self.unpatchify(x, img_size, cap_size, return_tensor=True) 328 | -------------------------------------------------------------------------------- /optimizers/adamw_8bit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import bitsandbytes 3 | import bitsandbytes.functional as F 4 | 5 | 6 | class AdamW8bitKahan(bitsandbytes.optim.AdamW8bit): 7 | def __init__(self, *args, stabilize=True, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self.stabilize = stabilize 10 | 11 | @torch.no_grad() 12 | def init_state(self, group, p, gindex, pindex): 13 | super().init_state(group, p, gindex, pindex) 14 | self.state[p]['shift'] = self.get_state_buffer(p, dtype=p.dtype) 15 | 16 | @torch.no_grad() 17 | def update_step(self, group, p, gindex, pindex): 18 | # avoid update error from non-contiguous memory layout 19 | p.data = p.data.contiguous() 20 | p.grad = p.grad.contiguous() 21 | 22 | state = self.state[p] 23 | grad = p.grad 24 | 25 | config = self.get_config(gindex, pindex, group) 26 | 27 | state["step"] += 1 28 | step = state["step"] 29 | 30 | if config["percentile_clipping"] < 100: 31 | current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( 32 | grad, 33 | state["gnorm_vec"], 34 | step, 35 | config["percentile_clipping"], 36 | ) 37 | else: 38 | gnorm_scale = 1.0 39 | 40 | shift = state['shift'] 41 | 42 | # StableAdamW 43 | if self.stabilize: 44 | exp_avg_sq = state['state2'] 45 | eps_sq = torch.tensor(config['eps']**2, dtype=exp_avg_sq.dtype, device=exp_avg_sq.device) 46 | rms = grad.pow(2).div_(exp_avg_sq.maximum(eps_sq)).mean().sqrt() 47 | lr = config['lr'] / max(1, rms.item()) 48 | else: 49 | lr = config['lr'] 50 | 51 | if state["state1"].dtype == torch.float: 52 | F.optimizer_update_32bit( 53 | self.optimizer_name, 54 | grad, 55 | shift, 56 | state["state1"], 57 | config["betas"][0], 58 | config["eps"], 59 | step, 60 | lr, 61 | state["state2"], 62 | config["betas"][1], 63 | config["betas"][2] if len(config["betas"]) >= 3 else 0.0, 64 | config["alpha"], 65 | config["weight_decay"], 66 | gnorm_scale, 67 | state["unorm_vec"] if config["max_unorm"] > 0.0 else None, 68 | max_unorm=config["max_unorm"], 69 | skip_zeros=config["skip_zeros"], 70 | ) 71 | 72 | elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: 73 | F.optimizer_update_8bit( 74 | self.optimizer_name, 75 | grad, 76 | shift, 77 | state["state1"], 78 | state["state2"], 79 | config["betas"][0], 80 | config["betas"][1], 81 | config["eps"], 82 | step, 83 | lr, 84 | state["qmap1"], 85 | state["qmap2"], 86 | state["max1"], 87 | state["max2"], 88 | state["new_max1"], 89 | state["new_max2"], 90 | config["weight_decay"], 91 | gnorm_scale=gnorm_scale, 92 | unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None, 93 | max_unorm=config["max_unorm"], 94 | ) 95 | 96 | # swap maxes 97 | state["max1"], state["new_max1"] = state["new_max1"], state["max1"] 98 | state["max2"], state["new_max2"] = state["new_max2"], state["max2"] 99 | elif state["state1"].dtype == torch.uint8 and config["block_wise"]: 100 | F.optimizer_update_8bit_blockwise( 101 | self.optimizer_name, 102 | grad, 103 | shift, 104 | state["state1"], 105 | state["state2"], 106 | config["betas"][0], 107 | config["betas"][1], 108 | config["betas"][2] if len(config["betas"]) >= 3 else 0.0, 109 | config["alpha"], 110 | config["eps"], 111 | step, 112 | lr, 113 | state["qmap1"], 114 | state["qmap2"], 115 | state["absmax1"], 116 | state["absmax2"], 117 | config["weight_decay"], 118 | gnorm_scale=gnorm_scale, 119 | skip_zeros=config["skip_zeros"], 120 | ) 121 | 122 | buffer = p.clone() 123 | p.add_(shift) 124 | shift.add_(buffer.sub_(p)) 125 | -------------------------------------------------------------------------------- /optimizers/automagic.py: -------------------------------------------------------------------------------- 1 | # Copied from AI Toolkit. 2 | # I added Kahan summation for bfloat16 parameters. 3 | 4 | # MIT License 5 | 6 | # Copyright (c) 2024 Ostris, LLC 7 | 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | 27 | from typing import List 28 | import torch 29 | from optimizers.optimizer_utils import Auto8bitTensor, copy_stochastic, stochastic_grad_accummulation 30 | from optimum.quanto import QBytesTensor 31 | import random 32 | 33 | 34 | class Automagic(torch.optim.Optimizer): 35 | def __init__( 36 | self, 37 | params, 38 | lr=1e-6, # lr is start lr 39 | min_lr=1e-7, 40 | max_lr=1e-3, 41 | lr_bump=1e-6, # amount to bump the lr when adjusting 42 | eps=(1e-30, 1e-3), 43 | clip_threshold=1.0, 44 | beta2=0.999, 45 | weight_decay=0.0, 46 | do_paramiter_swapping=False, 47 | paramiter_swapping_factor=0.1, 48 | ): 49 | self.lr = lr 50 | if self.lr > 1e-3: 51 | print(f"Warning! Start lr is very high: {self.lr}. Forcing to 1e-6. this does not work like prodigy") 52 | self.lr = 1e-6 53 | self.min_lr = min_lr 54 | self.max_lr = max_lr 55 | self.lr_bump = lr_bump 56 | 57 | defaults = { 58 | "lr": lr, 59 | "eps": eps, 60 | "clip_threshold": clip_threshold, 61 | "beta2": beta2, 62 | "weight_decay": weight_decay, 63 | } 64 | super().__init__(params, defaults) 65 | 66 | self.base_lrs: List[float] = [ 67 | lr for group in self.param_groups 68 | ] 69 | 70 | self.is_stochastic_rounding_accumulation = False 71 | 72 | # setup stochastic grad accum hooks 73 | for group in self.param_groups: 74 | for param in group['params']: 75 | if param.requires_grad and param.dtype != torch.float32: 76 | self.is_stochastic_rounding_accumulation = True 77 | param.register_post_accumulate_grad_hook( 78 | stochastic_grad_accummulation 79 | ) 80 | 81 | self.do_paramiter_swapping = do_paramiter_swapping 82 | self.paramiter_swapping_factor = paramiter_swapping_factor 83 | self._total_paramiter_size = 0 84 | # count total paramiters 85 | for group in self.param_groups: 86 | for param in group['params']: 87 | self._total_paramiter_size += torch.numel(param) 88 | # pretty print total paramiters with comma seperation 89 | print(f"Total training paramiters: {self._total_paramiter_size:,}") 90 | 91 | # needs to be enabled to count paramiters 92 | if self.do_paramiter_swapping: 93 | self.enable_paramiter_swapping(self.paramiter_swapping_factor) 94 | 95 | def enable_paramiter_swapping(self, paramiter_swapping_factor=0.1): 96 | self.do_paramiter_swapping = True 97 | self.paramiter_swapping_factor = paramiter_swapping_factor 98 | # call it an initial time 99 | self.swap_paramiters() 100 | 101 | def swap_paramiters(self): 102 | all_params = [] 103 | # deactivate all paramiters 104 | for group in self.param_groups: 105 | for param in group['params']: 106 | param.requires_grad_(False) 107 | # remove any grad 108 | param.grad = None 109 | all_params.append(param) 110 | # shuffle all paramiters 111 | random.shuffle(all_params) 112 | 113 | # keep activating paramiters until we are going to go over the target paramiters 114 | target_paramiters = int( 115 | self._total_paramiter_size * self.paramiter_swapping_factor) 116 | total_paramiters = 0 117 | for param in all_params: 118 | total_paramiters += torch.numel(param) 119 | if total_paramiters >= target_paramiters: 120 | break 121 | else: 122 | param.requires_grad_(True) 123 | 124 | @staticmethod 125 | def _get_lr(param_group, param_state): 126 | if 'avg_lr' in param_state: 127 | lr = param_state["avg_lr"] 128 | else: 129 | lr = 0.0 130 | return lr 131 | 132 | def _get_group_lr(self, group): 133 | group_lrs = [] 134 | for p in group["params"]: 135 | group_lrs.append(self._get_lr(group, self.state[p])) 136 | # return avg 137 | if len(group_lrs) == 0: 138 | return self.lr 139 | return sum(group_lrs) / len(group_lrs) 140 | 141 | @staticmethod 142 | def _rms(tensor): 143 | return tensor.norm(2) / (tensor.numel() ** 0.5) 144 | 145 | @staticmethod 146 | def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): 147 | r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=- 148 | 1, keepdim=True)).rsqrt_().unsqueeze(-1) 149 | c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() 150 | return torch.mul(r_factor, c_factor) 151 | 152 | def step_hook(self): 153 | if not self.is_stochastic_rounding_accumulation: 154 | return 155 | # copy over stochastically rounded grads 156 | for group in self.param_groups: 157 | for param in group['params']: 158 | if param.requires_grad and hasattr(param, "_accum_grad"): 159 | param.grad = param._accum_grad 160 | del param._accum_grad 161 | 162 | # automagic manages its own lr 163 | def get_learning_rates(self): 164 | 165 | lrs = [ 166 | self._get_group_lr(group) 167 | for group in self.param_groups 168 | ] 169 | if len(lrs) == 0: 170 | lrs = self.base_lrs # if called before stepping 171 | return lrs 172 | 173 | def get_avg_learning_rate(self): 174 | lrs = self.get_learning_rates() 175 | return sum(lrs) / len(lrs) 176 | 177 | @torch.no_grad() 178 | def step(self, closure=None): 179 | """ 180 | Performs a single optimization step 181 | 182 | Arguments: 183 | closure (callable, optional): A closure that reevaluates the model 184 | and returns the loss. 185 | """ 186 | self.step_hook() 187 | loss = None 188 | if closure is not None: 189 | loss = closure() 190 | 191 | for group in self.param_groups: 192 | for p in group["params"]: 193 | if p.grad is None or not p.requires_grad: 194 | continue 195 | 196 | grad = p.grad 197 | if grad.dtype != torch.float32: 198 | grad = grad.to(torch.float32) 199 | if grad.is_sparse: 200 | raise RuntimeError( 201 | "Automagic does not support sparse gradients.") 202 | 203 | state = self.state[p] 204 | grad_shape = grad.shape 205 | 206 | factored = len(grad_shape) >= 2 207 | # State Initialization 208 | if len(state) == 0: 209 | self.initialize_state(p) 210 | else: 211 | # Check if exp_avg_sq_row and exp_avg_sq_col exist for factored case 212 | if factored: 213 | if "exp_avg_sq_row" not in state or "exp_avg_sq_col" not in state: 214 | state["exp_avg_sq_row"] = torch.zeros(p.shape[:-1]).to(grad) 215 | state["exp_avg_sq_col"] = torch.zeros(p.shape[:-2] + p.shape[-1:]).to(grad) 216 | else: 217 | state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) 218 | state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) 219 | # Check if exp_avg_sq exists for non-factored case 220 | else: 221 | if "exp_avg_sq" not in state: 222 | state["exp_avg_sq"] = torch.zeros_like(grad) 223 | else: 224 | state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) 225 | 226 | p_data_fp32 = p 227 | 228 | if isinstance(p_data_fp32, QBytesTensor): 229 | p_data_fp32 = p_data_fp32.dequantize() 230 | if p.dtype != torch.float32: 231 | p_data_fp32 = p_data_fp32.clone().float() 232 | 233 | # Initialize step if it doesn't exist 234 | if "step" not in state: 235 | state["step"] = 0 236 | state["step"] += 1 237 | state["RMS"] = self._rms(p_data_fp32) 238 | 239 | # Use fixed beta2 from group instead of decay_rate calculation 240 | beta2 = group["beta2"] 241 | eps = group["eps"] 242 | if isinstance(eps, tuple) or isinstance(eps, list): 243 | eps = eps[0] 244 | update = (grad**2) + eps 245 | if factored: 246 | exp_avg_sq_row = state["exp_avg_sq_row"] 247 | exp_avg_sq_col = state["exp_avg_sq_col"] 248 | 249 | exp_avg_sq_row.mul_(beta2).add_( 250 | update.mean(dim=-1), alpha=(1.0 - beta2)) 251 | exp_avg_sq_col.mul_(beta2).add_( 252 | update.mean(dim=-2), alpha=(1.0 - beta2)) 253 | 254 | # Approximation of exponential moving average of square of gradient 255 | update = self._approx_sq_grad( 256 | exp_avg_sq_row, exp_avg_sq_col) 257 | update.mul_(grad) 258 | else: 259 | exp_avg_sq = state["exp_avg_sq"] 260 | 261 | exp_avg_sq.mul_(beta2).add_(update, alpha=(1.0 - beta2)) 262 | update = exp_avg_sq.rsqrt().mul_(grad) 263 | 264 | update.div_( 265 | (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) 266 | 267 | # Ensure state is properly initialized 268 | if 'last_polarity' not in state or 'lr_mask' not in state: 269 | self.initialize_state(p) 270 | 271 | # Get signs of current last update and updates 272 | last_polarity = state['last_polarity'] 273 | current_polarity = (update > 0).to(torch.bool) 274 | sign_agreement = torch.where( 275 | last_polarity == current_polarity, 1, -1) 276 | state['last_polarity'] = current_polarity 277 | 278 | lr_mask = state['lr_mask'].to(torch.float32) 279 | 280 | # Update learning rate mask based on sign agreement 281 | new_lr = torch.where( 282 | sign_agreement > 0, 283 | lr_mask + self.lr_bump, # Increase lr 284 | lr_mask - self.lr_bump # Decrease lr 285 | ) 286 | 287 | # Clip learning rates to bounds 288 | new_lr = torch.clamp( 289 | new_lr, 290 | min=self.min_lr, 291 | max=self.max_lr 292 | ) 293 | 294 | # Apply the learning rate mask to the update 295 | update.mul_(new_lr) 296 | 297 | state['lr_mask'] = Auto8bitTensor(new_lr) 298 | state['avg_lr'] = torch.mean(new_lr) 299 | 300 | if group["weight_decay"] != 0: 301 | # Apply weight decay with per-parameter learning rates 302 | # Instead of using add_ with a tensor alpha (which isn't supported), 303 | # we'll use element-wise multiplication to apply the weight decay 304 | weight_decay_update = p_data_fp32 * (-group["weight_decay"]) * new_lr 305 | p_data_fp32.add_(weight_decay_update) 306 | 307 | if p.dtype == torch.bfloat16: 308 | # Kahan summation for bfloat16 309 | update.mul_(-1) 310 | update.add_(weight_decay_update) 311 | shift = state['shift'] 312 | shift.add_(update) 313 | # Use grad as temp buffer 314 | grad.copy_(p.detach()) 315 | p.add_(shift) 316 | shift.add_(grad.sub_(p)) 317 | else: 318 | p_data_fp32.add_(-update) 319 | if p.dtype != torch.float32: 320 | # apply stochastic rounding 321 | copy_stochastic(p, p_data_fp32) 322 | 323 | return loss 324 | 325 | def initialize_state(self, p): 326 | state = self.state[p] 327 | state["step"] = 0 328 | 329 | # store the lr mask 330 | if 'lr_mask' not in state: 331 | state['lr_mask'] = Auto8bitTensor(torch.ones( 332 | p.shape).to(p.device, dtype=torch.float32) * self.lr 333 | ) 334 | state['avg_lr'] = torch.mean( 335 | state['lr_mask'].to(torch.float32)) 336 | if 'last_polarity' not in state: 337 | state['last_polarity'] = torch.zeros( 338 | p.shape, dtype=torch.bool, device=p.device) 339 | 340 | factored = len(p.shape) >= 2 341 | if factored: 342 | state["exp_avg_sq_row"] = torch.zeros( 343 | p.shape[:-1]).to(p) 344 | state["exp_avg_sq_col"] = torch.zeros( 345 | p.shape[:-2] + p.shape[-1:]).to(p) 346 | else: 347 | state["exp_avg_sq"] = torch.zeros_like(p) 348 | 349 | state["RMS"] = 0 350 | # For Kahan summation. 351 | if p.dtype == torch.bfloat16: 352 | state['shift'] = torch.zeros_like(p) 353 | 354 | # override the state_dict to save the lr_mask 355 | def state_dict(self, *args, **kwargs): 356 | orig_state_dict = super().state_dict(*args, **kwargs) 357 | # convert the state to quantized tensor to scale and quantized 358 | new_sace_state = {} 359 | for p, state in orig_state_dict['state'].items(): 360 | save_state = {k: v for k, v in state.items() if k != 'lr_mask'} 361 | 362 | # Check if lr_mask exists in the state before trying to access it 363 | if 'lr_mask' in state: 364 | save_state['lr_mask'] = state['lr_mask'].state_dict() 365 | 366 | new_sace_state[p] = save_state 367 | 368 | orig_state_dict['state'] = new_sace_state 369 | 370 | return orig_state_dict 371 | 372 | def load_state_dict(self, state_dict, strict=True): 373 | # Validate that the state_dict is from an Automagic optimizer 374 | is_valid_automagic_state = False 375 | 376 | # Check if state_dict has the expected structure 377 | if 'state' in state_dict and isinstance(state_dict['state'], dict): 378 | # Check if at least one state entry has an lr_mask, which is specific to Automagic 379 | for param_id, param_state in state_dict['state'].items(): 380 | if isinstance(param_state, dict) and 'lr_mask' in param_state: 381 | is_valid_automagic_state = True 382 | break 383 | 384 | if not is_valid_automagic_state: 385 | return 386 | 387 | # First, call the parent class's load_state_dict to load the basic optimizer state 388 | # We'll handle the lr_mask separately 389 | state_dict_copy = { 390 | 'state': {}, 391 | 'param_groups': state_dict['param_groups'] 392 | } 393 | 394 | # Copy all state entries except lr_mask 395 | for param_id, param_state in state_dict['state'].items(): 396 | state_dict_copy['state'][param_id] = { 397 | k: v for k, v in param_state.items() if k != 'lr_mask' 398 | } 399 | 400 | # Call parent class load_state_dict with the modified state dict 401 | super().load_state_dict(state_dict_copy) 402 | 403 | # Now handle the lr_mask separately 404 | # We need to map the saved parameters to the current parameters 405 | # This is tricky because the parameter IDs might be different 406 | 407 | # Get all current parameters that require gradients 408 | current_params = [] 409 | for group in self.param_groups: 410 | for p in group['params']: 411 | if p.requires_grad: 412 | current_params.append(p) 413 | 414 | # If the number of parameters doesn't match, we can't reliably map them 415 | if len(current_params) != len(state_dict['param_groups'][0]['params']): 416 | print(f"WARNING: Number of parameters doesn't match between saved state ({len(state_dict['param_groups'][0]['params'])}) " 417 | f"and current model ({len(current_params)}). Learning rate masks may not be correctly loaded.") 418 | 419 | # Map parameters by their position in the param_groups 420 | # This assumes the order of parameters is preserved between saving and loading 421 | saved_param_ids = list(state_dict['state'].keys()) 422 | 423 | for i, current_param in enumerate(current_params): 424 | if i >= len(saved_param_ids): 425 | break 426 | 427 | saved_param_id = saved_param_ids[i] 428 | saved_state = state_dict['state'][saved_param_id] 429 | 430 | # Skip if this saved state doesn't have an lr_mask 431 | if 'lr_mask' not in saved_state: 432 | continue 433 | 434 | # Initialize the state for this parameter if it doesn't exist 435 | if current_param not in self.state: 436 | self.initialize_state(current_param) 437 | 438 | # Get the current state for this parameter 439 | current_state = self.state[current_param] 440 | 441 | # Load the lr_mask from the saved state 442 | saved_lr_mask = saved_state['lr_mask'] 443 | 444 | # Reconstruct the Auto8bitTensor from its state dict 445 | try: 446 | # Make sure the shapes match 447 | if 'quantized' in saved_lr_mask and saved_lr_mask['quantized'].shape == current_param.shape: 448 | saved_lr_mask['quantized'] = saved_lr_mask['quantized'].to(current_param.device) 449 | current_state['lr_mask'] = Auto8bitTensor(saved_lr_mask) 450 | else: 451 | print(f"WARNING: Shape mismatch for parameter {i}. " 452 | f"Expected {current_param.shape}, got {saved_lr_mask['quantized'].shape if 'quantized' in saved_lr_mask else 'unknown'}. " 453 | f"Initializing new lr_mask.") 454 | # Initialize a new lr_mask 455 | current_state['lr_mask'] = Auto8bitTensor(torch.ones( 456 | current_param.shape).to(current_param.device, dtype=torch.float32) * self.lr 457 | ) 458 | except Exception as e: 459 | print(f"ERROR: Failed to load lr_mask for parameter {i}: {e}") 460 | # Initialize a new lr_mask 461 | current_state['lr_mask'] = Auto8bitTensor(torch.ones( 462 | current_param.shape).to(current_param.device, dtype=torch.float32) * self.lr 463 | ) 464 | -------------------------------------------------------------------------------- /optimizers/gradient_release.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # Simple wrapper for use with gradient release. Grad hooks do the optimizer steps, so this no-ops 4 | # the step() and zero_grad() methods. It also handles state_dict. 5 | class GradientReleaseOptimizerWrapper(torch.optim.Optimizer): 6 | def __init__(self, optimizers): 7 | self.optimizers = optimizers 8 | 9 | @property 10 | def param_groups(self): 11 | ret = [] 12 | for opt in self.optimizers: 13 | ret.extend(opt.param_groups) 14 | return ret 15 | 16 | def state_dict(self): 17 | return {i: opt.state_dict() for i, opt in enumerate(self.optimizers)} 18 | 19 | def load_state_dict(self, state_dict): 20 | for i, sd in state_dict.items(): 21 | self.optimizers[i].load_state_dict(sd) 22 | 23 | def step(self): 24 | pass 25 | 26 | def zero_grad(self, set_to_none=True): 27 | pass -------------------------------------------------------------------------------- /optimizers/optimizer_utils.py: -------------------------------------------------------------------------------- 1 | # Copied from AI Toolkit. 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2024 Ostris, LLC 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | 26 | import torch 27 | from torch import Tensor 28 | from typing import Optional 29 | from optimum.quanto import QBytesTensor 30 | 31 | 32 | def compute_scale_for_dtype(tensor, dtype): 33 | """ 34 | Compute appropriate scale for the given tensor and target dtype. 35 | 36 | Args: 37 | tensor: Input tensor to be quantized 38 | dtype: Target dtype for quantization 39 | Returns: 40 | Appropriate scale factor for the quantization 41 | """ 42 | if dtype == torch.int8: 43 | abs_max = torch.max(torch.abs(tensor)) 44 | return abs_max / 127.0 if abs_max > 0 else 1.0 45 | elif dtype == torch.uint8: 46 | max_val = torch.max(tensor) 47 | min_val = torch.min(tensor) 48 | range_val = max_val - min_val 49 | return range_val / 255.0 if range_val > 0 else 1.0 50 | elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): 51 | # For float8, we typically want to preserve the magnitude of the values 52 | # while fitting within the representable range of the format 53 | abs_max = torch.max(torch.abs(tensor)) 54 | if dtype == torch.float8_e4m3fn: 55 | # e4m3fn has range [-448, 448] with no infinities 56 | max_representable = 448.0 57 | else: # torch.float8_e5m2 58 | # e5m2 has range [-57344, 57344] with infinities 59 | max_representable = 57344.0 60 | 61 | return abs_max / max_representable if abs_max > 0 else 1.0 62 | else: 63 | raise ValueError(f"Unsupported dtype for quantization: {dtype}") 64 | 65 | def quantize_tensor(tensor, dtype): 66 | """ 67 | Quantize a floating-point tensor to the target dtype with appropriate scaling. 68 | 69 | Args: 70 | tensor: Input tensor (float) 71 | dtype: Target dtype for quantization 72 | Returns: 73 | quantized_data: Quantized tensor 74 | scale: Scale factor used 75 | """ 76 | scale = compute_scale_for_dtype(tensor, dtype) 77 | 78 | if dtype == torch.int8: 79 | quantized_data = torch.clamp(torch.round(tensor / scale), -128, 127).to(dtype) 80 | elif dtype == torch.uint8: 81 | quantized_data = torch.clamp(torch.round(tensor / scale), 0, 255).to(dtype) 82 | elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): 83 | # For float8, we scale and then cast directly to the target type 84 | # The casting operation will handle the appropriate rounding 85 | scaled_tensor = tensor / scale 86 | quantized_data = scaled_tensor.to(dtype) 87 | else: 88 | raise ValueError(f"Unsupported dtype for quantization: {dtype}") 89 | 90 | return quantized_data, scale 91 | 92 | 93 | def update_parameter(target, result_float): 94 | """ 95 | Updates a parameter tensor, handling both regular torch.Tensor and QBytesTensor cases 96 | with proper rescaling for quantized tensors. 97 | 98 | Args: 99 | target: The parameter to update (either torch.Tensor or QBytesTensor) 100 | result_float: The new values to assign (torch.Tensor) 101 | """ 102 | if isinstance(target, QBytesTensor): 103 | # Get the target dtype from the existing quantized tensor 104 | target_dtype = target._data.dtype 105 | 106 | # Handle device placement 107 | device = target._data.device 108 | result_float = result_float.to(device) 109 | 110 | # Compute new quantized values and scale 111 | quantized_data, new_scale = quantize_tensor(result_float, target_dtype) 112 | 113 | # Update the internal tensors with newly computed values 114 | target._data.copy_(quantized_data) 115 | target._scale.copy_(new_scale) 116 | else: 117 | # Regular tensor update 118 | target.copy_(result_float) 119 | 120 | 121 | def get_format_params(dtype: torch.dtype) -> tuple[int, int]: 122 | """ 123 | Returns (mantissa_bits, total_bits) for each format. 124 | mantissa_bits excludes the implicit leading 1. 125 | """ 126 | if dtype == torch.float32: 127 | return 23, 32 128 | elif dtype == torch.bfloat16: 129 | return 7, 16 130 | elif dtype == torch.float16: 131 | return 10, 16 132 | elif dtype == torch.float8_e4m3fn: 133 | return 3, 8 134 | elif dtype == torch.float8_e5m2: 135 | return 2, 8 136 | elif dtype == torch.int8: 137 | return 0, 8 # Int8 doesn't have mantissa bits 138 | else: 139 | raise ValueError(f"Unsupported dtype: {dtype}") 140 | 141 | 142 | def copy_stochastic( 143 | target: torch.Tensor, 144 | source: torch.Tensor, 145 | eps: Optional[float] = None 146 | ) -> None: 147 | """ 148 | Performs stochastic rounding from source tensor to target tensor. 149 | 150 | Args: 151 | target: Destination tensor (determines the target format) 152 | source: Source tensor (typically float32) 153 | eps: Optional minimum value for stochastic rounding (for numerical stability) 154 | """ 155 | with torch.no_grad(): 156 | # If target is float32, just copy directly 157 | if target.dtype == torch.float32: 158 | target.copy_(source) 159 | return 160 | 161 | # Special handling for int8 162 | if target.dtype == torch.int8: 163 | # Scale the source values to utilize the full int8 range 164 | scaled = source * 127.0 # Scale to [-127, 127] 165 | 166 | # Add random noise for stochastic rounding 167 | noise = torch.rand_like(scaled) - 0.5 168 | rounded = torch.round(scaled + noise) 169 | 170 | # Clamp to int8 range 171 | clamped = torch.clamp(rounded, -127, 127) 172 | target.copy_(clamped.to(torch.int8)) 173 | return 174 | 175 | mantissa_bits, _ = get_format_params(target.dtype) 176 | 177 | # Convert source to int32 view 178 | source_int = source.view(dtype=torch.int32) 179 | 180 | # Calculate number of bits to round 181 | bits_to_round = 23 - mantissa_bits # 23 is float32 mantissa bits 182 | 183 | # Create random integers for stochastic rounding 184 | rand = torch.randint_like( 185 | source, 186 | dtype=torch.int32, 187 | low=0, 188 | high=(1 << bits_to_round), 189 | ) 190 | 191 | # Add random values to the bits that will be rounded off 192 | result = source_int.clone() 193 | result.add_(rand) 194 | 195 | # Mask to keep only the bits we want 196 | # Create mask with 1s in positions we want to keep 197 | mask = (-1) << bits_to_round 198 | result.bitwise_and_(mask) 199 | 200 | # Handle minimum value threshold if specified 201 | if eps is not None: 202 | eps_int = torch.tensor( 203 | eps, dtype=torch.float32).view(dtype=torch.int32) 204 | zero_mask = (result.abs() < eps_int) 205 | result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int 206 | 207 | # Convert back to float32 view 208 | result_float = result.view(dtype=torch.float32) 209 | 210 | # Special handling for float8 formats 211 | if target.dtype == torch.float8_e4m3fn: 212 | result_float.clamp_(-448.0, 448.0) 213 | elif target.dtype == torch.float8_e5m2: 214 | result_float.clamp_(-57344.0, 57344.0) 215 | 216 | # Copy the result to the target tensor 217 | update_parameter(target, result_float) 218 | # target.copy_(result_float) 219 | del result, rand, source_int 220 | 221 | 222 | class Auto8bitTensor: 223 | def __init__(self, data: Tensor, *args, **kwargs): 224 | if isinstance(data, dict): # Add constructor from state dict 225 | self._load_from_state_dict(data) 226 | else: 227 | abs_max = data.abs().max().item() 228 | scale = abs_max / 127.0 if abs_max > 0 else 1.0 229 | 230 | self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8) 231 | self.scale = scale 232 | self.orig_dtype = data.dtype 233 | 234 | def dequantize(self) -> Tensor: 235 | return self.quantized.to(dtype=torch.float32) * self.scale 236 | 237 | def to(self, *args, **kwargs): 238 | # Handle the dtype argument whether it's positional or keyword 239 | dtype = None 240 | if args and isinstance(args[0], torch.dtype): 241 | dtype = args[0] 242 | args = args[1:] 243 | elif 'dtype' in kwargs: 244 | dtype = kwargs['dtype'] 245 | del kwargs['dtype'] 246 | 247 | if dtype is not None: 248 | # First dequantize then convert to requested dtype 249 | return self.dequantize().to(dtype=dtype, *args, **kwargs) 250 | 251 | # If no dtype specified, just pass through to parent 252 | return self.dequantize().to(*args, **kwargs) 253 | 254 | def state_dict(self): 255 | """Returns a dictionary containing the current state of the tensor.""" 256 | return { 257 | 'quantized': self.quantized, 258 | 'scale': self.scale, 259 | 'orig_dtype': self.orig_dtype 260 | } 261 | 262 | def _load_from_state_dict(self, state_dict): 263 | """Loads the tensor state from a state dictionary.""" 264 | self.quantized = state_dict['quantized'] 265 | self.scale = state_dict['scale'] 266 | self.orig_dtype = state_dict['orig_dtype'] 267 | 268 | def __str__(self): 269 | return f"Auto8bitTensor({self.dequantize()})" 270 | 271 | 272 | def stochastic_grad_accummulation(param): 273 | if hasattr(param, "_accum_grad"): 274 | grad_fp32 = param._accum_grad.clone().to(torch.float32) 275 | grad_fp32.add_(param.grad.to(torch.float32)) 276 | copy_stochastic(param._accum_grad, grad_fp32) 277 | del grad_fp32 278 | del param.grad 279 | else: 280 | param._accum_grad = param.grad.clone() 281 | del param.grad -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | deepspeed 2 | toml 3 | transformers 4 | diffusers>=0.32.1 5 | datasets 6 | pillow 7 | sentencepiece 8 | protobuf 9 | peft 10 | torch-optimi 11 | tensorboard 12 | tqdm 13 | safetensors 14 | bitsandbytes 15 | imageio[ffmpeg] 16 | av 17 | einops 18 | accelerate 19 | loguru 20 | flash-attn 21 | omegaconf 22 | iopath 23 | termcolor 24 | hydra-core 25 | easydict 26 | ftfy 27 | pytorch-optimizer 28 | wandb -------------------------------------------------------------------------------- /tools/cosmos_vae_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | import argparse 4 | import os.path 5 | sys.path.insert(0, os.path.abspath('submodules/Cosmos')) 6 | 7 | import torch 8 | import torchvision 9 | from torch import nn 10 | import accelerate 11 | 12 | from models.base import PreprocessMediaFile 13 | from utils.common import load_state_dict 14 | from cosmos1.models.diffusion.inference.inference_utils import load_model_by_config 15 | from cosmos1.utils.lazy_config import instantiate as lazy_instantiate 16 | from cosmos1.models.autoregressive.tokenizer.modules import EncoderFactorized, DecoderFactorized, CausalConv3d 17 | 18 | torch.set_grad_enabled(False) 19 | 20 | 21 | OFFICIAL_VIDEO_VAE_PATH = '/data2/imagegen_models/cosmos/Cosmos-1.0-Tokenizer-CV8x8x8' 22 | COMFYUI_VIDEO_VAE_WEIGHTS = '/data2/imagegen_models/cosmos/cosmos_cv8x8x8_1.0.safetensors' 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--input', type=Path, required=True) 27 | 28 | args = parser.parse_args() 29 | assert args.input.is_file() 30 | 31 | 32 | class CausalContinuousVideoTokenizer(nn.Module): 33 | def __init__(self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs) -> None: 34 | super().__init__() 35 | self.name = kwargs.get("name", "CausalContinuousVideoTokenizer") 36 | self.embedding_dim = embedding_dim 37 | self.sigma_data = 0.5 38 | self.encoder = EncoderFactorized(z_channels=z_factor * z_channels, **kwargs) 39 | self.decoder = DecoderFactorized(z_channels=z_channels, **kwargs) 40 | 41 | self.quant_conv = CausalConv3d(z_factor * z_channels, embedding_dim, kernel_size=1, padding=0) 42 | self.post_quant_conv = CausalConv3d(embedding_dim, z_channels, kernel_size=1, padding=0) 43 | 44 | latent_temporal_chunk = 16 45 | self.latent_mean = nn.Parameter(torch.zeros([self.embedding_dim * latent_temporal_chunk], dtype=torch.float32)) 46 | self.latent_std = nn.Parameter(torch.ones([self.embedding_dim * latent_temporal_chunk], dtype=torch.float32)) 47 | 48 | 49 | def encode(self, x): 50 | h = self.encoder(x) 51 | z = self.quant_conv(h) 52 | latent_ch = z.shape[1] 53 | latent_t = z.shape[2] 54 | dtype = z.dtype 55 | mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device) 56 | std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device) 57 | return ((z - mean) / std) * self.sigma_data 58 | 59 | def decode(self, z): 60 | in_dtype = z.dtype 61 | latent_ch = z.shape[1] 62 | latent_t = z.shape[2] 63 | mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) 64 | std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) 65 | z = z / self.sigma_data 66 | z = z * std + mean 67 | z = self.post_quant_conv(z) 68 | return self.decoder(z) 69 | 70 | 71 | def load_official_video_vae(): 72 | model = load_model_by_config( 73 | config_job_name='Cosmos_1_0_Diffusion_Text2World_7B', 74 | config_file='submodules/Cosmos/cosmos1/models/diffusion/config/config.py', 75 | ) 76 | vae = lazy_instantiate(model.config.tokenizer) 77 | vae.load_weights(OFFICIAL_VIDEO_VAE_PATH) 78 | vae.sigma_data = model.sigma_data 79 | return vae 80 | 81 | 82 | def load_custom_video_vae(): 83 | with accelerate.init_empty_weights(): 84 | vae = CausalContinuousVideoTokenizer( 85 | attn_resolutions=[32], 86 | channels=128, 87 | channels_mult=[2, 4, 4], 88 | dropout=0.0, 89 | in_channels=3, 90 | num_res_blocks=2, 91 | out_channels=3, 92 | resolution=1024, 93 | patch_size=4, 94 | patch_method="haar", 95 | z_channels=16, 96 | z_factor=1, 97 | num_groups=1, 98 | legacy_mode=False, 99 | spatial_compression=8, 100 | temporal_compression=8, 101 | embedding_dim=16, 102 | ) 103 | missing_keys, unexpected_keys = vae.load_state_dict(load_state_dict(COMFYUI_VIDEO_VAE_WEIGHTS), assign=True, strict=False) 104 | assert len(missing_keys) == 0 105 | vae.eval() 106 | return vae 107 | 108 | 109 | if __name__ == '__main__': 110 | vae = load_custom_video_vae().to('cuda') 111 | preprocessor = PreprocessMediaFile({}, support_video=True, framerate=24, round_height=8, round_width=8, round_frames=8) 112 | 113 | target_frames = 33 if args.input.suffix == '.mp4' else 1 114 | tensor = preprocessor(args.input, size_bucket=(720, 720, target_frames))[0].unsqueeze(0) 115 | 116 | p = next(vae.encoder.parameters()) 117 | device, dtype = p.device, p.dtype 118 | print(f'Input shape: {tensor.shape}') 119 | latents = vae.encode(tensor.to(device, dtype)) 120 | print(f'Latents shape: {latents.shape}') 121 | decoded = vae.decode(latents).to('cpu', torch.float32) 122 | print(f'Decoded shape: {decoded.shape}') 123 | 124 | decoded = decoded.squeeze(0) 125 | decoded = ((decoded + 1) / 2).clamp(0, 1) 126 | 127 | if decoded.shape[1] == 1: 128 | img = decoded.squeeze(1) 129 | pil_img = torchvision.transforms.functional.to_pil_image(img) 130 | output_path = args.input.with_stem(args.input.stem + '_decoded') 131 | pil_img.save(output_path) 132 | -------------------------------------------------------------------------------- /tools/hunyuan_video_vae_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | import argparse 4 | import os.path 5 | sys.path.insert(0, os.path.abspath('submodules/HunyuanVideo')) 6 | 7 | import torch 8 | from PIL import Image 9 | import torchvision 10 | 11 | from utils.common import VIDEO_EXTENSIONS 12 | from hyvideo.vae import load_vae 13 | 14 | 15 | MODEL_BASE = Path('/home/anon/HunyuanVideo/ckpts') 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--input', type=Path, required=True) 19 | parser.add_argument('--output', type=Path, required=True) 20 | 21 | args = parser.parse_args() 22 | assert args.input.is_file() 23 | assert args.output.is_dir() 24 | 25 | 26 | def vae_encode(tensor, vae): 27 | # tensor values already in range [-1, 1] here 28 | latents = vae.encode(tensor).latent_dist.sample() 29 | return latents * vae.config.scaling_factor 30 | 31 | 32 | def vae_decode(latents, vae): 33 | # tensor values already in range [-1, 1] here 34 | latents = latents / vae.config.scaling_factor 35 | tensor = vae.decode(latents, return_dict=False)[0] 36 | return tensor 37 | 38 | 39 | if __name__ == '__main__': 40 | vae, _, s_ratio, t_ratio = load_vae( 41 | '884-16c-hy', 42 | 'bf16', 43 | vae_path=MODEL_BASE / 'hunyuan-video-t2v-720p/vae', 44 | device='cuda', 45 | ) 46 | 47 | if args.input.suffix in VIDEO_EXTENSIONS: 48 | raise NotImplementedError() 49 | else: 50 | pil_img = Image.open(args.input) 51 | video = torchvision.transforms.functional.to_tensor(pil_img).unsqueeze(1).unsqueeze(0) 52 | 53 | video = (video * 2) - 1 54 | 55 | latents = vae_encode(video.to(vae.device, vae.dtype), vae) 56 | video = vae_decode(latents, vae).to('cpu', torch.float32) 57 | 58 | video = ((video + 1) / 2).clamp(0, 1) 59 | 60 | if video.shape[2] == 1: 61 | img = video.squeeze(2).squeeze(0) 62 | pil_img = torchvision.transforms.functional.to_pil_image(img) 63 | pil_img.save(args.output / args.input.name) 64 | -------------------------------------------------------------------------------- /tools/image_resize_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from PIL import Image, ImageOps, ImageFilter 5 | from tqdm import tqdm 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--input', type=Path, required=True) 10 | 11 | args = parser.parse_args() 12 | assert args.input.is_dir() 13 | 14 | 15 | def convert_crop_and_resize(pil_img, width_and_height): 16 | if pil_img.mode not in ['RGB', 'RGBA'] and 'transparency' in pil_img.info: 17 | pil_img = pil_img.convert('RGBA') 18 | 19 | # add white background for transparent images 20 | if pil_img.mode == 'RGBA': 21 | canvas = Image.new('RGBA', pil_img.size, (255, 255, 255)) 22 | canvas.alpha_composite(pil_img) 23 | pil_img = canvas.convert('RGB') 24 | else: 25 | pil_img = pil_img.convert('RGB') 26 | 27 | pil_img = pil_img.filter(ImageFilter.GaussianBlur(2)) 28 | return ImageOps.fit(pil_img, width_and_height) 29 | 30 | 31 | if __name__ == '__main__': 32 | for path in tqdm(list(args.input.glob('*'))): 33 | if path.suffix == '.txt': 34 | continue 35 | if '_' in path.stem: 36 | continue 37 | try: 38 | img = Image.open(path) 39 | except Exception: 40 | print(f'Image {path.name} could not be opened. Skipping.') 41 | continue 42 | scaled_image = convert_crop_and_resize(img, (512, 512)) 43 | output_path = path.with_stem(path.stem + '_scaled1').with_suffix('.png') 44 | scaled_image.save(output_path) 45 | -------------------------------------------------------------------------------- /tools/wan_vae_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | import argparse 4 | import os.path 5 | sys.path.insert(0, os.path.abspath('submodules/Wan2_1')) 6 | 7 | import torch 8 | import torchvision 9 | 10 | from models.base import PreprocessMediaFile 11 | from wan import configs as wan_configs 12 | from wan.modules.vae import WanVAE 13 | 14 | CKPT_DIR = '/data2/imagegen_models/Wan2.1-T2V-1.3B' 15 | 16 | torch.set_grad_enabled(False) 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--input', type=Path, required=True) 20 | 21 | args = parser.parse_args() 22 | assert args.input.is_file() 23 | 24 | 25 | def vae_encode(tensor, vae): 26 | return vae.model.encode(tensor, vae.scale) 27 | 28 | 29 | def vae_decode(tensor, vae): 30 | return vae.model.decode(tensor, vae.scale) 31 | 32 | 33 | def write_image(decoded, name): 34 | assert decoded.ndim == 5 and decoded.shape[2] == 1, decoded.shape 35 | decoded = decoded.squeeze(0) 36 | decoded = ((decoded + 1) / 2).clamp(0, 1) 37 | 38 | img = decoded.squeeze(1) 39 | pil_img = torchvision.transforms.functional.to_pil_image(img) 40 | output_path = args.input.with_name(name + '.jpg') 41 | pil_img.save(output_path) 42 | 43 | 44 | if __name__ == '__main__': 45 | wan_config = wan_configs.t2v_1_3B 46 | vae = WanVAE( 47 | vae_pth=os.path.join(CKPT_DIR, wan_config.vae_checkpoint), 48 | device='cuda', 49 | ) 50 | 51 | preprocessor = PreprocessMediaFile({}, support_video=True, framerate=16, round_height=8, round_width=8, round_frames=8) 52 | 53 | target_frames = 33 if args.input.suffix == '.mp4' else 1 54 | tensor = preprocessor(args.input, None, size_bucket=(624, 624, target_frames))[0][0].unsqueeze(0) 55 | 56 | p = next(vae.model.parameters()) 57 | device, dtype = p.device, p.dtype 58 | 59 | print(f'Input shape: {tensor.shape}') 60 | latents = vae_encode(tensor.to(device, dtype), vae) 61 | print(f'Latents shape: {latents.shape}') 62 | first_frame = latents[:, :, 0:1, ...] 63 | decoded = vae_decode(first_frame, vae) 64 | print(f'Decoded shape: {decoded.shape}') 65 | write_image(decoded, args.input.stem + '_decoded') 66 | 67 | tensor[:, :, 1:, ...] = 0 68 | latents = vae_encode(tensor.to(device, dtype), vae) 69 | print(latents[:, :, -1, :, :]) 70 | first_frame = latents[:, :, 0:1, ...] 71 | decoded = vae_decode(first_frame, vae) 72 | write_image(decoded, args.input.stem + '_decoded2') 73 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import gc 3 | import time 4 | 5 | import torch 6 | import deepspeed.comm.comm as dist 7 | import imageio 8 | from safetensors import safe_open 9 | 10 | 11 | DTYPE_MAP = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16, 'float8': torch.float8_e4m3fn} 12 | VIDEO_EXTENSIONS = set(x.extension for x in imageio.config.video_extensions) 13 | AUTOCAST_DTYPE = None 14 | 15 | 16 | def get_rank(): 17 | return dist.get_rank() 18 | 19 | 20 | def is_main_process(): 21 | return get_rank() == 0 22 | 23 | 24 | @contextmanager 25 | def zero_first(): 26 | if not is_main_process(): 27 | dist.barrier() 28 | yield 29 | if is_main_process(): 30 | dist.barrier() 31 | 32 | 33 | def empty_cuda_cache(): 34 | gc.collect() 35 | torch.cuda.empty_cache() 36 | 37 | 38 | @contextmanager 39 | def log_duration(name): 40 | start = time.time() 41 | try: 42 | yield 43 | finally: 44 | print(f'{name}: {time.time()-start:.3f}') 45 | 46 | 47 | def load_safetensors(path): 48 | tensors = {} 49 | with safe_open(path, framework="pt", device="cpu") as f: 50 | for key in f.keys(): 51 | tensors[key] = f.get_tensor(key) 52 | return tensors 53 | 54 | 55 | def load_state_dict(path): 56 | path = str(path) 57 | if path.endswith('.safetensors'): 58 | return load_safetensors(path) 59 | else: 60 | return torch.load(path, weights_only=True) 61 | 62 | 63 | def round_to_nearest_multiple(x, multiple): 64 | return int(round(x / multiple) * multiple) 65 | 66 | 67 | def round_down_to_multiple(x, multiple): 68 | return int((x // multiple) * multiple) 69 | -------------------------------------------------------------------------------- /utils/isolate_rng.py: -------------------------------------------------------------------------------- 1 | # copy/pasted from pytorch lightning 2 | # https://github.com/Lightning-AI/lightning/blob/0d52f4577310b5a1624bed4d23d49e37fb05af9e/src/lightning_fabric/utilities/seed.py 3 | # and 4 | # https://github.com/Lightning-AI/lightning/blob/98f7696d1681974d34fad59c03b4b58d9524ed13/src/pytorch_lightning/utilities/seed.py 5 | 6 | # Copyright The Lightning team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from contextlib import contextmanager 21 | from typing import Generator, Dict, Any 22 | 23 | import torch 24 | import numpy as np 25 | from random import getstate as python_get_rng_state 26 | from random import setstate as python_set_rng_state 27 | 28 | 29 | def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: 30 | """Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python.""" 31 | states = { 32 | "torch": torch.get_rng_state(), 33 | "numpy": np.random.get_state(), 34 | "python": python_get_rng_state(), 35 | } 36 | if include_cuda: 37 | try: 38 | states["torch.cuda"] = torch.cuda.get_rng_state_all() 39 | except RuntimeError: 40 | # CUDA initialization failure. 41 | pass 42 | return states 43 | 44 | 45 | def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None: 46 | """Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current 47 | process.""" 48 | torch.set_rng_state(rng_state_dict["torch"]) 49 | # torch.cuda rng_state is only included since v1.8. 50 | if "torch.cuda" in rng_state_dict: 51 | torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"]) 52 | np.random.set_state(rng_state_dict["numpy"]) 53 | version, state, gauss = rng_state_dict["python"] 54 | python_set_rng_state((version, tuple(state), gauss)) 55 | 56 | 57 | @contextmanager 58 | def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]: 59 | """A context manager that resets the global random state on exit to what it was before entering. 60 | It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators. 61 | Args: 62 | include_cuda: Whether to allow this function to also control the `torch.cuda` random number generator. 63 | Set this to ``False`` when using the function in a forked process where CUDA re-initialization is 64 | prohibited. 65 | Example: 66 | >>> import torch 67 | >>> torch.manual_seed(1) # doctest: +ELLIPSIS 68 | 69 | >>> with isolate_rng(): 70 | ... [torch.rand(1) for _ in range(3)] 71 | [tensor([0.7576]), tensor([0.2793]), tensor([0.4031])] 72 | >>> torch.rand(1) 73 | tensor([0.7576]) 74 | """ 75 | states = _collect_rng_states(include_cuda) 76 | yield 77 | _set_rng_states(states) 78 | -------------------------------------------------------------------------------- /utils/offloading.py: -------------------------------------------------------------------------------- 1 | # Copied from Musubi Tuner with modifications: 2 | # https://github.com/kohya-ss/musubi-tuner/blob/main/modules/custom_offloading_utils.py 3 | 4 | # NOTE: this is modified to work with LoRA training only, and checks for 'lora' in the parameter 5 | # names when doing some of the swapping. It does this because for the optimizer step, all the trained 6 | # params need to be on GPU. Musubi Tuner and sd-scripts don't have this problem because the LoRA modules 7 | # are completely separate in those projects, while here we use PEFT which replaces the linear layers with 8 | # LoRA modules, and therefore when moving parts of the model to/from the GPU we have to take special consideration 9 | # of the LoRA params which are what is being trained. 10 | 11 | from concurrent.futures import ThreadPoolExecutor 12 | import gc 13 | import time 14 | from typing import Optional 15 | import torch 16 | import torch.nn as nn 17 | 18 | 19 | def clean_memory_on_device(device: torch.device): 20 | r""" 21 | Clean memory on the specified device, will be called from training scripts. 22 | """ 23 | gc.collect() 24 | 25 | # device may "cuda" or "cuda:0", so we need to check the type of device 26 | if device.type == "cuda": 27 | torch.cuda.empty_cache() 28 | if device.type == "xpu": 29 | torch.xpu.empty_cache() 30 | if device.type == "mps": 31 | torch.mps.empty_cache() 32 | 33 | 34 | def synchronize_device(device: torch.device): 35 | if device.type == "cuda": 36 | torch.cuda.synchronize() 37 | elif device.type == "xpu": 38 | torch.xpu.synchronize() 39 | elif device.type == "mps": 40 | torch.mps.synchronize() 41 | 42 | 43 | def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): 44 | assert layer_to_cpu.__class__ == layer_to_cuda.__class__ 45 | 46 | weight_swap_jobs = [] 47 | 48 | # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules 49 | # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): 50 | # print(module_to_cpu.__class__, module_to_cuda.__class__) 51 | # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: 52 | # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) 53 | 54 | modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()} 55 | for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules(): 56 | if 'lora' in module_to_cuda_name: 57 | continue 58 | if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None: 59 | module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None) 60 | if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape: 61 | weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) 62 | else: 63 | if module_to_cuda.weight.data.device.type != device.type: 64 | # print( 65 | # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device" 66 | # ) 67 | module_to_cuda.weight.data = module_to_cuda.weight.data.to(device) 68 | 69 | torch.cuda.current_stream().synchronize() # this prevents the illegal loss value 70 | 71 | stream = torch.cuda.Stream() 72 | with torch.cuda.stream(stream): 73 | # cuda to cpu 74 | for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: 75 | cuda_data_view.record_stream(stream) 76 | module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) 77 | 78 | stream.synchronize() 79 | 80 | # cpu to cuda 81 | for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: 82 | cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) 83 | module_to_cuda.weight.data = cuda_data_view 84 | 85 | stream.synchronize() 86 | torch.cuda.current_stream().synchronize() # this prevents the illegal loss value 87 | 88 | 89 | def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): 90 | """ 91 | not tested 92 | """ 93 | assert layer_to_cpu.__class__ == layer_to_cuda.__class__ 94 | 95 | weight_swap_jobs = [] 96 | for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): 97 | if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: 98 | weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) 99 | 100 | # device to cpu 101 | for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: 102 | module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) 103 | 104 | synchronize_device() 105 | 106 | # cpu to device 107 | for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: 108 | cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) 109 | module_to_cuda.weight.data = cuda_data_view 110 | 111 | synchronize_device() 112 | 113 | 114 | def weights_to_device(layer: nn.Module, device: torch.device): 115 | for name, module in layer.named_modules(): 116 | if device.type == 'cpu' and 'lora' in name: 117 | continue 118 | if hasattr(module, "weight") and module.weight is not None: 119 | module.weight.data = module.weight.data.to(device, non_blocking=True) 120 | 121 | 122 | class Offloader: 123 | """ 124 | common offloading class 125 | """ 126 | 127 | def __init__(self, block_type: str, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): 128 | self.block_type = block_type 129 | self.blocks = blocks 130 | self.num_blocks = num_blocks 131 | self.blocks_to_swap = blocks_to_swap 132 | self.blocks_to_swap_tmp = None 133 | self.device = device 134 | self.debug = debug 135 | 136 | self.thread_pool = ThreadPoolExecutor(max_workers=1) 137 | self.futures = {} 138 | self.cuda_available = device.type == "cuda" 139 | 140 | def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): 141 | if self.cuda_available: 142 | swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda) 143 | else: 144 | swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) 145 | 146 | def _submit_move_blocks(self, block_idx_to_cpu, block_idx_to_cuda): 147 | def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): 148 | if self.debug: 149 | start_time = time.perf_counter() 150 | print( 151 | f"[{self.block_type}] Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}" 152 | ) 153 | 154 | self.swap_weight_devices(block_to_cpu, block_to_cuda) 155 | 156 | if self.debug: 157 | print(f"[{self.block_type}] Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") 158 | return bidx_to_cpu, bidx_to_cuda # , event 159 | 160 | block_to_cpu = self.blocks[block_idx_to_cpu] 161 | block_to_cuda = self.blocks[block_idx_to_cuda] 162 | 163 | self.futures[block_idx_to_cuda] = self.thread_pool.submit( 164 | move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda 165 | ) 166 | 167 | def _wait_blocks_move(self, block_idx): 168 | if block_idx not in self.futures: 169 | return 170 | 171 | if self.debug: 172 | print(f"[{self.block_type}] Wait for block {block_idx}") 173 | start_time = time.perf_counter() 174 | 175 | future = self.futures.pop(block_idx) 176 | _, bidx_to_cuda = future.result() 177 | 178 | assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}" 179 | 180 | if self.debug: 181 | print(f"[{self.block_type}] Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") 182 | 183 | 184 | class ModelOffloader(Offloader): 185 | """ 186 | supports forward offloading 187 | """ 188 | 189 | def __init__( 190 | self, 191 | block_type: str, 192 | blocks: list[nn.Module], 193 | num_blocks: int, 194 | blocks_to_swap: int, 195 | supports_backward: bool, 196 | device: torch.device, 197 | reentrant_activation_checkpointing: bool, 198 | debug: bool = False, 199 | ): 200 | super().__init__(block_type, blocks, num_blocks, blocks_to_swap, device, debug) 201 | 202 | self.supports_backward = supports_backward 203 | self.forward_only = not supports_backward # forward only offloading: can be changed to True for inference 204 | self.reentrant_activation_checkpointing = reentrant_activation_checkpointing 205 | 206 | if self.supports_backward: 207 | # register backward hooks 208 | self.remove_handles = [] 209 | for i, block in enumerate(blocks): 210 | hook = self.create_backward_hook(i) 211 | if hook is not None: 212 | handle = block.register_full_backward_hook(hook) 213 | self.remove_handles.append(handle) 214 | 215 | def disable_block_swap(self): 216 | self.blocks_to_swap_tmp = self.blocks_to_swap 217 | self.blocks_to_swap = None 218 | 219 | def enable_block_swap(self): 220 | if self.blocks_to_swap_tmp is not None: 221 | self.blocks_to_swap = self.blocks_to_swap_tmp 222 | 223 | def set_forward_only(self, forward_only: bool): 224 | self.forward_only = forward_only 225 | 226 | def __del__(self): 227 | if self.supports_backward: 228 | for handle in self.remove_handles: 229 | handle.remove() 230 | 231 | def create_backward_hook(self, block_index: int) -> Optional[callable]: 232 | # -1 for 0-based index 233 | num_blocks_propagated = self.num_blocks - block_index - 1 234 | swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap 235 | waiting = block_index > 0 and block_index <= self.blocks_to_swap 236 | 237 | if not swapping and not waiting: 238 | return None 239 | 240 | # create hook 241 | block_idx_to_cpu = self.num_blocks - num_blocks_propagated 242 | block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated 243 | block_idx_to_wait = block_index - 1 244 | 245 | def backward_hook(module, grad_input, grad_output): 246 | if self.debug: 247 | print(f"Backward hook for block {block_index}") 248 | 249 | if swapping: 250 | self._submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) 251 | if waiting: 252 | self._wait_blocks_move(block_idx_to_wait) 253 | return None 254 | 255 | return backward_hook 256 | 257 | def prepare_block_devices_before_forward(self): 258 | if self.blocks_to_swap is None or self.blocks_to_swap == 0: 259 | for block in self.blocks: 260 | block.to(self.device) 261 | return 262 | 263 | if self.debug: 264 | print(f"[{self.block_type}] Prepare block devices before forward") 265 | 266 | for b in self.blocks[0 : self.num_blocks - self.blocks_to_swap]: 267 | b.to(self.device) 268 | weights_to_device(b, self.device) # make sure weights are on device 269 | 270 | for b in self.blocks[self.num_blocks - self.blocks_to_swap :]: 271 | b.to(self.device) # move block to device first 272 | weights_to_device(b, torch.device('cpu')) # make sure weights are on cpu 273 | 274 | synchronize_device(self.device) 275 | clean_memory_on_device(self.device) 276 | 277 | def wait_for_block(self, block_idx: int): 278 | if self.blocks_to_swap is None or self.blocks_to_swap == 0: 279 | return 280 | if self.reentrant_activation_checkpointing and torch.is_grad_enabled(): 281 | # Second forward pass, don't do block swapping 282 | return 283 | self._wait_blocks_move(block_idx) 284 | 285 | def submit_move_blocks_forward(self, block_idx: int): 286 | # check if blocks_to_swap is enabled 287 | if self.blocks_to_swap is None or self.blocks_to_swap == 0: 288 | return 289 | 290 | if self.reentrant_activation_checkpointing and torch.is_grad_enabled(): 291 | # Second forward pass, don't do block swapping 292 | return 293 | 294 | # if supports_backward and backward is enabled, we swap blocks more than blocks_to_swap in backward pass 295 | if not self.forward_only and block_idx >= self.blocks_to_swap: 296 | return 297 | 298 | block_idx_to_cpu = block_idx 299 | block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx 300 | block_idx_to_cuda = block_idx_to_cuda % self.num_blocks # this works for forward-only offloading 301 | self._submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) -------------------------------------------------------------------------------- /utils/patches.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import sys 3 | import os.path 4 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(__file__)), '../submodules/HunyuanVideo')) 5 | 6 | import torch 7 | from torch import nn 8 | import peft 9 | from peft.tuners._buffer_dict import BufferDict 10 | from transformers import CLIPTextModel, AutoModel 11 | import deepspeed 12 | from deepspeed.runtime.pipe.schedule import ( 13 | SendGrad, RecvActivation, SendActivation, RecvGrad, LoadMicroBatch, ForwardPass, BackwardPass, 14 | ReduceTiedGrads, ReduceGrads, OptimizerStep, 15 | ) 16 | from deepspeed import comm as dist 17 | from deepspeed.utils import groups 18 | try: 19 | from torch._six import inf 20 | except ModuleNotFoundError: 21 | from torch import inf 22 | from deepspeed.accelerator import get_accelerator 23 | 24 | import hyvideo.text_encoder 25 | from hyvideo.constants import PRECISION_TO_TYPE, TEXT_ENCODER_PATH 26 | 27 | 28 | def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optional[torch.device] = None) -> None: 29 | """ 30 | Move the adapter of the given name to the device of the base layer. 31 | """ 32 | if device is None: 33 | # check weight and qweight (for GPTQ) 34 | for weight_name in ("weight", "qweight"): 35 | weight = getattr(self.get_base_layer(), weight_name, None) 36 | if weight is not None: 37 | device = weight.device 38 | dtype = weight.dtype 39 | break 40 | else: 41 | # no break encountered: could not determine the device 42 | return 43 | 44 | meta = torch.device("meta") 45 | 46 | # loop through all potential adapter layers and move them to the device of the base layer; be careful to only 47 | # move this specific adapter to the device, as the other adapters could be on different devices 48 | # see #1639 49 | for adapter_layer_name in self.adapter_layer_names + self.other_param_names: 50 | adapter_layer = getattr(self, adapter_layer_name, None) 51 | if not isinstance(adapter_layer, (nn.ModuleDict, nn.ParameterDict, BufferDict)): 52 | continue 53 | if adapter_name not in adapter_layer: 54 | continue 55 | if any(p.device == meta for p in adapter_layer.parameters()): 56 | continue 57 | 58 | if ((weight.dtype.is_floating_point or weight.dtype.is_complex) 59 | # This is the part I added. 60 | and not (weight.dtype == torch.float8_e4m3fn or weight.dtype == torch.float8_e5m2)): 61 | adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device, dtype=dtype) 62 | else: 63 | adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device) 64 | 65 | 66 | def load_text_encoder( 67 | text_encoder_type, 68 | text_encoder_precision=None, 69 | text_encoder_path=None, 70 | logger=None, 71 | device=None, 72 | ): 73 | if text_encoder_path is None: 74 | text_encoder_path = TEXT_ENCODER_PATH[text_encoder_type] 75 | if logger is not None: 76 | logger.info( 77 | f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}" 78 | ) 79 | 80 | torch_dtype = 'auto' 81 | if text_encoder_precision is not None: 82 | torch_dtype = PRECISION_TO_TYPE[text_encoder_precision] 83 | 84 | if text_encoder_type == "clipL": 85 | text_encoder = CLIPTextModel.from_pretrained(text_encoder_path, torch_dtype=torch_dtype) 86 | text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm 87 | elif text_encoder_type == "llm": 88 | text_encoder = AutoModel.from_pretrained( 89 | text_encoder_path, low_cpu_mem_usage=True, torch_dtype=torch_dtype 90 | ) 91 | text_encoder.final_layer_norm = text_encoder.norm 92 | else: 93 | raise ValueError(f"Unsupported text encoder type: {text_encoder_type}") 94 | # from_pretrained will ensure that the model is in eval mode. 95 | 96 | text_encoder.requires_grad_(False) 97 | 98 | if logger is not None: 99 | logger.info(f"Text encoder to dtype: {text_encoder.dtype}") 100 | 101 | if device is not None: 102 | text_encoder = text_encoder.to(device) 103 | 104 | return text_encoder, text_encoder_path 105 | 106 | 107 | def train_schedule_steps(self): 108 | prev_micro_batch_id = -1 109 | total_steps = 2 * (self.micro_batches + self.stages - 1) 110 | for step_id in range(total_steps): 111 | # Map the step of the pipeline to the micro-batch id and also whether it is a 112 | # forward or backward pass step. 113 | micro_batch_id, is_forward = self._step_to_micro_batch(step_id) 114 | 115 | if self._valid_micro_batch(prev_micro_batch_id): 116 | prev_buffer = self._buffer_idx(prev_micro_batch_id) 117 | if self._valid_micro_batch(micro_batch_id): 118 | curr_buffer = self._buffer_idx(micro_batch_id) 119 | 120 | cmds = [] 121 | 122 | # First/last stage loads 123 | if self.stage_id == 0 or self.stage_id == self.stages - 1: 124 | if is_forward and self._valid_micro_batch(micro_batch_id): 125 | cmds.append(LoadMicroBatch(curr_buffer)) 126 | 127 | # Exchange activations 128 | if is_forward: 129 | if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(self.prev_stage): 130 | cmds.append(SendGrad(prev_buffer)) 131 | if self._valid_micro_batch(micro_batch_id) and self._valid_stage(self.prev_stage): 132 | cmds.append(RecvActivation(curr_buffer)) 133 | else: 134 | if self._valid_micro_batch(micro_batch_id) and self._valid_stage(self.next_stage): 135 | cmds.append(RecvGrad(curr_buffer)) 136 | if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(self.next_stage): 137 | cmds.append(SendActivation(prev_buffer)) 138 | 139 | # Computation 140 | if self._valid_micro_batch(micro_batch_id): 141 | if is_forward: 142 | cmds.append(ForwardPass(curr_buffer)) 143 | else: 144 | cmds.append(BackwardPass(curr_buffer)) 145 | 146 | # Model step at the end of the batch 147 | if step_id == total_steps - 1: 148 | cmds.append(ReduceTiedGrads()) 149 | cmds.append(ReduceGrads()) 150 | cmds.append(OptimizerStep()) 151 | 152 | # Prepare state for next time 153 | prev_micro_batch_id = micro_batch_id 154 | yield cmds 155 | 156 | 157 | def broadcast_model(self): 158 | for n, p in self.module.named_parameters(): 159 | if torch.is_tensor(p) and p.requires_grad: 160 | orig_device = p.device 161 | move_to_gpu = (orig_device != self.device) 162 | if move_to_gpu: 163 | p.data = p.data.to(self.device) 164 | dist.broadcast(p.data, groups._get_broadcast_src_rank(), group=self.seq_data_parallel_group) 165 | if move_to_gpu: 166 | p.data = p.data.to(orig_device) 167 | 168 | 169 | def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): 170 | """Clips gradient norm of an iterable of parameters. 171 | 172 | This has been adapted from Nvidia megatron. We add norm averaging 173 | to consider MoE params when calculating norm as they will result 174 | in different norms across different ranks. 175 | 176 | This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and 177 | added functionality to handle model parallel parameters. Note that 178 | the gradients are modified in place. 179 | 180 | Arguments: 181 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 182 | single Tensor that will have gradients normalized 183 | max_norm (float or int): max norm of the gradients 184 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for 185 | infinity norm. 186 | 187 | Returns: 188 | Total norm of the parameters (viewed as a single vector). 189 | """ 190 | if isinstance(parameters, torch.Tensor): 191 | parameters = [parameters] 192 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 193 | norm_type = float(norm_type) 194 | all_norms = [] 195 | if norm_type == inf: 196 | for p in parameters: 197 | all_norms.append(p.grad.data.abs().max().float()) 198 | total_norm = torch.stack(all_norms).max() 199 | total_norm = total_norm.to(get_accelerator().current_device_name()) 200 | # Take max across all GPUs. 201 | if mpu is not None: 202 | dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) 203 | else: 204 | total_norm = 0 205 | for p in parameters: 206 | if mpu is not None: 207 | if (mpu.get_model_parallel_rank() == 0) or deepspeed.runtime.utils.is_model_parallel_parameter(p): 208 | param_norm = p.grad.data.detach().float().norm(norm_type) 209 | all_norms.append(param_norm) 210 | else: 211 | param_norm = p.grad.data.detach().float().norm(norm_type) 212 | all_norms.append(param_norm) 213 | if len(all_norms) > 0: 214 | total_norm = torch.stack(all_norms).square().sum().float() 215 | else: 216 | total_norm = get_accelerator().FloatTensor([0.0]) 217 | total_norm = total_norm.to(get_accelerator().current_device_name()) 218 | # Sum across all model parallel GPUs. 219 | if mpu is not None: 220 | dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) 221 | total_norm = total_norm.pow(1. / norm_type) 222 | 223 | # Need to average total_norm across different GPUs due to the presence of moe params 224 | pg = groups._get_data_parallel_group() 225 | scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg)) 226 | scaled_norm_tensor = scaled_norm 227 | 228 | dist.all_reduce(scaled_norm_tensor, group=pg) 229 | total_norm = scaled_norm_tensor 230 | # Change this from the original Deepspeed code. 231 | if len(parameters) > 0: 232 | total_norm = total_norm.to(parameters[0].device) 233 | 234 | max_norm = torch.tensor([float(max_norm)], device=total_norm.device) 235 | clip_coef = max_norm / (total_norm + 1e-6) 236 | tmp_tensor = torch.tensor([1.0], device=clip_coef.device) 237 | clip_coef = torch.min(tmp_tensor, clip_coef) 238 | for p in parameters: 239 | p.grad.data.mul_(clip_coef) 240 | return total_norm 241 | 242 | 243 | def apply_patches(): 244 | # Prevent PEFT from downcasting LoRA weights to fp8 only for this script to upcast them again. 245 | # TODO: probably should send a PR to PEFT. Default behavior looks like a mistake to me. 246 | peft.tuners.tuners_utils.BaseTunerLayer._move_adapter_to_device_of_base_layer = _move_adapter_to_device_of_base_layer 247 | 248 | # Use torch_dtype to avoid needlessly loading the text encoder in float32, only to cast it right after. 249 | hyvideo.text_encoder.load_text_encoder = load_text_encoder 250 | 251 | # LoadMicroBatch before sending / receiving activations so we can avoid a deadlock and broadcast the target 252 | # from the first stage to the last stage. InferenceSchedule already has the commands in the right order 253 | # and doesn't need this. 254 | deepspeed.runtime.pipe.schedule.TrainSchedule.steps = train_schedule_steps 255 | 256 | # This does two things: 257 | # 1. For block swapping, some parameters will be on CPU when the DeepSpeedEngine is constructed. So we patch this to 258 | # first move those parameters to GPU, then back again when broadcasting the model weights from rank 0. 259 | # 2. We skip broadcasting for parameters that don't require grad. These weights are static and always the same because 260 | # they were loaded from disk, so we can safely skip broadcasting and it's faster. 261 | deepspeed.runtime.engine.DeepSpeedEngine._broadcast_model = broadcast_model 262 | 263 | # Don't fail if there are no trainable parameters on a stage. 264 | deepspeed.runtime.engine.DeepSpeedEngine.clip_fp32_gradients = lambda self: clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping(), mpu=self.mpu) 265 | -------------------------------------------------------------------------------- /utils/pipeline.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from deepspeed.pipe import PipelineModule 4 | from deepspeed.runtime.pipe import LayerSpec 5 | 6 | 7 | # PipelineModule partition_method doesn't support uneven partitioning 8 | # This allow for loading more layers into selected GPU 9 | # For example if you have 2 gpus - one with 16GB and other with 24GB normal partitioning would throw OOM 10 | # With this implementation you can set partition_split in config so that less layers is loaded onto 16GB GPU 11 | class ManualPipelineModule(PipelineModule): 12 | def __init__(self, *args, manual_partition_split=None, **kwargs): 13 | self.manual_partition_split = manual_partition_split 14 | super().__init__(*args, **kwargs) 15 | 16 | def _partition_layers(self, method='uniform'): 17 | if method.lower() == 'manual' and self.manual_partition_split is not None: 18 | num_stages = self._topo.get_dim('pipe') 19 | stage_id = self._topo.get_coord(self.global_rank).pipe 20 | num_partitions = len(self.manual_partition_split) 21 | assert num_partitions == num_stages - 1, f'partition_split must be length {num_stages-1} (pipeline_stages-1), was actually {num_partitions}' 22 | 23 | total_layers = len(self._layer_specs) 24 | boundaries = [0] + self.manual_partition_split + [total_layers] 25 | self.parts = boundaries 26 | 27 | # Print some information on the partitioning. 28 | if self.global_rank == 0: 29 | for stage in range(num_stages): 30 | start = self.parts[stage] 31 | stop = self.parts[stage + 1] 32 | print(f'stage={stage} layers={stop - start}') 33 | for idx, layer in enumerate(self._layer_specs[start:stop]): 34 | name = str(layer) 35 | if isinstance(layer, LayerSpec): 36 | name = layer.typename.__name__ 37 | if isinstance(layer, nn.Module): 38 | name = layer.__class__.__name__ 39 | else: 40 | try: 41 | name = layer.__name__ 42 | except AttributeError: 43 | pass 44 | print(f' {idx+start:2d}: {name}') 45 | if self.loss_fn: 46 | try: 47 | print(f' loss: {self.loss_fn.__name__}') 48 | except AttributeError: 49 | print(f' loss: {self.loss_fn.__class__.__name__}') 50 | 51 | self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id+1]) 52 | else: 53 | super()._partition_layers(method) 54 | -------------------------------------------------------------------------------- /utils/saver.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | import shutil 4 | import time 5 | import sys 6 | 7 | import torch 8 | from deepspeed import comm as dist 9 | from deepspeed.utils.logging import logger 10 | 11 | from utils.common import is_main_process 12 | 13 | 14 | def convert_state_dict_dtype(state_dict, dtype): 15 | for key, v in state_dict.items(): 16 | state_dict[key] = v.to(device='cpu', dtype=dtype) 17 | 18 | 19 | last_checkpoint_time = None 20 | def need_to_checkpoint(config, epoch=None): 21 | global last_checkpoint_time 22 | 23 | if epoch is not None: 24 | if 'checkpoint_every_n_epochs' in config and epoch % config['checkpoint_every_n_epochs'] == 0: 25 | last_checkpoint_time = time.time() 26 | return True 27 | else: 28 | return False 29 | 30 | if 'checkpoint_every_n_minutes' not in config: 31 | return False 32 | 33 | checkpoint = False 34 | # rank 0 tracks if we need to checkpoint, broadcasts to everyone else 35 | if is_main_process(): 36 | current_time = time.time() 37 | if last_checkpoint_time is None: 38 | last_checkpoint_time = current_time 39 | elif (current_time - last_checkpoint_time) / 60 > config['checkpoint_every_n_minutes']: 40 | checkpoint = True 41 | last_checkpoint_time = current_time 42 | result = [checkpoint] 43 | torch.distributed.broadcast_object_list(result, src=0) 44 | return result[0] 45 | 46 | 47 | class Saver: 48 | def __init__(self, args, config, is_adapter, save_root, model, train_dataloader, model_engine, pipeline_model): 49 | self.args = args 50 | self.config = config 51 | self.is_adapter = is_adapter 52 | self.save_root = Path(save_root) 53 | self.model = model 54 | self.train_dataloader = train_dataloader 55 | self.model_engine = model_engine 56 | self.pipeline_model = pipeline_model 57 | 58 | def save_adapter(self, name): 59 | dp_id = self.model_engine.grid.get_data_parallel_rank() 60 | stage_id = self.model_engine.grid.get_pipe_parallel_rank() 61 | save_dir = self.save_root / name 62 | tmp_dir = save_dir / 'tmp' 63 | if dp_id == 0 and stage_id == 0: 64 | os.makedirs(tmp_dir, exist_ok=False) 65 | dist.barrier() 66 | if dp_id == 0: 67 | partial_state_dict = {} 68 | for name, p in self.pipeline_model.named_parameters(): 69 | if p.requires_grad: 70 | if not hasattr(p, 'original_name'): 71 | logger.warning(f'WARNING: parameter {name} requires_grad but does not have original_name. Not saving it.') 72 | continue 73 | # TODO: maybe this needs to change if we ever have non-lora adapters? 74 | partial_state_dict[p.original_name.replace('.default', '').replace('.modules_to_save', '')] = p.detach() 75 | if 'save_dtype' in self.config: 76 | convert_state_dict_dtype(partial_state_dict, self.config['save_dtype']) 77 | torch.save(partial_state_dict, tmp_dir / f'state_dict_{stage_id}.bin') 78 | dist.barrier() 79 | if dp_id == 0 and stage_id == 0: 80 | state_dict = {} 81 | for path in tmp_dir.glob('*.bin'): 82 | state_dict.update(torch.load(path, weights_only=True, map_location='cpu')) 83 | self.model.save_adapter(save_dir, state_dict) 84 | shutil.copy(self.args.config, save_dir) 85 | shutil.rmtree(tmp_dir) 86 | 87 | def save_full_model(self, name, max_shard_size='5GB'): 88 | dp_id = self.model_engine.grid.get_data_parallel_rank() 89 | stage_id = self.model_engine.grid.get_pipe_parallel_rank() 90 | save_dir = self.save_root / name 91 | tmp_dir = save_dir / 'tmp' 92 | if dp_id == 0 and stage_id == 0: 93 | os.makedirs(tmp_dir, exist_ok=False) 94 | dist.barrier() 95 | if dp_id == 0: 96 | # With BF16_Optimizer, we get pickle errors unless we do p.detach(). I have no idea why. 97 | partial_state_dict = {p.original_name: p.detach() for p in self.pipeline_model.parameters()} 98 | if 'save_dtype' in self.config: 99 | convert_state_dict_dtype(partial_state_dict, self.config['save_dtype']) 100 | torch.save(partial_state_dict, tmp_dir / f'state_dict_{stage_id}.bin') 101 | dist.barrier() 102 | if dp_id == 0 and stage_id == 0: 103 | state_dict = {} 104 | for path in tmp_dir.glob('*.bin'): 105 | state_dict.update(torch.load(path, map_location='cpu', weights_only=True)) 106 | self.model.save_model(save_dir, state_dict) 107 | shutil.copy(self.args.config, save_dir) 108 | shutil.rmtree(tmp_dir) 109 | 110 | def save_model(self, name): 111 | if is_main_process(): 112 | print(f'Saving model to directory {name}') 113 | if self.is_adapter: 114 | self.save_adapter(name) 115 | else: 116 | self.save_full_model(name) 117 | 118 | def save_checkpoint(self, step): 119 | self.model_engine.save_checkpoint( 120 | self.save_root, 121 | client_state={ 122 | 'step': step, 123 | 'custom_loader': self.train_dataloader.state_dict(), 124 | }, 125 | save_latest=True, 126 | exclude_frozen_parameters=True 127 | ) 128 | 129 | def process_epoch(self, epoch, step): 130 | checkpointed, saved = False, False 131 | if self.train_dataloader.epoch != epoch: 132 | if need_to_checkpoint(self.config, epoch): 133 | self.save_checkpoint(step) 134 | checkpointed = True 135 | if epoch % self.config['save_every_n_epochs'] == 0: 136 | self.save_model(f'epoch{epoch}') 137 | saved = True 138 | epoch = self.train_dataloader.epoch 139 | if epoch > self.config['epochs']: 140 | return None, checkpointed, saved 141 | if is_main_process(): 142 | print(f'Started new epoch: {epoch}') 143 | return epoch, checkpointed, saved 144 | 145 | def process_step(self, step): 146 | # Look at some simple "signal files" the user can write to save and optionally quit manually 147 | should_manually_save = False 148 | should_manually_quit = False 149 | save_signal_file = self.save_root / 'save' 150 | save_quit_signal_file = self.save_root / 'save_quit' 151 | if save_signal_file.exists() and save_signal_file.is_file(): 152 | should_manually_save = True 153 | dist.barrier() 154 | if is_main_process(): 155 | os.remove(save_signal_file) 156 | elif save_quit_signal_file.exists() and save_quit_signal_file.is_file(): 157 | should_manually_save = True 158 | should_manually_quit = True 159 | dist.barrier() 160 | if is_main_process(): 161 | os.remove(save_quit_signal_file) 162 | 163 | # TODO: support save_every_n_steps in addition to save_every_n_epochs. Maybe only one should be set? 164 | # if step % self.config['save_every_n_steps'] == 0 or should_manually_save: 165 | # self.save_model(f'step{step}') 166 | 167 | if need_to_checkpoint(self.config) or should_manually_save: 168 | self.save_checkpoint(step) 169 | 170 | if should_manually_quit: 171 | print('Manually quitting') 172 | sys.exit() -------------------------------------------------------------------------------- /utils/unsloth_utils.py: -------------------------------------------------------------------------------- 1 | # Unsloth Zoo - Utilities for Unsloth 2 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Lesser General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Lesser General Public License 15 | # along with this program. If not, see . 16 | 17 | # I (tdrussell) made a few modifications. 18 | 19 | import torch 20 | from deepspeed.runtime.activation_checkpointing.checkpointing import detach_variable 21 | 22 | 23 | class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): 24 | """ 25 | Code licensed under LGPL 26 | Saves VRAM by smartly offloading to RAM. 27 | Tiny hit to performance, since we mask the movement via non blocking calls. 28 | """ 29 | 30 | @staticmethod 31 | @torch.amp.custom_fwd(device_type='cuda') 32 | def forward(ctx, forward_function, hidden_states, *args): 33 | saved_hidden_states = hidden_states.to('cpu', non_blocking=True) 34 | with torch.no_grad(): 35 | output = forward_function(hidden_states, *args) 36 | ctx.save_for_backward(saved_hidden_states) 37 | ctx.forward_function = forward_function 38 | ctx.args = args 39 | return output 40 | 41 | pass 42 | 43 | @staticmethod 44 | @torch.amp.custom_bwd(device_type='cuda') 45 | def backward(ctx, *grads): 46 | (hidden_states,) = ctx.saved_tensors 47 | hidden_states = hidden_states.to('cuda', non_blocking=True).detach() 48 | hidden_states.requires_grad_(True) 49 | args = detach_variable(ctx.args) 50 | inputs = (hidden_states,) + args 51 | with torch.enable_grad(): 52 | outputs = ctx.forward_function(*inputs) 53 | 54 | output_tensors = [] 55 | grad_tensors = [] 56 | for out, grad in zip(outputs, grads): 57 | if out.requires_grad: 58 | output_tensors.append(out) 59 | grad_tensors.append(grad) 60 | torch.autograd.backward(output_tensors, grad_tensors) 61 | return (None,) + tuple(input.grad for input in inputs) 62 | 63 | pass 64 | 65 | 66 | pass 67 | 68 | 69 | @torch._disable_dynamo 70 | def unsloth_checkpoint(function, *args): 71 | return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args) 72 | --------------------------------------------------------------------------------