├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yaml │ └── feature-request.yaml └── workflows │ └── pr_tests.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── accelerate_configs ├── compiled_1.yaml ├── deepspeed.yaml ├── uncompiled_1.yaml ├── uncompiled_2.yaml ├── uncompiled_4.yaml └── uncompiled_8.yaml ├── assets ├── CogVideoX-LoRA.webm ├── contribute.md ├── contribute_zh.md ├── dataset_zh.md ├── lora_2b.png ├── lora_5b.png ├── sft_2b.png ├── sft_5b.png ├── slaying-ooms.png └── tests │ ├── metadata.csv │ ├── prompts.txt │ ├── prompts_multi.txt │ ├── videos.txt │ ├── videos │ ├── hiker.mp4 │ └── hiker_tiny.mp4 │ └── videos_multi.txt ├── docs ├── _NOTES_FOR_FUTURE_ME.md ├── args.md ├── dataset │ ├── README.md │ └── _DEBUG.md ├── environment.md ├── models │ ├── README.md │ ├── attention.md │ ├── cogvideox.md │ ├── cogview4.md │ ├── flux.md │ ├── hunyuan_video.md │ ├── ltx_video.md │ ├── optimization.md │ └── wan.md ├── optimizer.md ├── parallel │ └── README.md └── trainer │ ├── control_trainer.md │ └── sft_trainer.md ├── examples ├── _legacy │ └── training │ │ ├── README.md │ │ ├── README_zh.md │ │ ├── cogvideox │ │ ├── __init__.py │ │ ├── args.py │ │ ├── cogvideox_image_to_video_lora.py │ │ ├── cogvideox_image_to_video_sft.py │ │ ├── cogvideox_text_to_video_lora.py │ │ ├── cogvideox_text_to_video_sft.py │ │ ├── dataset.py │ │ ├── prepare_dataset.py │ │ ├── text_encoder │ │ │ ├── __init__.py │ │ │ └── text_encoder.py │ │ └── utils.py │ │ ├── mochi-1 │ │ ├── README.md │ │ ├── args.py │ │ ├── dataset_simple.py │ │ ├── embed.py │ │ ├── prepare_dataset.sh │ │ ├── requirements.txt │ │ ├── text_to_video_lora.py │ │ ├── train.sh │ │ ├── trim_and_crop_videos.py │ │ └── utils.py │ │ ├── prepare_dataset.sh │ │ ├── train_image_to_video_lora.sh │ │ ├── train_image_to_video_sft.sh │ │ ├── train_text_to_video_lora.sh │ │ └── train_text_to_video_sft.sh ├── formats │ └── hunyuan_video │ │ └── convert_to_original_format.py ├── inference │ ├── cogvideox │ │ └── cogvideox_text_to_video.sh │ ├── cogview4 │ │ └── cogview4_text_to_image.sh │ ├── datasets │ │ └── .gitignore │ ├── flux │ │ └── flux_text_to_image.sh │ ├── inference.py │ └── wan │ │ └── wan_text_to_video.sh └── training │ ├── control │ ├── cogview4 │ │ ├── canny │ │ │ ├── .gitignore │ │ │ ├── README.md │ │ │ ├── train.sh │ │ │ ├── training.json │ │ │ └── validation.json │ │ └── omni_edit │ │ │ ├── .gitignore │ │ │ ├── README.md │ │ │ ├── train.sh │ │ │ ├── training.json │ │ │ └── validation.json │ └── wan │ │ └── image_condition │ │ ├── train.sh │ │ ├── training.json │ │ └── validation.json │ └── sft │ ├── cogvideox │ └── crush_smol_lora │ │ ├── train.sh │ │ ├── training.json │ │ └── validation.json │ ├── cogview4 │ ├── raider_white_tarot │ │ ├── train.sh │ │ ├── training.json │ │ └── validation.json │ └── the_simpsons │ │ ├── README.md │ │ ├── train.sh │ │ ├── training.json │ │ └── validation.json │ ├── flux_dev │ └── raider_white_tarot │ │ ├── train.sh │ │ ├── training.json │ │ └── validation.json │ ├── hunyuan_video │ └── modal_labs_dissolve │ │ ├── train.sh │ │ ├── training.json │ │ └── validation.json │ ├── ltx_video │ └── crush_smol_lora │ │ ├── train.sh │ │ ├── train_multires.sh │ │ ├── training.json │ │ ├── training_multires.json │ │ ├── validation.json │ │ └── validation_multires.json │ ├── wan │ ├── 3dgs_dissolve │ │ ├── train.sh │ │ ├── training.json │ │ └── validation.json │ └── crush_smol_lora │ │ ├── train.sh │ │ ├── training.json │ │ └── validation.json │ └── wan_i2v │ └── 3dgs_dissolve │ ├── train.sh │ ├── training.json │ └── validation.json ├── finetrainers ├── __init__.py ├── _metadata.py ├── args.py ├── config.py ├── constants.py ├── data │ ├── __init__.py │ ├── _artifact.py │ ├── dataloader.py │ ├── dataset.py │ ├── precomputation.py │ └── sampler.py ├── functional │ ├── __init__.py │ ├── diffusion.py │ ├── image.py │ ├── normalization.py │ ├── text.py │ └── video.py ├── logging.py ├── models │ ├── __init__.py │ ├── _metadata │ │ └── transformer.py │ ├── attention_dispatch.py │ ├── cogvideox │ │ ├── __init__.py │ │ ├── base_specification.py │ │ └── utils.py │ ├── cogview4 │ │ ├── __init__.py │ │ ├── base_specification.py │ │ └── control_specification.py │ ├── flux │ │ ├── __init__.py │ │ └── base_specification.py │ ├── hunyuan_video │ │ ├── __init__.py │ │ └── base_specification.py │ ├── ltx_video │ │ ├── __init__.py │ │ └── base_specification.py │ ├── modeling_utils.py │ ├── utils.py │ └── wan │ │ ├── __init__.py │ │ ├── base_specification.py │ │ └── control_specification.py ├── optimizer.py ├── parallel │ ├── __init__.py │ ├── accelerate.py │ ├── base.py │ ├── deepspeed.py │ ├── ptd.py │ └── utils.py ├── patches │ ├── __init__.py │ ├── dependencies │ │ ├── diffusers │ │ │ ├── control.py │ │ │ ├── patch.py │ │ │ ├── peft.py │ │ │ └── rms_norm.py │ │ └── peft │ │ │ └── patch.py │ ├── models │ │ ├── ltx_video │ │ │ └── patch.py │ │ └── wan │ │ │ └── patch.py │ └── utils.py ├── processors │ ├── __init__.py │ ├── base.py │ ├── canny.py │ ├── clip.py │ ├── glm.py │ ├── llama.py │ ├── t5.py │ └── text.py ├── state.py ├── trackers.py ├── trainer │ ├── __init__.py │ ├── base.py │ ├── control_trainer │ │ ├── __init__.py │ │ ├── config.py │ │ ├── data.py │ │ └── trainer.py │ └── sft_trainer │ │ ├── __init__.py │ │ ├── config.py │ │ └── trainer.py ├── typing.py └── utils │ ├── __init__.py │ ├── _common.py │ ├── activation_checkpoint.py │ ├── args_config.py │ ├── data.py │ ├── diffusion.py │ ├── file.py │ ├── hub.py │ ├── import_utils.py │ ├── memory.py │ ├── model.py │ ├── serialization.py │ ├── timing.py │ └── torch.py ├── pyproject.toml ├── requirements.txt ├── setup.py ├── tests ├── README.md ├── __init__.py ├── _test_dataset_old.py ├── data │ ├── __init__.py │ ├── test_dataset.py │ ├── test_precomputation.py │ └── utils.py ├── models │ ├── __init__.py │ ├── attention_dispatch.py │ ├── cogvideox │ │ ├── __init__.py │ │ └── base_specification.py │ ├── cogview4 │ │ ├── __init__.py │ │ ├── base_specification.py │ │ └── control_specification.py │ ├── flux │ │ ├── __init__.py │ │ └── base_specification.py │ ├── hunyuan_video │ │ └── base_specification.py │ ├── ltx_video │ │ ├── __init__.py │ │ ├── _test_tp.py │ │ └── base_specification.py │ └── wan │ │ ├── __init__.py │ │ ├── base_specification.py │ │ └── control_specification.py ├── scripts │ ├── dummy_cogvideox_lora.sh │ ├── dummy_hunyuanvideo_lora.sh │ └── dummy_ltx_video_lora.sh ├── test_lora_inference.py ├── test_model_runs_minimally_lora.sh ├── test_trackers.py └── trainer │ ├── __init__.py │ ├── test_control_trainer.py │ └── test_sft_trainer.py └── train.py /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Submit a bug report to help us improve CogVideoX-Factory / 提交一个 Bug 问题报告来帮助我们改进 CogVideoX-Factory 开源框架 3 | body: 4 | - type: textarea 5 | id: system-info 6 | attributes: 7 | label: System Info / 系統信息 8 | description: Your operating environment / 您的运行环境信息 9 | placeholder: Includes Cuda version, Diffusers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Diffusers,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)... 10 | validations: 11 | required: true 12 | 13 | - type: checkboxes 14 | id: information-scripts-examples 15 | attributes: 16 | label: Information / 问题信息 17 | description: 'The problem arises when using: / 问题出现在' 18 | options: 19 | - label: "The official example scripts / 官方的示例脚本" 20 | - label: "My own modified scripts / 我自己修改的脚本和任务" 21 | 22 | - type: textarea 23 | id: reproduction 24 | validations: 25 | required: true 26 | attributes: 27 | label: Reproduction / 复现过程 28 | description: | 29 | Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit. 30 | If you have code snippets, error messages, stack traces, please provide them here as well. 31 | Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 32 | Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code. 33 | 34 | 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。 35 | 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。 36 | 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 37 | 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。 38 | placeholder: | 39 | Steps to reproduce the behavior/复现Bug的步骤: 40 | 41 | 1. 42 | 2. 43 | 3. 44 | 45 | - type: textarea 46 | id: expected-behavior 47 | validations: 48 | required: true 49 | attributes: 50 | label: Expected behavior / 期待表现 51 | description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。" -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yaml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Submit a request for a new CogVideoX-Factory feature / 提交一个新的 CogVideoX-Factory 开源项目的功能建议 3 | labels: [ "feature" ] 4 | body: 5 | - type: textarea 6 | id: feature-request 7 | validations: 8 | required: true 9 | attributes: 10 | label: Feature request / 功能建议 11 | description: | 12 | A brief description of the functional proposal. Links to corresponding papers and code are desirable. 13 | 对功能建议的简述。最好提供对应的论文和代码链接。 14 | 15 | - type: textarea 16 | id: motivation 17 | validations: 18 | required: true 19 | attributes: 20 | label: Motivation / 动机 21 | description: | 22 | Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here. 23 | 您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。 24 | 25 | - type: textarea 26 | id: contribution 27 | validations: 28 | required: true 29 | attributes: 30 | label: Your contribution / 您的贡献 31 | description: | 32 | 33 | Your PR link or any other link you can help with. 34 | 您的PR链接或者其他您能提供帮助的链接。 -------------------------------------------------------------------------------- /.github/workflows/pr_tests.yml: -------------------------------------------------------------------------------- 1 | name: Fast tests for PRs 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | check_code_quality: 14 | runs-on: ubuntu-22.04 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: "3.8" 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install ruff==0.9.10 25 | - name: Check quality 26 | run: make quality 27 | - name: Check if failure 28 | if: ${{ failure() }} 29 | run: | 30 | echo "Quality check failed. Please install ruff: `pip install ruff` and then run `make style && make quality` from the root of the repository." >> $GITHUB_STEP_SUMMARY 31 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute to Finetrainers 2 | 3 | Finetrainers is an early-stage library for training diffusion models. Everyone is welcome to contribute - models, algorithms, refactors, docs, etc. - but due to the early stage of the project, we recommend bigger contributions be discussed in an issue before submitting a PR. Eventually, we will have a better process for this! 4 | 5 | ## How to contribute 6 | 7 | ### Adding a new model 8 | 9 | If you would like to add a new model, please follow these steps: 10 | 11 | - Create a new file in the `finetrainers/models` directory with the model name (if it's new), or use the same directory if it's a variant of an existing model. 12 | - Implement the model specification in the file. For more details on what a model specification should look like, see the [ModelSpecification](TODO(aryan): add link) documentation. 13 | - Update the supported configs in `finetrainers/config.py` to include the new model and the training types supported. 14 | - Add a dummy model specification in the `tests/models` directory. 15 | - Make sure to test training with the following settings: 16 | - Single GPU 17 | - 2x GPU with `--dp_degree 2 --dp_shards 1` 18 | - 2x GPU with `--dp_degree 1 --dp_shards 2` 19 | 20 | For `SFTTrainer` additions, please make sure to train with atleast 1000 steps (atleast 2000 data points) to ensure the model training is working as expected. 21 | - Open a PR with your changes. Please make sure to share your wandb logs for the above training settings in the PR description. This will help us verify the training is working as expected. 22 | 23 | ### Adding a new algorithm 24 | 25 | Currently, we are not accepting algorithm contributions. We will update this section once we are better ready 🤗 26 | 27 | ### Refactors 28 | 29 | The library is in a very early stage. There are many instances of dead code, poorly written abstractions, and other issues. If you would like to refactor/clean-up a part of the codebase, please open an issue to discuss the changes before submitting a PR. 30 | 31 | ### Dataset improvements 32 | 33 | Any changes to dataset/dataloader implementations can be submitted directly. The improvements and reasons for the changes should be conveyed appropriately for us to move quickly 🤗 34 | 35 | ### Documentation 36 | 37 | Due to the early stage of the project, the documentation is not as comprehensive as we would like. Any improvements/refactors are welcome directly! 38 | 39 | ## Asking for help 40 | 41 | If you have any questions, feel free to open an issue and we will be sure to help you out asap! Please make sure to describe your issues in either English (preferable) or Chinese. Any other language will make it hard for us to help you, so we will most likely close such issues without explanation/answer. 42 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: quality style 2 | 3 | check_dirs := finetrainers tests examples train.py setup.py 4 | 5 | quality: 6 | ruff check $(check_dirs) --exclude examples/_legacy 7 | ruff format --check $(check_dirs) --exclude examples/_legacy 8 | 9 | style: 10 | ruff check $(check_dirs) --fix --exclude examples/_legacy 11 | ruff format $(check_dirs) --exclude examples/_legacy 12 | -------------------------------------------------------------------------------- /accelerate_configs/compiled_1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: 'NO' 4 | downcast_bf16: 'no' 5 | dynamo_config: 6 | dynamo_backend: INDUCTOR 7 | dynamo_mode: max-autotune 8 | dynamo_use_dynamic: true 9 | dynamo_use_fullgraph: false 10 | enable_cpu_affinity: false 11 | gpu_ids: '3' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 1 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /accelerate_configs/deepspeed.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: cpu 7 | offload_param_device: cpu 8 | zero3_init_flag: false 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | enable_cpu_affinity: false 13 | machine_rank: 0 14 | main_training_function: main 15 | mixed_precision: bf16 16 | num_machines: 1 17 | num_processes: 2 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_env: [] 21 | tpu_use_cluster: false 22 | tpu_use_sudo: false 23 | use_cpu: false -------------------------------------------------------------------------------- /accelerate_configs/uncompiled_1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: 'NO' 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | gpu_ids: '3' 7 | machine_rank: 0 8 | main_training_function: main 9 | mixed_precision: bf16 10 | num_machines: 1 11 | num_processes: 1 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | -------------------------------------------------------------------------------- /accelerate_configs/uncompiled_2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | gpu_ids: 0,1 7 | machine_rank: 0 8 | main_training_function: main 9 | mixed_precision: bf16 10 | num_machines: 1 11 | num_processes: 2 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false -------------------------------------------------------------------------------- /accelerate_configs/uncompiled_4.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | gpu_ids: 0,1,2,3 7 | machine_rank: 0 8 | main_training_function: main 9 | mixed_precision: bf16 10 | num_machines: 1 11 | num_processes: 4 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false -------------------------------------------------------------------------------- /accelerate_configs/uncompiled_8.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | gpu_ids: all 7 | machine_rank: 0 8 | main_training_function: main 9 | mixed_precision: bf16 10 | num_machines: 1 11 | num_processes: 8 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false -------------------------------------------------------------------------------- /assets/CogVideoX-LoRA.webm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/assets/CogVideoX-LoRA.webm -------------------------------------------------------------------------------- /assets/contribute.md: -------------------------------------------------------------------------------- 1 | # Contributions Welcome 2 | 3 | This project is in a very early stage, and we welcome contributions from everyone. We hope to receive contributions and support in the following areas: 4 | 5 | 1. Support for more models. In addition to CogVideoX models, we also highly encourage contributions supporting other models. 6 | 2. Support for richer datasets. In our example, we used a Disney video generation dataset, but we hope to support more datasets as the current one is too limited for deeper fine-tuning exploration. 7 | 3. Anything in `TODO` we mention in our README.md 8 | 9 | ## How to Submit 10 | 11 | We welcome you to create a new PR and describe the corresponding contribution. We will review it as soon as possible. 12 | 13 | ## Naming Conventions 14 | 15 | - Please use English for naming, avoid using pinyin or other languages. All comments should be in English. 16 | - Strictly follow PEP8 conventions, and use underscores to separate words. Please avoid using names like a, b, c. -------------------------------------------------------------------------------- /assets/contribute_zh.md: -------------------------------------------------------------------------------- 1 | # 欢迎你们的贡献 2 | 3 | 本项目属于非常初级的阶段,欢迎大家进行贡献。我们希望在以下方面得到贡献和支持: 4 | 5 | 1. 支持更多的模型,除了 CogVideoX 模型之外的模型,我们也非常支持。 6 | 2. 更丰富的数据集支持。在我们的例子中,我们使用了一个 Disney 视频生成数据集,但是我们希望能够支持更多的数据集,这个数据集太少了,并不足以进行更深的微调探索。 7 | 3. 任何我们在README中`TODO`提到的内容。 8 | 9 | ## 提交方式 10 | 11 | 我们欢迎您直接创建一个新的PR,并说明对应的贡献,我们将第一时间查看。 12 | 13 | ## 命名规范 14 | 15 | - 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。 16 | - 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。 -------------------------------------------------------------------------------- /assets/dataset_zh.md: -------------------------------------------------------------------------------- 1 | ## 数据集格式 2 | 3 | ### 提示词数据集要求 4 | 5 | 创建 `prompt.txt` 文件,文件应包含逐行分隔的提示。请注意,提示必须是英文,并且建议使用 [提示润色脚本](https://github.com/THUDM/CogVideo/blob/main/inference/convert_demo.py) 进行润色。或者可以使用 [CogVideo-caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption) 进行数据标注: 6 | 7 | ``` 8 | A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction. 9 | A black and white animated sequence on a ship’s deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language... 10 | ... 11 | ``` 12 | 13 | ### 视频数据集要求 14 | 15 | 该框架支持的分辨率和帧数需要满足以下条件: 16 | 17 | - **支持的分辨率(宽 * 高)**: 18 | - 任意分辨率且必须能被32整除。例如,`720 * 480`, `1920 * 1020` 等分辨率。 19 | 20 | - **支持的帧数(Frames)**: 21 | - 必须是 `4 * k` 或 `4 * k + 1`(例如:16, 32, 49, 81) 22 | 23 | 所有的视频建议放在一个文件夹中。 24 | 25 | 26 | 接着,创建 `videos.txt` 文件。 `videos.txt` 文件应包含逐行分隔的视频文件路径。请注意,路径必须相对于 `--data_root` 目录。格式如下: 27 | 28 | ``` 29 | videos/00000.mp4 30 | videos/00001.mp4 31 | ... 32 | ``` 33 | 34 | 对于有兴趣了解更多细节的开发者,您可以查看相关的 `BucketSampler` 代码。 35 | 36 | ### 数据集结构 37 | 38 | 您的数据集结构应如下所示,通过运行`tree`命令,你能看到: 39 | 40 | ``` 41 | dataset 42 | ├── prompt.txt 43 | ├── videos.txt 44 | ├── videos 45 | ├── videos/00000.mp4 46 | ├── videos/00001.mp4 47 | ├── ... 48 | ``` 49 | 50 | ### 使用数据集 51 | 52 | 当使用此格式时,`--caption_column` 应为 `prompt.txt`,`--video_column` 应为 `videos.txt`。如果您的数据存储在 CSV 53 | 文件中,也可以指定 `--dataset_file` 为 CSV 文件的路径,`--caption_column` 和 `--video_column` 为 CSV 54 | 文件中的实际列名。请参考 [test_dataset](../tests/test_dataset.py) 文件中的一些简单示例。 55 | 56 | 例如,使用 [这个](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset) Disney 数据集进行微调。下载可通过🤗 57 | Hugging Face CLI 完成: 58 | 59 | ``` 60 | huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir video-dataset-disney 61 | ``` 62 | 63 | 该数据集已按照预期格式准备好,可直接使用。但是,直接使用视频数据集可能会导致较小 VRAM 的 GPU 出现 64 | OOM(内存不足),因为它需要加载 [VAE](https://huggingface.co/THUDM/CogVideoX-5b/tree/main/vae) 65 | (将视频编码为潜在空间)和大型 [T5-XXL](https://huggingface.co/google/t5-v1_1-xxl/) 66 | 67 | 文本编码器。为了降低内存需求,您可以使用 `training/prepare_dataset.py` 脚本预先计算潜在变量和嵌入。 68 | 69 | 填写或修改 `prepare_dataset.sh` 中的参数并执行它以获得预先计算的潜在变量和嵌入(请确保指定 `--save_latents_and_embeddings` 70 | 以保存预计算的工件)。如果准备图像到视频的训练,请确保传递 `--save_image_latents`,它对沙子进行编码,将图像潜在值与视频一起保存。 71 | 在训练期间使用这些工件时,确保指定 `--load_tensors` 标志,否则将直接使用视频并需要加载文本编码器和 72 | VAE。该脚本还支持 PyTorch DDP,以便可以使用多个 GPU 并行编码大型数据集(修改 `NUM_GPUS` 参数)。 73 | -------------------------------------------------------------------------------- /assets/lora_2b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/assets/lora_2b.png -------------------------------------------------------------------------------- /assets/lora_5b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/assets/lora_5b.png -------------------------------------------------------------------------------- /assets/sft_2b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/assets/sft_2b.png -------------------------------------------------------------------------------- /assets/sft_5b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/assets/sft_5b.png -------------------------------------------------------------------------------- /assets/slaying-ooms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/assets/slaying-ooms.png -------------------------------------------------------------------------------- /assets/tests/metadata.csv: -------------------------------------------------------------------------------- 1 | video,caption 2 | "videos/hiker.mp4","""A hiker standing at the top of a mountain, triumphantly, high quality""" -------------------------------------------------------------------------------- /assets/tests/prompts.txt: -------------------------------------------------------------------------------- 1 | A hiker standing at the top of a mountain, triumphantly, high quality -------------------------------------------------------------------------------- /assets/tests/prompts_multi.txt: -------------------------------------------------------------------------------- 1 | A hiker standing at the top of a mountain, triumphantly, high quality 2 | A hiker standing at the top of a mountain, triumphantly, high quality -------------------------------------------------------------------------------- /assets/tests/videos.txt: -------------------------------------------------------------------------------- 1 | videos/hiker.mp4 -------------------------------------------------------------------------------- /assets/tests/videos/hiker.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/assets/tests/videos/hiker.mp4 -------------------------------------------------------------------------------- /assets/tests/videos/hiker_tiny.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/assets/tests/videos/hiker_tiny.mp4 -------------------------------------------------------------------------------- /assets/tests/videos_multi.txt: -------------------------------------------------------------------------------- 1 | videos/hiker.mp4 2 | videos/hiker_tiny.mp4 -------------------------------------------------------------------------------- /docs/_NOTES_FOR_FUTURE_ME.md: -------------------------------------------------------------------------------- 1 | # Notes for Future Me 2 | 3 | >![NOTE] 4 | > This doc page is intended for developers and contributors. 5 | 6 | FSDP dump: 7 | - https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes 8 | - https://github.com/pytorch/pytorch/issues/114299 9 | - Using FSDP1 requires that all FSDP flat parameters are of the same dtype. For LoRA training, we default lora parameters to fp32 and transformer parameters to dtype chosen by user. There seems to be no easy workaround than performing lora training in same dtype. 10 | - https://github.com/pytorch/pytorch/issues/100945 11 | - https://github.com/pytorch/torchtune/blob/9b3836028fd0b48f593ea43474b86880c49a4d74/recipes/lora_finetune_distributed.py 12 | - https://github.com/KellerJordan/modded-nanogpt/pull/68 13 | - https://github.com/pytorch/pytorch/pull/125394: monkey-patch method for FSDP pre/post-hooks to be triggered for method other than `forward` 14 | - https://github.com/pytorch/pytorch/pull/127786: 15 | - https://github.com/pytorch/pytorch/pull/130949: 16 | - Sanity saver: create optimizers after parallelizing/activation-checkpointing models 17 | 18 | DTensor: 19 | - https://github.com/pytorch/pytorch/issues/88838 20 | - https://github.com/pytorch/pytorch/blob/main/test/distributed/tensor/parallel/test_parallelize_api.py 21 | -------------------------------------------------------------------------------- /docs/dataset/_DEBUG.md: -------------------------------------------------------------------------------- 1 | # Distributed dataset debugging 2 | 3 | >![NOTE] 4 | > This doc page is intended for developers and contributors. 5 | 6 | If the number of samples in the dataset is lower than the number of processes per node, the training will hand indefinitely. I haven't been able to pin down on how this could be fixed due to limited time, but basically: 7 | - Start training with `--dp_degree 2` and `torchrun --standalone --nnodes=1 --nproc_per_node=2`. This launches training with DDP across 2 ranks. 8 | - The dataset has `< dp_degree` samples 9 | - When `datasets.distributed.split_dataset_by_node` is called, the data is distributed correctly to one rank, but the other rank hangs indefinitely. Due to this edge case, fast tests seem to fail. 10 | - For now, we should just use `>= dp_degree` samples in the test dataset. However, should be fixed in the future. 11 | 12 | Minimal reproducer: 13 | 14 | ```python 15 | import torch 16 | import torch.distributed as dist 17 | from datasets import Dataset 18 | from datasets.distributed import split_dataset_by_node 19 | from torch.utils.data import DataLoader 20 | 21 | ds = Dataset.from_dict({"x": [1]}).to_iterable_dataset() 22 | 23 | dist.init_process_group() 24 | rank, world_size = dist.get_rank(), dist.get_world_size() 25 | ds = split_dataset_by_node(ds, rank=rank,world_size=world_size) 26 | dl = DataLoader(ds) 27 | 28 | exhausted = torch.zeros(world_size, dtype=torch.bool) 29 | 30 | def loop(): 31 | while True: 32 | print(rank, "hello", flush=True) 33 | yield from dl 34 | yield "end" 35 | 36 | for x in loop(): 37 | if x == "end": 38 | exhausted[rank] = True 39 | continue 40 | dist.all_reduce(exhausted) 41 | if torch.all(exhausted): 42 | break 43 | print(f"{rank} {x}", flush=True) 44 | ``` 45 | -------------------------------------------------------------------------------- /docs/environment.md: -------------------------------------------------------------------------------- 1 | # Environment 2 | 3 | Finetrainers has only been widely tested with the following environment (output obtained by running `diffusers-cli env`): 4 | 5 | ```shell 6 | - 🤗 Diffusers version: 0.33.0.dev0 7 | - Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31 8 | - Running on Google Colab?: No 9 | - Python version: 3.10.14 10 | - PyTorch version (GPU?): 2.5.1+cu124 (True) 11 | - Flax version (CPU?/GPU?/TPU?): 0.8.5 (cpu) 12 | - Jax version: 0.4.31 13 | - JaxLib version: 0.4.31 14 | - Huggingface_hub version: 0.28.1 15 | - Transformers version: 4.48.0.dev0 16 | - Accelerate version: 1.1.0.dev0 17 | - PEFT version: 0.14.1.dev0 18 | - Bitsandbytes version: 0.43.3 19 | - Safetensors version: 0.4.5 20 | - xFormers version: not installed 21 | - Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB 22 | NVIDIA A100-SXM4-80GB, 81920 MiB 23 | NVIDIA A100-SXM4-80GB, 81920 MiB 24 | NVIDIA DGX Display, 4096 MiB 25 | NVIDIA A100-SXM4-80GB, 81920 MiB 26 | ``` 27 | 28 | Other versions of dependencies may or may not work as expected. We would like to make finetrainers work on a wider range of environments, but due to the complexity of testing at the early stages of development, we are unable to do so. The long term goals include compatibility with most pytorch versions on CUDA, MPS, ROCm and XLA devices. 29 | 30 | > [!IMPORTANT] 31 | > 32 | > For context parallelism, PyTorch 2.6+ is required. 33 | 34 | ## Configuration 35 | 36 | The following environment variables may be configured to change the default behaviour of finetrainers: 37 | 38 | `FINETRAINERS_ATTN_PROVIDER`: Sets the default attention provider for training/validation. Defaults to `native`, as in native PyTorch SDPA. See [attention docs](./models/attention.md) for more information. 39 | `FINETRAINERS_ATTN_CHECKS`: Whether or not to run basic sanity checks when using different attention providers. This is useful for debugging but you should leave it disabled for longer training runs. Defaults to `"0"`. Can be set to a truthy env value. 40 | -------------------------------------------------------------------------------- /docs/models/cogvideox.md: -------------------------------------------------------------------------------- 1 | # CogVideoX 2 | 3 | ## Training 4 | 5 | For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`. 6 | 7 | Examples available: 8 | - [PIKA crush effect](../../examples/training/sft/cogvideox/crush_smol_lora/) 9 | 10 | To run an example, run the following from the root directory of the repository (assuming you have installed the requirements and are using Linux/WSL): 11 | 12 | ```bash 13 | chmod +x ./examples/training/sft/cogvideox/crush_smol_lora/train.sh 14 | ./examples/training/sft/cogvideox/crush_smol_lora/train.sh 15 | ``` 16 | 17 | On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows] 18 | 19 | ## Supported checkpoints 20 | 21 | CogVideoX has multiple checkpoints as one can note [here](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce). The following checkpoints were tested with `finetrainers` and are known to be working: 22 | 23 | - [THUDM/CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b) 24 | - [THUDM/CogVideoX-5B](https://huggingface.co/THUDM/CogVideoX-5B) 25 | - [THUDM/CogVideoX1.5-5B](https://huggingface.co/THUDM/CogVideoX1.5-5B) 26 | 27 | ## Inference 28 | 29 | Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference: 30 | 31 | ```diff 32 | import torch 33 | from diffusers import CogVideoXPipeline 34 | from diffusers.utils import export_to_video 35 | 36 | pipe = CogVideoXPipeline.from_pretrained( 37 | "THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16 38 | ).to("cuda") 39 | + pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="cogvideox-lora") 40 | + pipe.set_adapters(["cogvideox-lora"], [0.75]) 41 | 42 | video = pipe("").frames[0] 43 | export_to_video(video, "output.mp4") 44 | ``` 45 | 46 | You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`: 47 | 48 | - [CogVideoX in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox) 49 | - [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference) 50 | - [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras) -------------------------------------------------------------------------------- /docs/models/flux.md: -------------------------------------------------------------------------------- 1 | # Flux 2 | 3 | ## Training 4 | 5 | For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`. 6 | 7 | Examples available: 8 | - [Raider White Tarot cards style](../../examples/training/sft/flux_dev/raider_white_tarot/) 9 | 10 | To run an example, run the following from the root directory of the repository (assuming you have installed the requirements and are using Linux/WSL): 11 | 12 | ```bash 13 | chmod +x ./examples/training/sft/flux_dev/raider_white_tarot/train.sh 14 | ./examples/training/sft/flux_dev/raider_white_tarot/train.sh 15 | ``` 16 | 17 | On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows] 18 | 19 | > [!NOTE] 20 | > Currently, only FLUX.1-dev is supported. It is a guidance-distilled model which directly predicts the outputs of its teacher model when the teacher is run with CFG. To match the output distribution of the distilled model with that of the teacher model, a guidance scale of 1.0 is hardcoded into the codebase. However, other values may work too but it is experimental. 21 | > FLUX.1-schnell is not supported for training yet. It is a timestep-distilled model. Matching its output distribution for training is significantly more difficult. 22 | 23 | ## Supported checkpoints 24 | 25 | The following checkpoints were tested with `finetrainers` and are known to be working: 26 | 27 | - [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) 28 | - [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) 29 | 30 | ## Inference 31 | 32 | Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference: 33 | 34 | ```diff 35 | import torch 36 | from diffusers import FluxPipeline 37 | 38 | pipe = FluxPipeline.from_pretrained( 39 | "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 40 | ).to("cuda") 41 | + pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="flux-lora") 42 | + pipe.set_adapters(["flux-lora"], [0.9]) 43 | 44 | # Make sure to set guidance_scale to 0.0 when inferencing with FLUX.1-schnell or derivative models 45 | image = pipe("").images[0] 46 | image.save("output.png") 47 | ``` 48 | 49 | You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`: 50 | 51 | - [Flux in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) 52 | - [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference) 53 | - [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras) 54 | -------------------------------------------------------------------------------- /docs/models/hunyuan_video.md: -------------------------------------------------------------------------------- 1 | # HunyuanVideo 2 | 3 | ## Training 4 | 5 | For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`. 6 | 7 | Examples available: 8 | - [PIKA Dissolve effect](../../examples/training/sft/hunyuan_video/modal_labs_dissolve/) 9 | 10 | To run an example, run the following from the root directory of the repository (assuming you have installed the requirements and are using Linux/WSL): 11 | 12 | ```bash 13 | chmod +x ./examples/training/sft/hunyuan_video/modal_labs_dissolve/train.sh 14 | ./examples/training/sft/hunyuan_video/modal_labs_dissolve/train.sh 15 | ``` 16 | 17 | On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows] 18 | 19 | ## Inference 20 | 21 | Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference: 22 | 23 | ```py 24 | import torch 25 | from diffusers import HunyuanVideoPipeline 26 | 27 | import torch 28 | from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel 29 | from diffusers.utils import export_to_video 30 | 31 | model_id = "hunyuanvideo-community/HunyuanVideo" 32 | transformer = HunyuanVideoTransformer3DModel.from_pretrained( 33 | model_id, subfolder="transformer", torch_dtype=torch.bfloat16 34 | ) 35 | pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) 36 | pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="hunyuanvideo-lora") 37 | pipe.set_adapters(["hunyuanvideo-lora"], [0.6]) 38 | pipe.vae.enable_tiling() 39 | pipe.to("cuda") 40 | 41 | output = pipe( 42 | prompt="A cat walks on the grass, realistic", 43 | height=320, 44 | width=512, 45 | num_frames=61, 46 | num_inference_steps=30, 47 | ).frames[0] 48 | export_to_video(output, "output.mp4", fps=15) 49 | ``` 50 | 51 | You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`: 52 | 53 | - [Hunyuan-Video in Diffusers](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video) 54 | - [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference) 55 | - [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras) -------------------------------------------------------------------------------- /docs/models/ltx_video.md: -------------------------------------------------------------------------------- 1 | # LTX-Video 2 | 3 | ## Training 4 | 5 | For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`. 6 | 7 | Examples available: 8 | - [PIKA crush effect](../../examples/training/sft/ltx_video/crush_smol_lora/) 9 | 10 | To run an example, run the following from the root directory of the repository (assuming you have installed the requirements and are using Linux/WSL): 11 | 12 | ```bash 13 | chmod +x ./examples/training/sft/ltx_video/crush_smol_lora/train.sh 14 | ./examples/training/sft/ltx_video/crush_smol_lora/train.sh 15 | ``` 16 | 17 | On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows] 18 | 19 | ## Inference 20 | 21 | Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference: 22 | 23 | ```diff 24 | import torch 25 | from diffusers import LTXPipeline 26 | from diffusers.utils import export_to_video 27 | 28 | pipe = LTXPipeline.from_pretrained( 29 | "Lightricks/LTX-Video", torch_dtype=torch.bfloat16 30 | ).to("cuda") 31 | + pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="ltxv-lora") 32 | + pipe.set_adapters(["ltxv-lora"], [0.75]) 33 | 34 | video = pipe("").frames[0] 35 | export_to_video(video, "output.mp4", fps=8) 36 | ``` 37 | 38 | You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`: 39 | 40 | - [LTX-Video in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video) 41 | - [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference) 42 | - [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras) -------------------------------------------------------------------------------- /docs/models/optimization.md: -------------------------------------------------------------------------------- 1 | # Memory optimizations 2 | 3 | To lower memory requirements during training: 4 | 5 | - `--precompute_conditions`: this precomputes the conditions and latents, and loads them as required during training, which saves a significant amount of time and memory. 6 | - `--gradient_checkpointing`: this saves memory by recomputing activations during the backward pass. 7 | - `--layerwise_upcasting_modules transformer`: naively casts the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`. This halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`) 8 | - `--use_8bit_bnb`: this is only applicable to Adam and AdamW optimizers, and makes use of 8-bit precision to store optimizer states. 9 | - Use a DeepSpeed config to launch training (refer to [`accelerate_configs/deepspeed.yaml`](./accelerate_configs/deepspeed.yaml) as an example). 10 | - Do not perform validation/testing. This saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs. 11 | 12 | We will continue to add more features that help to reduce memory consumption. 13 | -------------------------------------------------------------------------------- /docs/optimizer.md: -------------------------------------------------------------------------------- 1 | # Optimizers 2 | 3 | The following optimizers are supported: 4 | - **torch**: 5 | - `Adam` 6 | - `AdamW` 7 | - **bitsandbytes**: 8 | - `Adam` 9 | - `AdamW` 10 | - `Adam8Bit` 11 | - `AdamW8Bit` 12 | 13 | > [!NOTE] 14 | > Not all optimizers have been tested with all models/parallel settings. They may or may not work, but this will gradually improve over time. 15 | -------------------------------------------------------------------------------- /docs/parallel/README.md: -------------------------------------------------------------------------------- 1 | # Finetrainers Parallel Backends 2 | 3 | Finetrainers supports parallel training on multiple GPUs & nodes. This is done using the Pytorch DTensor backend. To run parallel training, `torchrun` is utilized. 4 | 5 | As an experiment for comparing performance of different training backends, Finetrainers has implemented multi-backend support. These backends may or may not fully rely on Pytorch's distributed DTensor solution. Currently, only [🤗 Accelerate](https://github.com/huggingface/accelerate) is supported for backwards-compatibility reasons (as we initially started Finetrainers with only Accelerate). In the near future, there are plans for integrating with: 6 | - [DeepSpeed](https://github.com/deepspeedai/DeepSpeed) 7 | - [Nanotron](https://github.com/huggingface/nanotron) 8 | - [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) 9 | 10 | > [!IMPORTANT] 11 | > The multi-backend support is completely experimental and only serves to satisfy my curiosity of how much of a tradeoff there is between performance and ease of use. The Pytorch DTensor backend is the only one with stable support, following Accelerate. 12 | > 13 | > Users will not have to worry about backwards-breaking changes or dependencies if they stick to the Pytorch DTensor backend. 14 | 15 | ## Support matrix 16 | 17 | Currently supported parallelizations include: 18 | - [DDP](https://pytorch.org/docs/stable/notes/ddp.html) 19 | - [FSDP2](https://pytorch.org/docs/stable/fsdp.html) 20 | - [HSDP](https://pytorch.org/docs/stable/fsdp.html) 21 | - [CP](https://docs.pytorch.org/tutorials/prototype/context_parallel.html) 22 | 23 | 24 | ## Training 25 | 26 | The following parameters are relevant for launching training: 27 | 28 | - `parallel_backend`: The backend to use for parallel training. Available options are `ptd` & `accelerate`. 29 | - `pp_degree`: The degree of pipeline parallelism. Currently unsupported. 30 | - `dp_degree`: The degree of data parallelis/replicas. Defaults to `1`. 31 | - `dp_shards`: The number of shards for data parallelism. Defaults to `1`. 32 | - `cp_degree`: The degree of context parallelism. 33 | - `tp_degree`: The degree of tensor parallelism. 34 | 35 | For launching training with the Pytorch DTensor backend, use the following: 36 | 37 | ```bash 38 | # Single node - 8 GPUs available 39 | torchrun --standalone --nodes=1 --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint="localhost:0" train.py 40 | 41 | # Single node - 8 GPUs but only 4 available 42 | export CUDA_VISIBLE_DEVICES=0,2,4,5 43 | torchrun --standalone --nodes=1 --nproc_per_node=4 --rdzv_backend c10d --rdzv_endpoint="localhost:0" train.py 44 | 45 | # Multi-node - Nx8 GPUs available 46 | # TODO(aryan): Add slurm script 47 | ``` 48 | 49 | For launching training with the Accelerate backend, use the following: 50 | 51 | ```bash 52 | # Single node - 8 GPUs available 53 | accelerate launch --config_file accelerate_configs/uncompiled_8.yaml --gpu_ids 0,1,2,3,4,5,6,7 train.py 54 | 55 | # Single node - 8 GPUs but only 4 available 56 | accelerate launch --config_file accelerate_configs/uncompiled_4.yaml --gpu_ids 0,2,4,5 train.py 57 | 58 | # Multi-node - Nx8 GPUs available 59 | # TODO(aryan): Add slurm script 60 | ``` 61 | 62 | ## Inference 63 | 64 | For inference-only purposes, the example implementation can be found in the [examples/inference/](../../examples/inference/) directory. 65 | -------------------------------------------------------------------------------- /docs/trainer/control_trainer.md: -------------------------------------------------------------------------------- 1 | # Control Trainer 2 | 3 | The Control trainer supports channel-concatenated control conditioning for models either using low-rank adapters or full-rank training. It involves adding extra input channels to the patch embedding layer (referred to as the "control injection" layer in finetrainers), to mix conditioning features into the latent stream. This architecture choice is very common and has been seen before in many models - CogVideoX-I2V, HunyuanVideo-I2V, Alibaba's Fun Control models, etc. 4 | -------------------------------------------------------------------------------- /docs/trainer/sft_trainer.md: -------------------------------------------------------------------------------- 1 | # SFT Trainer 2 | 3 | The SFT trainer supports low-rank and full-rank finetuning of models. 4 | -------------------------------------------------------------------------------- /examples/_legacy/training/cogvideox/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/examples/_legacy/training/cogvideox/__init__.py -------------------------------------------------------------------------------- /examples/_legacy/training/cogvideox/text_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .text_encoder import compute_prompt_embeddings 2 | -------------------------------------------------------------------------------- /examples/_legacy/training/cogvideox/text_encoder/text_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | from transformers import T5EncoderModel, T5Tokenizer 5 | 6 | 7 | def _get_t5_prompt_embeds( 8 | tokenizer: T5Tokenizer, 9 | text_encoder: T5EncoderModel, 10 | prompt: Union[str, List[str]], 11 | num_videos_per_prompt: int = 1, 12 | max_sequence_length: int = 226, 13 | device: Optional[torch.device] = None, 14 | dtype: Optional[torch.dtype] = None, 15 | text_input_ids=None, 16 | ): 17 | prompt = [prompt] if isinstance(prompt, str) else prompt 18 | batch_size = len(prompt) 19 | 20 | if tokenizer is not None: 21 | text_inputs = tokenizer( 22 | prompt, 23 | padding="max_length", 24 | max_length=max_sequence_length, 25 | truncation=True, 26 | add_special_tokens=True, 27 | return_tensors="pt", 28 | ) 29 | text_input_ids = text_inputs.input_ids 30 | else: 31 | if text_input_ids is None: 32 | raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") 33 | 34 | prompt_embeds = text_encoder(text_input_ids.to(device))[0] 35 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 36 | 37 | # duplicate text embeddings for each generation per prompt, using mps friendly method 38 | _, seq_len, _ = prompt_embeds.shape 39 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) 40 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) 41 | 42 | return prompt_embeds 43 | 44 | 45 | def encode_prompt( 46 | tokenizer: T5Tokenizer, 47 | text_encoder: T5EncoderModel, 48 | prompt: Union[str, List[str]], 49 | num_videos_per_prompt: int = 1, 50 | max_sequence_length: int = 226, 51 | device: Optional[torch.device] = None, 52 | dtype: Optional[torch.dtype] = None, 53 | text_input_ids=None, 54 | ): 55 | prompt = [prompt] if isinstance(prompt, str) else prompt 56 | prompt_embeds = _get_t5_prompt_embeds( 57 | tokenizer, 58 | text_encoder, 59 | prompt=prompt, 60 | num_videos_per_prompt=num_videos_per_prompt, 61 | max_sequence_length=max_sequence_length, 62 | device=device, 63 | dtype=dtype, 64 | text_input_ids=text_input_ids, 65 | ) 66 | return prompt_embeds 67 | 68 | 69 | def compute_prompt_embeddings( 70 | tokenizer: T5Tokenizer, 71 | text_encoder: T5EncoderModel, 72 | prompt: str, 73 | max_sequence_length: int, 74 | device: torch.device, 75 | dtype: torch.dtype, 76 | requires_grad: bool = False, 77 | ): 78 | if requires_grad: 79 | prompt_embeds = encode_prompt( 80 | tokenizer, 81 | text_encoder, 82 | prompt, 83 | num_videos_per_prompt=1, 84 | max_sequence_length=max_sequence_length, 85 | device=device, 86 | dtype=dtype, 87 | ) 88 | else: 89 | with torch.no_grad(): 90 | prompt_embeds = encode_prompt( 91 | tokenizer, 92 | text_encoder, 93 | prompt, 94 | num_videos_per_prompt=1, 95 | max_sequence_length=max_sequence_length, 96 | device=device, 97 | dtype=dtype, 98 | ) 99 | return prompt_embeds 100 | -------------------------------------------------------------------------------- /examples/_legacy/training/mochi-1/dataset_simple.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from 3 | https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/dataset.py 4 | """ 5 | 6 | from pathlib import Path 7 | 8 | import click 9 | import torch 10 | from torch.utils.data import DataLoader, Dataset 11 | 12 | 13 | def load_to_cpu(x): 14 | return torch.load(x, map_location=torch.device("cpu"), weights_only=True) 15 | 16 | 17 | class LatentEmbedDataset(Dataset): 18 | def __init__(self, file_paths, repeat=1): 19 | self.items = [ 20 | (Path(p).with_suffix(".latent.pt"), Path(p).with_suffix(".embed.pt")) 21 | for p in file_paths 22 | if Path(p).with_suffix(".latent.pt").is_file() and Path(p).with_suffix(".embed.pt").is_file() 23 | ] 24 | self.items = self.items * repeat 25 | print(f"Loaded {len(self.items)}/{len(file_paths)} valid file pairs.") 26 | 27 | def __len__(self): 28 | return len(self.items) 29 | 30 | def __getitem__(self, idx): 31 | latent_path, embed_path = self.items[idx] 32 | return load_to_cpu(latent_path), load_to_cpu(embed_path) 33 | 34 | 35 | @click.command() 36 | @click.argument("directory", type=click.Path(exists=True, file_okay=False)) 37 | def process_videos(directory): 38 | dir_path = Path(directory) 39 | mp4_files = [str(f) for f in dir_path.glob("**/*.mp4") if not f.name.endswith(".recon.mp4")] 40 | assert mp4_files, f"No mp4 files found" 41 | 42 | dataset = LatentEmbedDataset(mp4_files) 43 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True) 44 | 45 | for latents, embeds in dataloader: 46 | print([(k, v.shape) for k, v in latents.items()]) 47 | 48 | 49 | if __name__ == "__main__": 50 | process_videos() 51 | -------------------------------------------------------------------------------- /examples/_legacy/training/mochi-1/prepare_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPU_ID=0 4 | VIDEO_DIR=video-dataset-disney-organized 5 | OUTPUT_DIR=videos_prepared 6 | NUM_FRAMES=37 7 | RESOLUTION=480x848 8 | 9 | # Extract width and height from RESOLUTION 10 | WIDTH=$(echo $RESOLUTION | cut -dx -f1) 11 | HEIGHT=$(echo $RESOLUTION | cut -dx -f2) 12 | 13 | python trim_and_crop_videos.py $VIDEO_DIR $OUTPUT_DIR --num_frames=$NUM_FRAMES --resolution=$RESOLUTION --force_upsample 14 | 15 | CUDA_VISIBLE_DEVICES=$GPU_ID python embed.py $OUTPUT_DIR --shape=${NUM_FRAMES}x${WIDTH}x${HEIGHT} 16 | -------------------------------------------------------------------------------- /examples/_legacy/training/mochi-1/requirements.txt: -------------------------------------------------------------------------------- 1 | peft 2 | transformers 3 | wandb 4 | torch 5 | torchvision 6 | av==11.0.0 7 | moviepy==1.0.3 8 | click -------------------------------------------------------------------------------- /examples/_legacy/training/mochi-1/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export NCCL_P2P_DISABLE=1 3 | export TORCH_NCCL_ENABLE_MONITORING=0 4 | 5 | GPU_IDS="0" 6 | 7 | DATA_ROOT="videos_prepared" 8 | MODEL="genmo/mochi-1-preview" 9 | OUTPUT_PATH="mochi-lora" 10 | 11 | cmd="CUDA_VISIBLE_DEVICES=$GPU_IDS python text_to_video_lora.py \ 12 | --pretrained_model_name_or_path $MODEL \ 13 | --cast_dit \ 14 | --data_root $DATA_ROOT \ 15 | --seed 42 \ 16 | --output_dir $OUTPUT_PATH \ 17 | --train_batch_size 1 \ 18 | --dataloader_num_workers 4 \ 19 | --pin_memory \ 20 | --caption_dropout 0.1 \ 21 | --max_train_steps 2000 \ 22 | --gradient_checkpointing \ 23 | --enable_slicing \ 24 | --enable_tiling \ 25 | --enable_model_cpu_offload \ 26 | --optimizer adamw \ 27 | --validation_prompt \"A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions\" \ 28 | --validation_prompt_separator ::: \ 29 | --num_validation_videos 1 \ 30 | --validation_epochs 1 \ 31 | --allow_tf32 \ 32 | --report_to wandb \ 33 | --push_to_hub" 34 | 35 | echo "Running command: $cmd" 36 | eval $cmd 37 | echo -ne "-------------------- Finished executing script --------------------\n\n" -------------------------------------------------------------------------------- /examples/_legacy/training/mochi-1/utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import inspect 3 | from typing import Optional, Tuple, Union 4 | 5 | import torch 6 | 7 | logger = get_logger(__name__) 8 | 9 | def reset_memory(device: Union[str, torch.device]) -> None: 10 | gc.collect() 11 | torch.cuda.empty_cache() 12 | torch.cuda.reset_peak_memory_stats(device) 13 | torch.cuda.reset_accumulated_memory_stats(device) 14 | 15 | 16 | def print_memory(device: Union[str, torch.device]) -> None: 17 | memory_allocated = torch.cuda.memory_allocated(device) / 1024**3 18 | max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 19 | max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 20 | print(f"{memory_allocated=:.3f} GB") 21 | print(f"{max_memory_allocated=:.3f} GB") 22 | print(f"{max_memory_reserved=:.3f} GB") 23 | -------------------------------------------------------------------------------- /examples/_legacy/training/prepare_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_ID="THUDM/CogVideoX-2b" 4 | 5 | NUM_GPUS=8 6 | 7 | # For more details on the expected data format, please refer to the README. 8 | DATA_ROOT="/path/to/my/datasets/video-dataset" # This needs to be the path to the base directory where your videos are located. 9 | CAPTION_COLUMN="prompt.txt" 10 | VIDEO_COLUMN="videos.txt" 11 | OUTPUT_DIR="/path/to/my/datasets/preprocessed-dataset" 12 | HEIGHT_BUCKETS="480 720" 13 | WIDTH_BUCKETS="720 960" 14 | FRAME_BUCKETS="49" 15 | MAX_NUM_FRAMES="49" 16 | MAX_SEQUENCE_LENGTH=226 17 | TARGET_FPS=8 18 | BATCH_SIZE=1 19 | DTYPE=fp32 20 | 21 | # To create a folder-style dataset structure without pre-encoding videos and captions 22 | # For Image-to-Video finetuning, make sure to pass `--save_image_latents` 23 | CMD_WITHOUT_PRE_ENCODING="\ 24 | torchrun --nproc_per_node=$NUM_GPUS \ 25 | training/prepare_dataset.py \ 26 | --model_id $MODEL_ID \ 27 | --data_root $DATA_ROOT \ 28 | --caption_column $CAPTION_COLUMN \ 29 | --video_column $VIDEO_COLUMN \ 30 | --output_dir $OUTPUT_DIR \ 31 | --height_buckets $HEIGHT_BUCKETS \ 32 | --width_buckets $WIDTH_BUCKETS \ 33 | --frame_buckets $FRAME_BUCKETS \ 34 | --max_num_frames $MAX_NUM_FRAMES \ 35 | --max_sequence_length $MAX_SEQUENCE_LENGTH \ 36 | --target_fps $TARGET_FPS \ 37 | --batch_size $BATCH_SIZE \ 38 | --dtype $DTYPE 39 | " 40 | 41 | CMD_WITH_PRE_ENCODING="$CMD_WITHOUT_PRE_ENCODING --save_latents_and_embeddings" 42 | 43 | # Select which you'd like to run 44 | CMD=$CMD_WITH_PRE_ENCODING 45 | 46 | echo "===== Running \`$CMD\` =====" 47 | eval $CMD 48 | echo -ne "===== Finished running script =====\n" 49 | -------------------------------------------------------------------------------- /examples/inference/cogvideox/cogvideox_text_to_video.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e -x 4 | 5 | # export TORCH_LOGS="+dynamo,recompiles,graph_breaks" 6 | # export TORCHDYNAMO_VERBOSE=1 7 | # export WANDB_MODE="offline" 8 | export WANDB_MODE="disabled" 9 | export NCCL_P2P_DISABLE=1 10 | export NCCL_IB_DISABLE=1 11 | export TORCH_NCCL_ENABLE_MONITORING=0 12 | export FINETRAINERS_LOG_LEVEL="DEBUG" 13 | 14 | # Download the validation dataset 15 | if [ ! -d "examples/inference/datasets/openvid-1k-split-validation" ]; then 16 | echo "Downloading validation dataset..." 17 | huggingface-cli download --repo-type dataset finetrainers/OpenVid-1k-split-validation --local-dir examples/inference/datasets/openvid-1k-split-validation 18 | else 19 | echo "Validation dataset already exists. Skipping download." 20 | fi 21 | 22 | BACKEND="ptd" 23 | 24 | NUM_GPUS=2 25 | CUDA_VISIBLE_DEVICES="2,3" 26 | 27 | # Check the JSON files for the expected JSON format 28 | DATASET_FILE="examples/inference/cogvideox/dummy_text_to_video.json" 29 | 30 | # Depending on how many GPUs you have available, choose your degree of parallelism and technique! 31 | DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" 32 | DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" 33 | DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" 34 | DDP_8="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 8 --dp_shards 1 --cp_degree 1 --tp_degree 1" 35 | CP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 2 --tp_degree 1" 36 | CP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 4 --tp_degree 1" 37 | # FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" 38 | # FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" 39 | # HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" 40 | 41 | # Parallel arguments 42 | parallel_cmd=( 43 | $CP_2 44 | ) 45 | 46 | # Model arguments 47 | model_cmd=( 48 | --model_name cogvideox 49 | --pretrained_model_name_or_path "THUDM/CogVideoX-5B" 50 | --enable_slicing 51 | --enable_tiling 52 | ) 53 | 54 | # Inference arguments 55 | inference_cmd=( 56 | --inference_type text_to_video 57 | --dataset_file "$DATASET_FILE" 58 | ) 59 | 60 | # Attention provider arguments 61 | attn_provider_cmd=( 62 | --attn_provider sage 63 | ) 64 | 65 | # Torch config arguments 66 | torch_config_cmd=( 67 | --allow_tf32 68 | --float32_matmul_precision high 69 | ) 70 | 71 | # Miscellaneous arguments 72 | miscellaneous_cmd=( 73 | --seed 31337 74 | --tracker_name "finetrainers-inference" 75 | --output_dir "/raid/aryan/cogvideox-inference" 76 | --init_timeout 600 77 | --nccl_timeout 600 78 | --report_to "wandb" 79 | ) 80 | 81 | # Execute the inference script 82 | export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES 83 | 84 | torchrun \ 85 | --standalone \ 86 | --nnodes=1 \ 87 | --nproc_per_node=$NUM_GPUS \ 88 | --rdzv_backend c10d \ 89 | --rdzv_endpoint="localhost:19242" \ 90 | examples/inference/inference.py \ 91 | "${parallel_cmd[@]}" \ 92 | "${model_cmd[@]}" \ 93 | "${inference_cmd[@]}" \ 94 | "${attn_provider_cmd[@]}" \ 95 | "${torch_config_cmd[@]}" \ 96 | "${miscellaneous_cmd[@]}" 97 | 98 | echo -ne "-------------------- Finished executing script --------------------\n\n" 99 | -------------------------------------------------------------------------------- /examples/inference/cogview4/cogview4_text_to_image.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e -x 4 | 5 | # export TORCH_LOGS="+dynamo,recompiles,graph_breaks" 6 | # export TORCHDYNAMO_VERBOSE=1 7 | # export WANDB_MODE="offline" 8 | export WANDB_MODE="disabled" 9 | export NCCL_P2P_DISABLE=1 10 | export NCCL_IB_DISABLE=1 11 | export TORCH_NCCL_ENABLE_MONITORING=0 12 | export FINETRAINERS_LOG_LEVEL="DEBUG" 13 | 14 | BACKEND="ptd" 15 | 16 | NUM_GPUS=2 17 | CUDA_VISIBLE_DEVICES="2,3" 18 | 19 | # Check the JSON files for the expected JSON format 20 | DATASET_FILE="examples/inference/cogview4/dummy_text_to_image.json" 21 | 22 | # Depending on how many GPUs you have available, choose your degree of parallelism and technique! 23 | DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" 24 | DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" 25 | DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" 26 | DDP_8="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 8 --dp_shards 1 --cp_degree 1 --tp_degree 1" 27 | CP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 2 --tp_degree 1" 28 | CP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 4 --tp_degree 1" 29 | # FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" 30 | # FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" 31 | # HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" 32 | 33 | # Parallel arguments 34 | parallel_cmd=( 35 | $CP_2 36 | ) 37 | 38 | # Model arguments 39 | model_cmd=( 40 | --model_name "cogview4" 41 | --pretrained_model_name_or_path "THUDM/CogView4-6B" 42 | --enable_slicing 43 | --enable_tiling 44 | ) 45 | 46 | # Inference arguments 47 | inference_cmd=( 48 | --inference_type text_to_image 49 | --dataset_file "$DATASET_FILE" 50 | ) 51 | 52 | # Attention provider arguments 53 | attn_provider_cmd=( 54 | --attn_provider flash_varlen 55 | ) 56 | 57 | # Torch config arguments 58 | torch_config_cmd=( 59 | --allow_tf32 60 | --float32_matmul_precision high 61 | ) 62 | 63 | # Miscellaneous arguments 64 | miscellaneous_cmd=( 65 | --seed 31337 66 | --tracker_name "finetrainers-inference" 67 | --output_dir "/raid/aryan/cogview4-inference" 68 | --init_timeout 600 69 | --nccl_timeout 600 70 | --report_to "wandb" 71 | ) 72 | 73 | # Execute the inference script 74 | export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES 75 | 76 | torchrun \ 77 | --standalone \ 78 | --nnodes=1 \ 79 | --nproc_per_node=$NUM_GPUS \ 80 | --rdzv_backend c10d \ 81 | --rdzv_endpoint="localhost:19242" \ 82 | examples/inference/inference.py \ 83 | "${parallel_cmd[@]}" \ 84 | "${model_cmd[@]}" \ 85 | "${inference_cmd[@]}" \ 86 | "${attn_provider_cmd[@]}" \ 87 | "${torch_config_cmd[@]}" \ 88 | "${miscellaneous_cmd[@]}" 89 | 90 | echo -ne "-------------------- Finished executing script --------------------\n\n" 91 | -------------------------------------------------------------------------------- /examples/inference/datasets/.gitignore: -------------------------------------------------------------------------------- 1 | openvid-1k-split-validation 2 | -------------------------------------------------------------------------------- /examples/inference/flux/flux_text_to_image.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e -x 4 | 5 | # export TORCH_LOGS="+dynamo,recompiles,graph_breaks" 6 | # export TORCHDYNAMO_VERBOSE=1 7 | # export WANDB_MODE="offline" 8 | export WANDB_MODE="disabled" 9 | export NCCL_P2P_DISABLE=1 10 | export NCCL_IB_DISABLE=1 11 | export TORCH_NCCL_ENABLE_MONITORING=0 12 | export FINETRAINERS_LOG_LEVEL="DEBUG" 13 | 14 | BACKEND="ptd" 15 | 16 | NUM_GPUS=4 17 | CUDA_VISIBLE_DEVICES="0,1,2,3" 18 | 19 | # Check the JSON files for the expected JSON format 20 | DATASET_FILE="examples/inference/flux/dummy_text_to_image.json" 21 | 22 | # Depending on how many GPUs you have available, choose your degree of parallelism and technique! 23 | DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" 24 | DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" 25 | DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" 26 | DDP_8="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 8 --dp_shards 1 --cp_degree 1 --tp_degree 1" 27 | CP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 2 --tp_degree 1" 28 | CP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 4 --tp_degree 1" 29 | # FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" 30 | # FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" 31 | # HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" 32 | 33 | # Parallel arguments 34 | parallel_cmd=( 35 | $CP_4 36 | ) 37 | 38 | # Model arguments 39 | model_cmd=( 40 | --model_name "flux" 41 | --pretrained_model_name_or_path "black-forest-labs/FLUX.1-dev" 42 | --cache_dir /raid/.cache/huggingface 43 | --enable_slicing 44 | --enable_tiling 45 | ) 46 | 47 | # Inference arguments 48 | inference_cmd=( 49 | --inference_type text_to_image 50 | --dataset_file "$DATASET_FILE" 51 | ) 52 | 53 | # Attention provider arguments 54 | attn_provider_cmd=( 55 | --attn_provider flash_varlen 56 | ) 57 | 58 | # Torch config arguments 59 | torch_config_cmd=( 60 | --allow_tf32 61 | --float32_matmul_precision high 62 | ) 63 | 64 | # Miscellaneous arguments 65 | miscellaneous_cmd=( 66 | --seed 31337 67 | --tracker_name "finetrainers-inference" 68 | --output_dir "/raid/aryan/flux-inference" 69 | --init_timeout 600 70 | --nccl_timeout 600 71 | --report_to "wandb" 72 | ) 73 | 74 | # Execute the inference script 75 | export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES 76 | 77 | torchrun \ 78 | --standalone \ 79 | --nnodes=1 \ 80 | --nproc_per_node=$NUM_GPUS \ 81 | --rdzv_backend c10d \ 82 | --rdzv_endpoint="localhost:19242" \ 83 | examples/inference/inference.py \ 84 | "${parallel_cmd[@]}" \ 85 | "${model_cmd[@]}" \ 86 | "${inference_cmd[@]}" \ 87 | "${attn_provider_cmd[@]}" \ 88 | "${torch_config_cmd[@]}" \ 89 | "${miscellaneous_cmd[@]}" 90 | 91 | echo -ne "-------------------- Finished executing script --------------------\n\n" 92 | -------------------------------------------------------------------------------- /examples/inference/wan/wan_text_to_video.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e -x 4 | 5 | # export TORCH_LOGS="+dynamo,recompiles,graph_breaks" 6 | # export TORCHDYNAMO_VERBOSE=1 7 | # export WANDB_MODE="offline" 8 | export WANDB_MODE="disabled" 9 | export NCCL_P2P_DISABLE=1 10 | export NCCL_IB_DISABLE=1 11 | export TORCH_NCCL_ENABLE_MONITORING=0 12 | export FINETRAINERS_LOG_LEVEL="DEBUG" 13 | 14 | # Download the validation dataset 15 | if [ ! -d "examples/inference/datasets/openvid-1k-split-validation" ]; then 16 | echo "Downloading validation dataset..." 17 | huggingface-cli download --repo-type dataset finetrainers/OpenVid-1k-split-validation --local-dir examples/inference/datasets/openvid-1k-split-validation 18 | else 19 | echo "Validation dataset already exists. Skipping download." 20 | fi 21 | 22 | BACKEND="ptd" 23 | 24 | NUM_GPUS=4 25 | CUDA_VISIBLE_DEVICES="0,1,2,3" 26 | 27 | # Check the JSON files for the expected JSON format 28 | DATASET_FILE="examples/inference/wan/dummy_text_to_video.json" 29 | 30 | # Depending on how many GPUs you have available, choose your degree of parallelism and technique! 31 | DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" 32 | DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" 33 | DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" 34 | DDP_8="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 8 --dp_shards 1 --cp_degree 1 --tp_degree 1" 35 | CP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 2 --tp_degree 1" 36 | CP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 4 --tp_degree 1" 37 | # FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" 38 | # FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" 39 | # HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" 40 | 41 | # Parallel arguments 42 | parallel_cmd=( 43 | $CP_4 44 | ) 45 | 46 | # Model arguments 47 | model_cmd=( 48 | --model_name "wan" 49 | --pretrained_model_name_or_path "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" 50 | --enable_slicing 51 | --enable_tiling 52 | ) 53 | 54 | # Inference arguments 55 | inference_cmd=( 56 | --inference_type text_to_video 57 | --dataset_file "$DATASET_FILE" 58 | ) 59 | 60 | # Attention provider arguments 61 | attn_provider_cmd=( 62 | --attn_provider sage 63 | ) 64 | 65 | # Torch config arguments 66 | torch_config_cmd=( 67 | --allow_tf32 68 | --float32_matmul_precision high 69 | ) 70 | 71 | # Miscellaneous arguments 72 | miscellaneous_cmd=( 73 | --seed 31337 74 | --tracker_name "finetrainers-inference" 75 | --output_dir "/raid/aryan/wan-inference" 76 | --init_timeout 600 77 | --nccl_timeout 600 78 | --report_to "wandb" 79 | ) 80 | 81 | # Execute the inference script 82 | export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES 83 | 84 | torchrun \ 85 | --standalone \ 86 | --nnodes=1 \ 87 | --nproc_per_node=$NUM_GPUS \ 88 | --rdzv_backend c10d \ 89 | --rdzv_endpoint="localhost:19242" \ 90 | examples/inference/inference.py \ 91 | "${parallel_cmd[@]}" \ 92 | "${model_cmd[@]}" \ 93 | "${inference_cmd[@]}" \ 94 | "${attn_provider_cmd[@]}" \ 95 | "${torch_config_cmd[@]}" \ 96 | "${miscellaneous_cmd[@]}" 97 | 98 | echo -ne "-------------------- Finished executing script --------------------\n\n" 99 | -------------------------------------------------------------------------------- /examples/training/control/cogview4/canny/.gitignore: -------------------------------------------------------------------------------- 1 | !validation_dataset/**/* -------------------------------------------------------------------------------- /examples/training/control/cogview4/canny/README.md: -------------------------------------------------------------------------------- 1 | # CogView4 Canny Control training 2 | 3 | To launch training, you can run the following from the root directory of the repository. 4 | 5 | ```bash 6 | chmod +x ./examples/training/sft/cogview4/canny/train.sh 7 | ./examples/training/sft/cogview4/canny/train.sh 8 | ``` 9 | 10 | The script should automatically download the validation dataset, but in case that doesn't happen, please make sure that a folder named `validation_dataset` exists in `examples/training/sft/cogview4/omni_edit/` and contains the validation dataset. You can also configure `validation.json` in the same directory however you like for your own validation dataset. 11 | 12 | ```bash 13 | cd examples/training/sft/cogview4/canny/ 14 | huggingface-cli download --repo-type dataset finetrainers/Canny-image-validation-dataset --local-dir validation_dataset 15 | ``` 16 | -------------------------------------------------------------------------------- /examples/training/control/cogview4/canny/training.json: -------------------------------------------------------------------------------- 1 | { 2 | "datasets": [ 3 | { 4 | "data_root": "recoilme/aesthetic_photos_xs", 5 | "dataset_type": "image", 6 | "image_resolution_buckets": [ 7 | [1024, 1024] 8 | ], 9 | "reshape_mode": "bicubic", 10 | "remove_common_llm_caption_prefixes": true 11 | } 12 | ] 13 | } -------------------------------------------------------------------------------- /examples/training/control/cogview4/canny/validation.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "caption": "an orange flamingo stands in the shallow water, solo, standing, full_body, outdoors, blurry, no_humans, bird, leaf, plant, animal_focus, beak", 5 | "image_path": "examples/training/control/cogview4/canny/validation_dataset/0.png", 6 | "num_inference_steps": 30, 7 | "height": 1024, 8 | "width": 1024 9 | }, 10 | { 11 | "caption": "a woman holding a bouquet of flowers on the street, 1girl, solo, long_hair, shirt, black_hair, long_sleeves, holding, bare_shoulders, jewelry, upper_body, flower, earrings, outdoors, parted_lips, striped, off_shoulder, black_eyes, tree, looking_to_the_side, grass, pink_flower, striped_shirt, off-shoulder_shirt, holding_flower", 12 | "image_path": "examples/training/control/cogview4/canny/validation_dataset/1.png", 13 | "num_inference_steps": 30, 14 | "height": 1024, 15 | "width": 1024 16 | }, 17 | { 18 | "caption": "there is a boat on the river in the wilderness, outdoors, sky, day, water, blurry, tree, no_humans, leaf, plant, nature, scenery, reflection, mountain, lake", 19 | "image_path": "examples/training/control/cogview4/canny/validation_dataset/2.png", 20 | "num_inference_steps": 30, 21 | "height": 1024, 22 | "width": 1024 23 | }, 24 | { 25 | "caption": "a man in white lab coat wearing a pair of virtual glasses, shirt, black_hair, 1boy, holding, male_focus, multiple_boys, indoors, blurry, cup, depth_of_field, blurry_background, blue_shirt, mug, labcoat, coffee, coffee_mug, doctor", 26 | "image_path": "examples/training/control/cogview4/canny/validation_dataset/3.png", 27 | "num_inference_steps": 30, 28 | "height": 1024, 29 | "width": 1024 30 | } 31 | ] 32 | } 33 | -------------------------------------------------------------------------------- /examples/training/control/cogview4/omni_edit/.gitignore: -------------------------------------------------------------------------------- 1 | !validation_dataset/**/* -------------------------------------------------------------------------------- /examples/training/control/cogview4/omni_edit/README.md: -------------------------------------------------------------------------------- 1 | # CogView4 Edit Control training 2 | 3 | To launch training, you can run the following from the root directory of the repository. 4 | 5 | ```bash 6 | chmod +x ./examples/training/sft/cogview4/omni_edit/train.sh 7 | ./examples/training/sft/cogview4/omni_edit/train.sh 8 | ``` 9 | 10 | The script should automatically download the validation dataset, but in case that doesn't happen, please make sure that a folder named `validation_dataset` exists in `examples/training/sft/cogview4/omni_edit/` and contains the validation dataset. You can also configure `validation.json` in the same directory however you like for your own validation dataset. 11 | 12 | ```bash 13 | cd examples/training/sft/cogview4/omni_edit/ 14 | huggingface-cli download --repo-type dataset finetrainers/OmniEdit-validation-dataset --local-dir validation_dataset 15 | ``` 16 | -------------------------------------------------------------------------------- /examples/training/control/cogview4/omni_edit/training.json: -------------------------------------------------------------------------------- 1 | { 2 | "datasets": [ 3 | { 4 | "data_root": "sayakpaul/OmniEdit-mini", 5 | "dataset_type": "image", 6 | "image_resolution_buckets": [ 7 | [384, 384], 8 | [480, 480], 9 | [512, 512], 10 | [640, 640], 11 | [704, 704], 12 | [768, 768], 13 | [832, 832], 14 | [896, 896], 15 | [960, 960], 16 | [1024, 1024], 17 | [1152, 1152], 18 | [1280, 1280], 19 | [1344, 1344], 20 | [480, 720], 21 | [480, 768], 22 | [512, 768], 23 | [640, 768], 24 | [768, 960], 25 | [768, 1152], 26 | [768, 1360], 27 | [864, 1152], 28 | [864, 1360], 29 | [720, 480], 30 | [768, 480], 31 | [768, 512], 32 | [768, 640], 33 | [960, 768], 34 | [1152, 768], 35 | [1360, 768], 36 | [1152, 864], 37 | [1360, 864] 38 | ], 39 | "caption_options": { 40 | "column_names": "edited_prompt_list" 41 | }, 42 | "rename_columns": { 43 | "src_img": "control_image", 44 | "edited_img": "image" 45 | }, 46 | "reshape_mode": "bicubic", 47 | "remove_common_llm_caption_prefixes": true 48 | } 49 | ] 50 | } -------------------------------------------------------------------------------- /examples/training/control/wan/image_condition/training.json: -------------------------------------------------------------------------------- 1 | { 2 | "datasets": [ 3 | { 4 | "data_root": "recoilme/aesthetic_photos_xs", 5 | "dataset_type": "image", 6 | "image_resolution_buckets": [ 7 | [1024, 1024] 8 | ], 9 | "reshape_mode": "bicubic", 10 | "remove_common_llm_caption_prefixes": true 11 | }, 12 | { 13 | "data_root": "finetrainers/OpenVid-1k-split", 14 | "dataset_type": "video", 15 | "video_resolution_buckets": [ 16 | [49, 512, 512], 17 | [49, 768, 768], 18 | [49, 1024, 1024], 19 | [49, 480, 704], 20 | [49, 704, 480] 21 | ], 22 | "reshape_mode": "bicubic", 23 | "remove_common_llm_caption_prefixes": true 24 | } 25 | ] 26 | } -------------------------------------------------------------------------------- /examples/training/sft/cogvideox/crush_smol_lora/training.json: -------------------------------------------------------------------------------- 1 | { 2 | "datasets": [ 3 | { 4 | "data_root": "finetrainers/crush-smol", 5 | "dataset_type": "video", 6 | "id_token": "PIKA_CRUSH", 7 | "video_resolution_buckets": [ 8 | [81, 480, 768] 9 | ], 10 | "reshape_mode": "bicubic", 11 | "remove_common_llm_caption_prefixes": true 12 | } 13 | ] 14 | } -------------------------------------------------------------------------------- /examples/training/sft/cogvideox/crush_smol_lora/validation.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "caption": "PIKA_CRUSH A red toy car is being crushed by a large hydraulic press, which is flattening objects as if they were under a hydraulic press.", 5 | "image_path": null, 6 | "video_path": null, 7 | "num_inference_steps": 50, 8 | "height": 512, 9 | "width": 768, 10 | "num_frames": 49, 11 | "frame_rate": 25 12 | }, 13 | { 14 | "caption": "PIKA_CRUSH A green cube is being compressed by a hydraulic press, which flattens the object as if it were under a hydraulic press. The press is shown in action, with the cube being squeezed into a smaller shape.", 15 | "image_path": null, 16 | "video_path": null, 17 | "num_inference_steps": 50, 18 | "height": 512, 19 | "width": 768, 20 | "num_frames": 49, 21 | "frame_rate": 25 22 | }, 23 | { 24 | "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of colorful jelly beans, flattening them as if they were under a hydraulic press.", 25 | "image_path": null, 26 | "video_path": null, 27 | "num_inference_steps": 50, 28 | "height": 512, 29 | "width": 768, 30 | "num_frames": 49, 31 | "frame_rate": 25 32 | }, 33 | { 34 | "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.", 35 | "image_path": null, 36 | "video_path": null, 37 | "num_inference_steps": 50, 38 | "height": 512, 39 | "width": 768, 40 | "num_frames": 49, 41 | "frame_rate": 25 42 | } 43 | ] 44 | } 45 | -------------------------------------------------------------------------------- /examples/training/sft/cogview4/raider_white_tarot/training.json: -------------------------------------------------------------------------------- 1 | { 2 | "datasets": [ 3 | { 4 | "data_root": "multimodalart/1920-raider-waite-tarot-public-domain", 5 | "dataset_type": "image", 6 | "id_token": "TRTCRD", 7 | "image_resolution_buckets": [ 8 | [1280, 720] 9 | ], 10 | "reshape_mode": "bicubic", 11 | "remove_common_llm_caption_prefixes": true 12 | }, 13 | { 14 | "data_root": "multimodalart/1920-raider-waite-tarot-public-domain", 15 | "dataset_type": "image", 16 | "id_token": "TRTCRD", 17 | "image_resolution_buckets": [ 18 | [512, 512] 19 | ], 20 | "reshape_mode": "center_crop", 21 | "remove_common_llm_caption_prefixes": true 22 | }, 23 | { 24 | "data_root": "multimodalart/1920-raider-waite-tarot-public-domain", 25 | "dataset_type": "image", 26 | "id_token": "TRTCRD", 27 | "image_resolution_buckets": [ 28 | [768, 768] 29 | ], 30 | "reshape_mode": "center_crop", 31 | "remove_common_llm_caption_prefixes": true 32 | } 33 | ] 34 | } -------------------------------------------------------------------------------- /examples/training/sft/cogview4/raider_white_tarot/validation.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "caption": "TRTCRD a trtcrd of a knight mounting a running horse wearing an armor and holding a staff, \"knight of wands\"", 5 | "image_path": null, 6 | "video_path": null, 7 | "num_inference_steps": 50, 8 | "height": 1280, 9 | "width": 720 10 | }, 11 | { 12 | "caption": "TRTCRD a trtcrd of a woman sitting on a throne, wearing a crown and holding a trophee, \"queen of cups\"", 13 | "image_path": null, 14 | "video_path": null, 15 | "num_inference_steps": 50, 16 | "height": 1280, 17 | "width": 720 18 | }, 19 | { 20 | "caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse", 21 | "image_path": null, 22 | "video_path": null, 23 | "num_inference_steps": 50, 24 | "height": 1280, 25 | "width": 720 26 | }, 27 | { 28 | "caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles", 29 | "image_path": null, 30 | "video_path": null, 31 | "num_inference_steps": 50, 32 | "height": 1280, 33 | "width": 720 34 | }, 35 | { 36 | "caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse", 37 | "image_path": null, 38 | "video_path": null, 39 | "num_inference_steps": 50, 40 | "height": 512, 41 | "width": 512 42 | }, 43 | { 44 | "caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles", 45 | "image_path": null, 46 | "video_path": null, 47 | "num_inference_steps": 50, 48 | "height": 512, 49 | "width": 512 50 | }, 51 | { 52 | "caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse", 53 | "image_path": null, 54 | "video_path": null, 55 | "num_inference_steps": 50, 56 | "height": 768, 57 | "width": 768 58 | }, 59 | { 60 | "caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles", 61 | "image_path": null, 62 | "video_path": null, 63 | "num_inference_steps": 50, 64 | "height": 768, 65 | "width": 768 66 | } 67 | ] 68 | } 69 | -------------------------------------------------------------------------------- /examples/training/sft/cogview4/the_simpsons/README.md: -------------------------------------------------------------------------------- 1 | # CogView4-6B The Simpsons dataset 2 | 3 | This example is only an experiment to verify if webdataset loading and streaming from the HF Hub works as expected. Do not expect meaningful results. 4 | 5 | The dataset used for testing is available at [`bigdata-pw/TheSimpsons`](https://huggingface.co/datasets/bigdata-pw/TheSimpsons). 6 | -------------------------------------------------------------------------------- /examples/training/sft/cogview4/the_simpsons/training.json: -------------------------------------------------------------------------------- 1 | { 2 | "datasets": [ 3 | { 4 | "data_root": "bigdata-pw/TheSimpsons", 5 | "dataset_type": "image", 6 | "id_token": "SMPSN", 7 | "image_resolution_buckets": [ 8 | [960, 528], 9 | [720, 528], 10 | [720, 480] 11 | ], 12 | "reshape_mode": "bicubic", 13 | "remove_common_llm_caption_prefixes": true, 14 | "caption_options": { 15 | "column_names": ["caption.txt", "detailed_caption.txt", "more_detailed_caption.txt"], 16 | "weights": { 17 | "caption.txt": 0.2, 18 | "detailed_caption.txt": 0.6, 19 | "more_detailed_caption.txt": 0.2 20 | } 21 | } 22 | } 23 | ] 24 | } -------------------------------------------------------------------------------- /examples/training/sft/flux_dev/raider_white_tarot/training.json: -------------------------------------------------------------------------------- 1 | { 2 | "datasets": [ 3 | { 4 | "data_root": "multimodalart/1920-raider-waite-tarot-public-domain", 5 | "dataset_type": "image", 6 | "id_token": "TRTCRD", 7 | "image_resolution_buckets": [ 8 | [1280, 720] 9 | ], 10 | "reshape_mode": "bicubic", 11 | "remove_common_llm_caption_prefixes": true 12 | }, 13 | { 14 | "data_root": "multimodalart/1920-raider-waite-tarot-public-domain", 15 | "dataset_type": "image", 16 | "id_token": "TRTCRD", 17 | "image_resolution_buckets": [ 18 | [512, 512] 19 | ], 20 | "reshape_mode": "center_crop", 21 | "remove_common_llm_caption_prefixes": true 22 | }, 23 | { 24 | "data_root": "multimodalart/1920-raider-waite-tarot-public-domain", 25 | "dataset_type": "image", 26 | "id_token": "TRTCRD", 27 | "image_resolution_buckets": [ 28 | [768, 768] 29 | ], 30 | "reshape_mode": "center_crop", 31 | "remove_common_llm_caption_prefixes": true 32 | } 33 | ] 34 | } -------------------------------------------------------------------------------- /examples/training/sft/flux_dev/raider_white_tarot/validation.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "caption": "TRTCRD a trtcrd of a knight mounting a running horse wearing an armor and holding a staff, \"knight of wands\"", 5 | "image_path": null, 6 | "video_path": null, 7 | "num_inference_steps": 50, 8 | "height": 1280, 9 | "width": 720 10 | }, 11 | { 12 | "caption": "TRTCRD a trtcrd of a woman sitting on a throne, wearing a crown and holding a trophee, \"queen of cups\"", 13 | "image_path": null, 14 | "video_path": null, 15 | "num_inference_steps": 50, 16 | "height": 1280, 17 | "width": 720 18 | }, 19 | { 20 | "caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse", 21 | "image_path": null, 22 | "video_path": null, 23 | "num_inference_steps": 50, 24 | "height": 1280, 25 | "width": 720 26 | }, 27 | { 28 | "caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles", 29 | "image_path": null, 30 | "video_path": null, 31 | "num_inference_steps": 50, 32 | "height": 1280, 33 | "width": 720 34 | }, 35 | { 36 | "caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse", 37 | "image_path": null, 38 | "video_path": null, 39 | "num_inference_steps": 50, 40 | "height": 512, 41 | "width": 512 42 | }, 43 | { 44 | "caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles", 45 | "image_path": null, 46 | "video_path": null, 47 | "num_inference_steps": 50, 48 | "height": 512, 49 | "width": 512 50 | }, 51 | { 52 | "caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse", 53 | "image_path": null, 54 | "video_path": null, 55 | "num_inference_steps": 50, 56 | "height": 768, 57 | "width": 768 58 | }, 59 | { 60 | "caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles", 61 | "image_path": null, 62 | "video_path": null, 63 | "num_inference_steps": 50, 64 | "height": 768, 65 | "width": 768 66 | } 67 | ] 68 | } 69 | -------------------------------------------------------------------------------- /examples/training/sft/hunyuan_video/modal_labs_dissolve/training.json: -------------------------------------------------------------------------------- 1 | { 2 | "datasets": [ 3 | { 4 | "data_root": "modal-labs/dissolve", 5 | "dataset_type": "video", 6 | "id_token": "MODAL_DISSOLVE", 7 | "video_resolution_buckets": [ 8 | [49, 480, 768] 9 | ], 10 | "reshape_mode": "bicubic", 11 | "remove_common_llm_caption_prefixes": true 12 | }, 13 | { 14 | "data_root": "modal-labs/dissolve", 15 | "dataset_type": "video", 16 | "id_token": "MODAL_DISSOLVE", 17 | "video_resolution_buckets": [ 18 | [81, 480, 768] 19 | ], 20 | "reshape_mode": "bicubic", 21 | "remove_common_llm_caption_prefixes": true 22 | } 23 | ] 24 | } -------------------------------------------------------------------------------- /examples/training/sft/ltx_video/crush_smol_lora/training.json: -------------------------------------------------------------------------------- 1 | { 2 | "datasets": [ 3 | { 4 | "data_root": "finetrainers/crush-smol", 5 | "dataset_type": "video", 6 | "id_token": "PIKA_CRUSH", 7 | "video_resolution_buckets": [ 8 | [49, 512, 768] 9 | ], 10 | "reshape_mode": "bicubic", 11 | "remove_common_llm_caption_prefixes": true 12 | } 13 | ] 14 | } -------------------------------------------------------------------------------- /examples/training/sft/ltx_video/crush_smol_lora/training_multires.json: -------------------------------------------------------------------------------- 1 | { 2 | "datasets": [ 3 | { 4 | "data_root": "finetrainers/crush-smol", 5 | "dataset_type": "video", 6 | "id_token": "PIKA_CRUSH", 7 | "video_resolution_buckets": [ 8 | [49, 512, 768] 9 | ], 10 | "reshape_mode": "bicubic", 11 | "remove_common_llm_caption_prefixes": true 12 | }, 13 | { 14 | "data_root": "finetrainers/crush-smol", 15 | "dataset_type": "video", 16 | "id_token": "PIKA_CRUSH", 17 | "video_resolution_buckets": [ 18 | [81, 512, 768] 19 | ], 20 | "reshape_mode": "bicubic", 21 | "remove_common_llm_caption_prefixes": true 22 | }, 23 | { 24 | "data_root": "finetrainers/crush-smol", 25 | "dataset_type": "video", 26 | "id_token": "PIKA_CRUSH", 27 | "video_resolution_buckets": [ 28 | [121, 512, 768] 29 | ], 30 | "reshape_mode": "bicubic", 31 | "remove_common_llm_caption_prefixes": true 32 | }, 33 | { 34 | "data_root": "finetrainers/crush-smol", 35 | "dataset_type": "video", 36 | "id_token": "PIKA_CRUSH", 37 | "video_resolution_buckets": [ 38 | [161, 512, 768] 39 | ], 40 | "reshape_mode": "bicubic", 41 | "remove_common_llm_caption_prefixes": true 42 | } 43 | ] 44 | } -------------------------------------------------------------------------------- /examples/training/sft/ltx_video/crush_smol_lora/validation.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "caption": "PIKA_CRUSH A red toy car is being crushed by a large hydraulic press, which is flattening objects as if they were under a hydraulic press.", 5 | "image_path": null, 6 | "video_path": null, 7 | "num_inference_steps": 50, 8 | "height": 512, 9 | "width": 768, 10 | "num_frames": 49, 11 | "frame_rate": 25 12 | }, 13 | { 14 | "caption": "PIKA_CRUSH A green cube is being compressed by a hydraulic press, which flattens the object as if it were under a hydraulic press. The press is shown in action, with the cube being squeezed into a smaller shape.", 15 | "image_path": null, 16 | "video_path": null, 17 | "num_inference_steps": 50, 18 | "height": 512, 19 | "width": 768, 20 | "num_frames": 49, 21 | "frame_rate": 25 22 | }, 23 | { 24 | "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of colorful jelly beans, flattening them as if they were under a hydraulic press.", 25 | "image_path": null, 26 | "video_path": null, 27 | "num_inference_steps": 50, 28 | "height": 512, 29 | "width": 768, 30 | "num_frames": 49, 31 | "frame_rate": 25 32 | }, 33 | { 34 | "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.", 35 | "image_path": null, 36 | "video_path": null, 37 | "num_inference_steps": 50, 38 | "height": 512, 39 | "width": 768, 40 | "num_frames": 49, 41 | "frame_rate": 25 42 | } 43 | ] 44 | } 45 | -------------------------------------------------------------------------------- /examples/training/sft/ltx_video/crush_smol_lora/validation_multires.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "caption": "PIKA_CRUSH A red toy car is being crushed by a large hydraulic press, which is flattening objects as if they were under a hydraulic press.", 5 | "image_path": null, 6 | "video_path": null, 7 | "num_inference_steps": 50, 8 | "height": 512, 9 | "width": 768, 10 | "num_frames": 49, 11 | "frame_rate": 25 12 | }, 13 | { 14 | "caption": "PIKA_CRUSH A green cube is being compressed by a hydraulic press, which flattens the object as if it were under a hydraulic press. The press is shown in action, with the cube being squeezed into a smaller shape.", 15 | "image_path": null, 16 | "video_path": null, 17 | "num_inference_steps": 50, 18 | "height": 512, 19 | "width": 768, 20 | "num_frames": 49, 21 | "frame_rate": 25 22 | }, 23 | { 24 | "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of colorful jelly beans, flattening them as if they were under a hydraulic press.", 25 | "image_path": null, 26 | "video_path": null, 27 | "num_inference_steps": 50, 28 | "height": 512, 29 | "width": 768, 30 | "num_frames": 49, 31 | "frame_rate": 25 32 | }, 33 | { 34 | "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.", 35 | "image_path": null, 36 | "video_path": null, 37 | "num_inference_steps": 50, 38 | "height": 512, 39 | "width": 768, 40 | "num_frames": 49, 41 | "frame_rate": 25 42 | }, 43 | { 44 | "caption": "PIKA_CRUSH A red toy car is being crushed by a large hydraulic press, which is flattening objects as if they were under a hydraulic press.", 45 | "image_path": null, 46 | "video_path": null, 47 | "num_inference_steps": 50, 48 | "height": 512, 49 | "width": 768, 50 | "num_frames": 81, 51 | "frame_rate": 25 52 | }, 53 | { 54 | "caption": "PIKA_CRUSH A red toy car is being crushed by a large hydraulic press, which is flattening objects as if they were under a hydraulic press.", 55 | "image_path": null, 56 | "video_path": null, 57 | "num_inference_steps": 50, 58 | "height": 512, 59 | "width": 768, 60 | "num_frames": 121, 61 | "frame_rate": 25 62 | }, 63 | { 64 | "caption": "PIKA_CRUSH A red toy car is being crushed by a large hydraulic press, which is flattening objects as if they were under a hydraulic press.", 65 | "image_path": null, 66 | "video_path": null, 67 | "num_inference_steps": 50, 68 | "height": 512, 69 | "width": 768, 70 | "num_frames": 161, 71 | "frame_rate": 25 72 | }, 73 | { 74 | "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.", 75 | "image_path": null, 76 | "video_path": null, 77 | "num_inference_steps": 50, 78 | "height": 512, 79 | "width": 768, 80 | "num_frames": 161, 81 | "frame_rate": 25 82 | } 83 | ] 84 | } 85 | -------------------------------------------------------------------------------- /examples/training/sft/wan/3dgs_dissolve/training.json: -------------------------------------------------------------------------------- 1 | { 2 | "datasets": [ 3 | { 4 | "data_root": "finetrainers/3dgs-dissolve", 5 | "dataset_type": "video", 6 | "id_token": "3DGS_DISSOLVE", 7 | "video_resolution_buckets": [ 8 | [49, 480, 832] 9 | ], 10 | "reshape_mode": "bicubic", 11 | "remove_common_llm_caption_prefixes": true 12 | }, 13 | { 14 | "data_root": "finetrainers/3dgs-dissolve", 15 | "dataset_type": "video", 16 | "id_token": "3DGS_DISSOLVE", 17 | "video_resolution_buckets": [ 18 | [81, 480, 832] 19 | ], 20 | "reshape_mode": "bicubic", 21 | "remove_common_llm_caption_prefixes": true 22 | } 23 | ] 24 | } -------------------------------------------------------------------------------- /examples/training/sft/wan/3dgs_dissolve/validation.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "caption": "A spacecraft, rendered in a 3D appearance, ascends into the night sky, leaving behind a trail of fiery exhaust. As it climbs higher, the exhaust gradually transforms into a burst of red sparks, creating a dramatic and dynamic visual effect against the dark backdrop.", 5 | "image_path": null, 6 | "video_path": null, 7 | "num_inference_steps": 50, 8 | "height": 480, 9 | "width": 832, 10 | "num_frames": 49 11 | }, 12 | { 13 | "caption": "3DGS_DISSOLVE A spacecraft, rendered in a 3D appearance, ascends into the night sky, leaving behind a trail of fiery exhaust. As it climbs higher, the exhaust gradually transforms into a burst of red sparks, creating a dramatic and dynamic visual effect against the dark backdrop.", 14 | "image_path": null, 15 | "video_path": null, 16 | "num_inference_steps": 50, 17 | "height": 480, 18 | "width": 832, 19 | "num_frames": 49 20 | }, 21 | { 22 | "caption": "3DGS_DISSOLVE A spacecraft, rendered in a 3D appearance, ascends into the night sky, leaving behind a trail of fiery exhaust. As it climbs higher, the exhaust gradually transforms into a burst of red sparks, creating a dramatic and dynamic visual effect against the dark backdrop.", 23 | "image_path": null, 24 | "video_path": null, 25 | "num_inference_steps": 50, 26 | "height": 480, 27 | "width": 832, 28 | "num_frames": 81 29 | }, 30 | { 31 | "caption": "3DGS_DISSOLVE A vintage-style treasure chest, rendered in a 3D appearance, stands prominently against a dark background. As the scene progresses, the chest begins to emit a glowing light, which intensifies until it evaporates into a burst of red sparks, creating a dramatic and mysterious atmosphere.", 32 | "image_path": null, 33 | "video_path": null, 34 | "num_inference_steps": 50, 35 | "height": 480, 36 | "width": 832, 37 | "num_frames": 49 38 | }, 39 | { 40 | "caption": "3DGS_DISSOLVE A glowing, fiery cube in a 3D appearance begins to spin and rotate, its edges shimmering with intense light. As it continues to spin, the cube gradually evaporates into a burst of red sparks that scatter across the screen, creating a dynamic and mesmerizing visual effect against the dark background.", 41 | "image_path": null, 42 | "video_path": null, 43 | "num_inference_steps": 50, 44 | "height": 480, 45 | "width": 832, 46 | "num_frames": 49 47 | }, 48 | { 49 | "caption": "3DGS_DISSOLVE A dynamic explosion unfolds in a 3D appearance, beginning as a concentrated burst of intense orange flames. As the fire intensifies, it rapidly expands outward, transitioning into a vibrant display of red sparks that scatter across the frame. The sparks continue to evolve, evaporating into a burst of red sparks against the dark backdrop, creating a mesmerizing visual spectacle.", 50 | "image_path": null, 51 | "video_path": null, 52 | "num_inference_steps": 50, 53 | "height": 480, 54 | "width": 832, 55 | "num_frames": 49 56 | } 57 | ] 58 | } 59 | -------------------------------------------------------------------------------- /examples/training/sft/wan/crush_smol_lora/training.json: -------------------------------------------------------------------------------- 1 | { 2 | "datasets": [ 3 | { 4 | "data_root": "finetrainers/crush-smol", 5 | "dataset_type": "video", 6 | "id_token": "PIKA_CRUSH", 7 | "video_resolution_buckets": [ 8 | [49, 480, 832] 9 | ], 10 | "reshape_mode": "bicubic", 11 | "remove_common_llm_caption_prefixes": true 12 | } 13 | ] 14 | } -------------------------------------------------------------------------------- /examples/training/sft/wan/crush_smol_lora/validation.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "caption": "PIKA_CRUSH A red toy car is being crushed by a large hydraulic press, which is flattening objects as if they were under a hydraulic press.", 5 | "image_path": null, 6 | "video_path": null, 7 | "num_inference_steps": 50, 8 | "height": 480, 9 | "width": 832, 10 | "num_frames": 49 11 | }, 12 | { 13 | "caption": "PIKA_CRUSH A green cube is being compressed by a hydraulic press, which flattens the object as if it were under a hydraulic press. The press is shown in action, with the cube being squeezed into a smaller shape.", 14 | "image_path": null, 15 | "video_path": null, 16 | "num_inference_steps": 50, 17 | "height": 480, 18 | "width": 832, 19 | "num_frames": 49 20 | }, 21 | { 22 | "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of colorful jelly beans, flattening them as if they were under a hydraulic press.", 23 | "image_path": null, 24 | "video_path": null, 25 | "num_inference_steps": 50, 26 | "height": 480, 27 | "width": 832, 28 | "num_frames": 49 29 | }, 30 | { 31 | "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.", 32 | "image_path": null, 33 | "video_path": null, 34 | "num_inference_steps": 50, 35 | "height": 480, 36 | "width": 832, 37 | "num_frames": 49 38 | } 39 | ] 40 | } 41 | -------------------------------------------------------------------------------- /examples/training/sft/wan_i2v/3dgs_dissolve/training.json: -------------------------------------------------------------------------------- 1 | { 2 | "datasets": [ 3 | { 4 | "data_root": "finetrainers/3dgs-dissolve", 5 | "dataset_type": "video", 6 | "id_token": "3DGS_DISSOLVE", 7 | "video_resolution_buckets": [ 8 | [49, 480, 832] 9 | ], 10 | "reshape_mode": "bicubic", 11 | "remove_common_llm_caption_prefixes": true 12 | } 13 | ] 14 | } -------------------------------------------------------------------------------- /examples/training/sft/wan_i2v/3dgs_dissolve/validation.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "caption": "3DGS_DISSOLVE A vibrant green Mustang GT parked in a parking lot. The car is positioned at an angle, showcasing its sleek design and black rims. The car's hood is black, contrasting with the green body. The car gradually transforms and bursts into red sparks, creating a dramatic and dynamic visual effect against a dark backdrop.", 5 | "video_path": "examples/training/sft/wan_i2v/3dgs_dissolve/validation_dataset/0.mp4", 6 | "num_inference_steps": 30, 7 | "num_frames": 49, 8 | "height": 480, 9 | "width": 832 10 | }, 11 | { 12 | "caption": "3DGS_DISSOLVE A cooking tutorial featuring a man in a kitchen. He is wearing a white t-shirt and a black apron. As the scene progresses, light starts to emanate from the man and he burst into a fiery flame of red sparks.", 13 | "video_path": "examples/training/control/wan/image_condition/validation_dataset/1.mp4", 14 | "num_inference_steps": 30, 15 | "num_frames": 49, 16 | "height": 480, 17 | "width": 832 18 | }, 19 | { 20 | "caption": "3DGS_DISSOLVE A man in a suit and tie, standing against a blue background with a digital pattern. He appears to be speaking or presenting, as suggested by his open mouth and focused expression. Suddenly, the man starts to dissolve into thin air with a bright fiery flame of red sparks.", 21 | "video_path": "examples/training/control/wan/image_condition/validation_dataset/2.mp4", 22 | "num_inference_steps": 30, 23 | "num_frames": 49, 24 | "height": 480, 25 | "width": 832 26 | }, 27 | { 28 | "caption": "3DGS_DISSOLVE A man in a workshop, dressed in a black shirt and a beige hat, with a beard and glasses. He is holding a hammer and a metal object, possibly a piece of iron or a tool. The scene erupts with a bright fiery flame of red sparks.", 29 | "video_path": "examples/training/control/wan/image_condition/validation_dataset/3.mp4", 30 | "num_inference_steps": 30, 31 | "num_frames": 49, 32 | "height": 480, 33 | "width": 832 34 | } 35 | ] 36 | } 37 | -------------------------------------------------------------------------------- /finetrainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .args import BaseArgs 2 | from .config import ModelType, TrainingType 3 | from .logging import get_logger 4 | from .models import ModelSpecification 5 | from .trainer import ControlTrainer, SFTTrainer 6 | 7 | 8 | __version__ = "0.2.0.dev0" 9 | -------------------------------------------------------------------------------- /finetrainers/_metadata.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Dict, ForwardRef, List, Optional, Type, Union 3 | 4 | 5 | ParamIdentifierType = ForwardRef("ParamIdentifier") 6 | ContextParallelInputMetadataType = ForwardRef("ContextParallelInputMetadata") 7 | ContextParallelOutputMetadataType = ForwardRef("ContextParallelOutputMetadata") 8 | 9 | _ContextParallelInputType = Dict[ 10 | ParamIdentifierType, Union[ContextParallelInputMetadataType, List[ContextParallelInputMetadataType]] 11 | ] 12 | _ContextParallelOutputType = List[ContextParallelOutputMetadataType] 13 | ContextParallelModelPlan = Union[_ContextParallelInputType, _ContextParallelOutputType] 14 | 15 | 16 | @dataclass(frozen=True) 17 | class ParamId: 18 | """ 19 | A class to identify a parameter of a method. 20 | 21 | Atleast one of `name` or `index` must be provided. 22 | 23 | Attributes: 24 | name (`str`, *optional*): 25 | The name of the parameter. 26 | index (`int`, *optional*): 27 | The index of the parameter in the method signature. Indexing starts at 0 (ignore 28 | the `self` parameter for instance methods). 29 | """ 30 | 31 | name: Optional[str] = None 32 | index: Optional[int] = None 33 | 34 | def __post_init__(self): 35 | if self.name is None and self.index is None: 36 | raise ValueError("At least one of `name` or `index` must be provided.") 37 | 38 | 39 | @dataclass(frozen=True) 40 | class CPInput: 41 | split_dim: int 42 | expected_dims: Optional[int] = None 43 | split_output: bool = False 44 | 45 | 46 | @dataclass(frozen=True) 47 | class CPOutput: 48 | gather_dim: int 49 | expected_dims: Optional[int] = None 50 | 51 | 52 | @dataclass 53 | class TransformerMetadata: 54 | # Mapping of FQN to mapping of input name to ContextParallelModelPlan 55 | cp_plan: Dict[str, ContextParallelModelPlan] = field(default_factory=dict) 56 | 57 | # tp_plan # TODO(aryan) 58 | 59 | 60 | class TransformerRegistry: 61 | _registry = {} 62 | 63 | @classmethod 64 | def register(cls, model_class: Type, metadata: TransformerMetadata): 65 | cls._registry[model_class] = metadata 66 | 67 | @classmethod 68 | def get(cls, model_class: Type) -> TransformerMetadata: 69 | if model_class not in cls._registry: 70 | raise ValueError(f"Model class {model_class} not registered.") 71 | return cls._registry[model_class] 72 | -------------------------------------------------------------------------------- /finetrainers/config.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Type 3 | 4 | from .models import ModelSpecification 5 | from .models.cogvideox import CogVideoXModelSpecification 6 | from .models.cogview4 import CogView4ControlModelSpecification, CogView4ModelSpecification 7 | from .models.flux import FluxModelSpecification 8 | from .models.hunyuan_video import HunyuanVideoModelSpecification 9 | from .models.ltx_video import LTXVideoModelSpecification 10 | from .models.wan import WanControlModelSpecification, WanModelSpecification 11 | 12 | 13 | class ModelType(str, Enum): 14 | COGVIDEOX = "cogvideox" 15 | COGVIEW4 = "cogview4" 16 | FLUX = "flux" 17 | HUNYUAN_VIDEO = "hunyuan_video" 18 | LTX_VIDEO = "ltx_video" 19 | WAN = "wan" 20 | 21 | 22 | class TrainingType(str, Enum): 23 | # SFT 24 | LORA = "lora" 25 | FULL_FINETUNE = "full-finetune" 26 | 27 | # Control 28 | CONTROL_LORA = "control-lora" 29 | CONTROL_FULL_FINETUNE = "control-full-finetune" 30 | 31 | 32 | SUPPORTED_MODEL_CONFIGS = { 33 | # TODO(aryan): autogenerate this 34 | # SFT 35 | ModelType.COGVIDEOX: { 36 | TrainingType.LORA: CogVideoXModelSpecification, 37 | TrainingType.FULL_FINETUNE: CogVideoXModelSpecification, 38 | }, 39 | ModelType.COGVIEW4: { 40 | TrainingType.LORA: CogView4ModelSpecification, 41 | TrainingType.FULL_FINETUNE: CogView4ModelSpecification, 42 | TrainingType.CONTROL_LORA: CogView4ControlModelSpecification, 43 | TrainingType.CONTROL_FULL_FINETUNE: CogView4ControlModelSpecification, 44 | }, 45 | ModelType.FLUX: { 46 | TrainingType.LORA: FluxModelSpecification, 47 | TrainingType.FULL_FINETUNE: FluxModelSpecification, 48 | }, 49 | ModelType.HUNYUAN_VIDEO: { 50 | TrainingType.LORA: HunyuanVideoModelSpecification, 51 | TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification, 52 | }, 53 | ModelType.LTX_VIDEO: { 54 | TrainingType.LORA: LTXVideoModelSpecification, 55 | TrainingType.FULL_FINETUNE: LTXVideoModelSpecification, 56 | }, 57 | ModelType.WAN: { 58 | TrainingType.LORA: WanModelSpecification, 59 | TrainingType.FULL_FINETUNE: WanModelSpecification, 60 | TrainingType.CONTROL_LORA: WanControlModelSpecification, 61 | TrainingType.CONTROL_FULL_FINETUNE: WanControlModelSpecification, 62 | }, 63 | } 64 | 65 | 66 | def _get_model_specifiction_cls(model_name: str, training_type: str) -> Type[ModelSpecification]: 67 | if model_name not in SUPPORTED_MODEL_CONFIGS: 68 | raise ValueError( 69 | f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}" 70 | ) 71 | if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]: 72 | raise ValueError( 73 | f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}" 74 | ) 75 | return SUPPORTED_MODEL_CONFIGS[model_name][training_type] 76 | -------------------------------------------------------------------------------- /finetrainers/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} 5 | 6 | FINETRAINERS_LOG_LEVEL = os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO") 7 | FINETRAINERS_ATTN_PROVIDER = os.environ.get("FINETRAINERS_ATTN_PROVIDER", "native") 8 | FINETRAINERS_ATTN_CHECKS = os.getenv("FINETRAINERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES 9 | FINETRAINERS_ENABLE_TIMING = os.getenv("FINETRAINERS_ENABLE_TIMING", "1") in ENV_VARS_TRUE_VALUES 10 | 11 | DEFAULT_HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] 12 | DEFAULT_WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] 13 | DEFAULT_FRAME_BUCKETS = [49] 14 | 15 | DEFAULT_IMAGE_RESOLUTION_BUCKETS = [] 16 | for height in DEFAULT_HEIGHT_BUCKETS: 17 | for width in DEFAULT_WIDTH_BUCKETS: 18 | DEFAULT_IMAGE_RESOLUTION_BUCKETS.append((height, width)) 19 | 20 | DEFAULT_VIDEO_RESOLUTION_BUCKETS = [] 21 | for frames in DEFAULT_FRAME_BUCKETS: 22 | for height in DEFAULT_HEIGHT_BUCKETS: 23 | for width in DEFAULT_WIDTH_BUCKETS: 24 | DEFAULT_VIDEO_RESOLUTION_BUCKETS.append((frames, height, width)) 25 | 26 | PRECOMPUTED_DIR_NAME = "precomputed" 27 | PRECOMPUTED_CONDITIONS_DIR_NAME = "conditions" 28 | PRECOMPUTED_LATENTS_DIR_NAME = "latents" 29 | 30 | MODEL_DESCRIPTION = r""" 31 | \# {model_id} {training_type} finetune 32 | 33 | 34 | 35 | \#\# Model Description 36 | 37 | This model is a {training_type} of the `{model_id}` model. 38 | 39 | This model was trained using the `fine-video-trainers` library - a repository containing memory-optimized scripts for training video models with [Diffusers](https://github.com/huggingface/diffusers). 40 | 41 | \#\# Download model 42 | 43 | [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. 44 | 45 | \#\# Usage 46 | 47 | Requires [🧨 Diffusers](https://github.com/huggingface/diffusers) installed. 48 | 49 | ```python 50 | {model_example} 51 | ``` 52 | 53 | For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. 54 | 55 | \#\# License 56 | 57 | Please adhere to the license of the base model. 58 | """.strip() 59 | 60 | _COMMON_BEGINNING_PHRASES = ( 61 | "This video", 62 | "The video", 63 | "This clip", 64 | "The clip", 65 | "The animation", 66 | "This image", 67 | "The image", 68 | "This picture", 69 | "The picture", 70 | ) 71 | _COMMON_CONTINUATION_WORDS = ("shows", "depicts", "features", "captures", "highlights", "introduces", "presents") 72 | 73 | COMMON_LLM_START_PHRASES = ( 74 | "In the video,", 75 | "In this video,", 76 | "In this video clip,", 77 | "In the clip,", 78 | "Caption:", 79 | *( 80 | f"{beginning} {continuation}" 81 | for beginning in _COMMON_BEGINNING_PHRASES 82 | for continuation in _COMMON_CONTINUATION_WORDS 83 | ), 84 | ) 85 | 86 | SUPPORTED_IMAGE_FILE_EXTENSIONS = ("jpg", "jpeg", "png") 87 | SUPPORTED_VIDEO_FILE_EXTENSIONS = ("mp4", "mov") 88 | -------------------------------------------------------------------------------- /finetrainers/data/__init__.py: -------------------------------------------------------------------------------- 1 | from ._artifact import ImageArtifact, VideoArtifact 2 | from .dataloader import DPDataLoader 3 | from .dataset import ( 4 | ImageCaptionFilePairDataset, 5 | ImageFileCaptionFileListDataset, 6 | ImageFolderDataset, 7 | ImageWebDataset, 8 | ValidationDataset, 9 | VideoCaptionFilePairDataset, 10 | VideoFileCaptionFileListDataset, 11 | VideoFolderDataset, 12 | VideoWebDataset, 13 | combine_datasets, 14 | initialize_dataset, 15 | wrap_iterable_dataset_for_preprocessing, 16 | ) 17 | from .precomputation import ( 18 | InMemoryDataIterable, 19 | InMemoryDistributedDataPreprocessor, 20 | InMemoryOnceDataIterable, 21 | PrecomputedDataIterable, 22 | PrecomputedDistributedDataPreprocessor, 23 | PrecomputedOnceDataIterable, 24 | initialize_preprocessor, 25 | ) 26 | from .sampler import ResolutionSampler 27 | -------------------------------------------------------------------------------- /finetrainers/data/_artifact.py: -------------------------------------------------------------------------------- 1 | # ===== THIS FILE ONLY EXISTS FOR THE TIME BEING SINCE I DID NOT KNOW WHERE TO PUT IT ===== 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, List 5 | 6 | from PIL.Image import Image 7 | 8 | 9 | @dataclass 10 | class Artifact: 11 | type: str 12 | value: Any 13 | file_extension: str 14 | 15 | 16 | @dataclass 17 | class ImageArtifact(Artifact): 18 | value: Image 19 | 20 | def __init__(self, value: Image): 21 | super().__init__(type="image", value=value, file_extension="png") 22 | 23 | 24 | @dataclass 25 | class VideoArtifact(Artifact): 26 | value: List[Image] 27 | 28 | def __init__(self, value: List[Image]): 29 | super().__init__(type="video", value=value, file_extension="mp4") 30 | -------------------------------------------------------------------------------- /finetrainers/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from typing import Any, Dict 3 | 4 | import torch.distributed.checkpoint.stateful 5 | import torchdata.stateful_dataloader 6 | 7 | from finetrainers.logging import get_logger 8 | 9 | 10 | logger = get_logger() 11 | 12 | 13 | class DPDataLoader(torchdata.stateful_dataloader.StatefulDataLoader, torch.distributed.checkpoint.stateful.Stateful): 14 | def __init__( 15 | self, 16 | rank: int, 17 | dataset: torch.utils.data.IterableDataset, 18 | batch_size: int = 1, 19 | num_workers: int = 0, 20 | collate_fn=None, 21 | ) -> None: 22 | super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn) 23 | 24 | self._dp_rank = rank 25 | self._rank_id = f"dp_rank_{rank}" 26 | 27 | def state_dict(self) -> Dict[str, Any]: 28 | # Store state only for dp rank to avoid replicating the same state across other dimensions 29 | return {self._rank_id: pickle.dumps(super().state_dict())} 30 | 31 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 32 | # State being empty is valid 33 | if not state_dict: 34 | return 35 | 36 | if self._rank_id not in state_dict: 37 | logger.warning(f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}") 38 | return 39 | 40 | super().load_state_dict(pickle.loads(state_dict[self._rank_id])) 41 | -------------------------------------------------------------------------------- /finetrainers/data/sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple 2 | 3 | import torch 4 | 5 | 6 | class ResolutionSampler: 7 | def __init__(self, batch_size: int = 1, dim_keys: Dict[str, Tuple[int, ...]] = None) -> None: 8 | self.batch_size = batch_size 9 | self.dim_keys = dim_keys 10 | assert dim_keys is not None, "dim_keys must be provided" 11 | 12 | self._chosen_leader_key = None 13 | self._unsatisfied_buckets: Dict[Tuple[int, ...], List[Dict[Any, Any]]] = {} 14 | self._satisfied_buckets: List[Dict[Any, Any]] = [] 15 | 16 | def consume(self, *dict_items: Dict[Any, Any]) -> None: 17 | if self._chosen_leader_key is None: 18 | self._determine_leader_item(*dict_items) 19 | self._update_buckets(*dict_items) 20 | 21 | def get_batch(self) -> List[Dict[str, Any]]: 22 | return list(zip(*self._satisfied_buckets.pop(-1))) 23 | 24 | @property 25 | def is_ready(self) -> bool: 26 | return len(self._satisfied_buckets) > 0 27 | 28 | def _determine_leader_item(self, *dict_items: Dict[Any, Any]) -> None: 29 | num_observed = 0 30 | for dict_item in dict_items: 31 | for key in self.dim_keys.keys(): 32 | if key in dict_item.keys(): 33 | self._chosen_leader_key = key 34 | if not torch.is_tensor(dict_item[key]): 35 | raise ValueError(f"Leader key {key} must be a tensor") 36 | num_observed += 1 37 | if num_observed > 1: 38 | raise ValueError( 39 | f"Only one leader key is allowed in provided list of data dictionaries. Found {num_observed} leader keys" 40 | ) 41 | if self._chosen_leader_key is None: 42 | raise ValueError("No leader key found in provided list of data dictionaries") 43 | 44 | def _update_buckets(self, *dict_items: Dict[Any, Any]) -> None: 45 | chosen_value = [ 46 | dict_item[self._chosen_leader_key] 47 | for dict_item in dict_items 48 | if self._chosen_leader_key in dict_item.keys() 49 | ] 50 | if len(chosen_value) == 0: 51 | raise ValueError(f"Leader key {self._chosen_leader_key} not found in provided list of data dictionaries") 52 | chosen_value = chosen_value[0] 53 | dims = tuple(chosen_value.size(x) for x in self.dim_keys[self._chosen_leader_key]) 54 | if dims not in self._unsatisfied_buckets: 55 | self._unsatisfied_buckets[dims] = [] 56 | self._unsatisfied_buckets[dims].append(dict_items) 57 | if len(self._unsatisfied_buckets[dims]) == self.batch_size: 58 | self._satisfied_buckets.append(self._unsatisfied_buckets.pop(dims)) 59 | -------------------------------------------------------------------------------- /finetrainers/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion import flow_match_target, flow_match_xt 2 | from .image import ( 3 | bicubic_resize_image, 4 | center_crop_image, 5 | find_nearest_resolution_image, 6 | resize_crop_image, 7 | resize_to_nearest_bucket_image, 8 | ) 9 | from .normalization import normalize 10 | from .text import convert_byte_str_to_str, dropout_caption, dropout_embeddings_to_zero, remove_prefix 11 | from .video import ( 12 | bicubic_resize_video, 13 | center_crop_video, 14 | find_nearest_video_resolution, 15 | resize_crop_video, 16 | resize_to_nearest_bucket_video, 17 | ) 18 | -------------------------------------------------------------------------------- /finetrainers/functional/diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def flow_match_xt(x0: torch.Tensor, n: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 5 | r"""Forward process of flow matching.""" 6 | return (1.0 - t) * x0 + t * n 7 | 8 | 9 | def flow_match_target(n: torch.Tensor, x0: torch.Tensor) -> torch.Tensor: 10 | r"""Loss target for flow matching.""" 11 | return n - x0 12 | -------------------------------------------------------------------------------- /finetrainers/functional/image.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal, Tuple 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def center_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: 8 | num_channels, height, width = image.shape 9 | crop_h, crop_w = size 10 | if height < crop_h or width < crop_w: 11 | raise ValueError(f"Image size {(height, width)} is smaller than the target size {size}.") 12 | top = (height - crop_h) // 2 13 | left = (width - crop_w) // 2 14 | return image[:, top : top + crop_h, left : left + crop_w] 15 | 16 | 17 | def resize_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: 18 | num_channels, height, width = image.shape 19 | target_h, target_w = size 20 | scale = max(target_h / height, target_w / width) 21 | new_h, new_w = int(height * scale), int(width * scale) 22 | image = F.interpolate(image, size=(new_h, new_w), mode="bilinear", align_corners=False) 23 | return center_crop_image(image, size) 24 | 25 | 26 | def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: 27 | return F.interpolate(image.unsqueeze(0), size=size, mode="bicubic", align_corners=False)[0] 28 | 29 | 30 | def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]: 31 | num_channels, height, width = image.shape 32 | aspect_ratio = width / height 33 | 34 | def aspect_ratio_diff(bucket): 35 | return abs((bucket[1] / bucket[0]) - aspect_ratio), (-bucket[0], -bucket[1]) 36 | 37 | return min(resolution_buckets, key=aspect_ratio_diff) 38 | 39 | 40 | def resize_to_nearest_bucket_image( 41 | image: torch.Tensor, 42 | resolution_buckets: List[Tuple[int, int]], 43 | resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic", 44 | ) -> torch.Tensor: 45 | target_size = find_nearest_resolution_image(image, resolution_buckets) 46 | 47 | if resize_mode == "center_crop": 48 | return center_crop_image(image, target_size) 49 | elif resize_mode == "resize_crop": 50 | return resize_crop_image(image, target_size) 51 | elif resize_mode == "bicubic": 52 | return bicubic_resize_image(image, target_size) 53 | else: 54 | raise ValueError( 55 | f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'." 56 | ) 57 | -------------------------------------------------------------------------------- /finetrainers/functional/normalization.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | 6 | def normalize(x: torch.Tensor, min: float = -1.0, max: float = 1.0, dim: Optional[int] = None) -> torch.Tensor: 7 | """ 8 | Normalize a tensor to the range [min_val, max_val]. 9 | 10 | Args: 11 | x (`torch.Tensor`): 12 | The input tensor to normalize. 13 | min (`float`, defaults to `-1.0`): 14 | The minimum value of the normalized range. 15 | max (`float`, defaults to `1.0`): 16 | The maximum value of the normalized range. 17 | dim (`int`, *optional*): 18 | The dimension along which to normalize. If `None`, the entire tensor is normalized. 19 | 20 | Returns: 21 | The normalized tensor of the same shape as `x`. 22 | """ 23 | if dim is None: 24 | x_min = x.min() 25 | x_max = x.max() 26 | if torch.isclose(x_min, x_max).any(): 27 | x = torch.full_like(x, min) 28 | else: 29 | x = min + (max - min) * (x - x_min) / (x_max - x_min) 30 | else: 31 | x_min = x.amin(dim=dim, keepdim=True) 32 | x_max = x.amax(dim=dim, keepdim=True) 33 | if torch.isclose(x_min, x_max).any(): 34 | x = torch.full_like(x, min) 35 | else: 36 | x = min + (max - min) * (x - x_min) / (x_max - x_min) 37 | return x 38 | -------------------------------------------------------------------------------- /finetrainers/functional/text.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Union 3 | 4 | import torch 5 | 6 | 7 | def convert_byte_str_to_str(s: str, encoding: str = "utf-8") -> str: 8 | """ 9 | Extracts the actual string from a stringified bytes array (common in some webdatasets). 10 | 11 | Example: "b'hello world'" -> "hello world" 12 | """ 13 | try: 14 | s = s[2:-1] 15 | s = s.encode("utf-8").decode(encoding) 16 | except (UnicodeDecodeError, UnicodeEncodeError, IndexError): 17 | pass 18 | return s 19 | 20 | 21 | def dropout_caption(caption: Union[str, List[str]], dropout_p: float = 0) -> Union[str, List[str]]: 22 | if random.random() >= dropout_p: 23 | return caption 24 | if isinstance(caption, str): 25 | return "" 26 | return [""] * len(caption) 27 | 28 | 29 | def dropout_embeddings_to_zero(embed: torch.Tensor, dropout_p: float = 0) -> torch.Tensor: 30 | if random.random() >= dropout_p: 31 | return embed 32 | embed = torch.zeros_like(embed) 33 | return embed 34 | 35 | 36 | def remove_prefix(text: str, prefixes: List[str]) -> str: 37 | for prefix in prefixes: 38 | if text.startswith(prefix): 39 | return text.removeprefix(prefix).strip() 40 | return text 41 | -------------------------------------------------------------------------------- /finetrainers/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention_dispatch import AttentionProvider, attention_dispatch, attention_provider 2 | from .modeling_utils import ControlModelSpecification, ModelSpecification 3 | 4 | 5 | from ._metadata.transformer import register_transformer_metadata # isort: skip 6 | 7 | 8 | register_transformer_metadata() 9 | -------------------------------------------------------------------------------- /finetrainers/models/_metadata/transformer.py: -------------------------------------------------------------------------------- 1 | from diffusers import ( 2 | CogVideoXTransformer3DModel, 3 | CogView4Transformer2DModel, 4 | FluxTransformer2DModel, 5 | WanTransformer3DModel, 6 | ) 7 | 8 | from finetrainers._metadata import CPInput, CPOutput, ParamId, TransformerMetadata, TransformerRegistry 9 | from finetrainers.logging import get_logger 10 | 11 | 12 | logger = get_logger() 13 | 14 | 15 | def register_transformer_metadata(): 16 | # CogVideoX 17 | TransformerRegistry.register( 18 | model_class=CogVideoXTransformer3DModel, 19 | metadata=TransformerMetadata( 20 | cp_plan={ 21 | "": { 22 | ParamId("image_rotary_emb", 5): [CPInput(0, 2), CPInput(0, 2)], 23 | }, 24 | "transformer_blocks.0": { 25 | ParamId("hidden_states", 0): CPInput(1, 3), 26 | ParamId("encoder_hidden_states", 1): CPInput(1, 3), 27 | }, 28 | "proj_out": [CPOutput(1, 3)], 29 | } 30 | ), 31 | ) 32 | 33 | # CogView4 34 | TransformerRegistry.register( 35 | model_class=CogView4Transformer2DModel, 36 | metadata=TransformerMetadata( 37 | cp_plan={ 38 | "patch_embed": { 39 | ParamId(index=0): CPInput(1, 3, split_output=True), 40 | ParamId(index=1): CPInput(1, 3, split_output=True), 41 | }, 42 | "rope": { 43 | ParamId(index=0): CPInput(0, 2, split_output=True), 44 | ParamId(index=1): CPInput(0, 2, split_output=True), 45 | }, 46 | "proj_out": [CPOutput(1, 3)], 47 | } 48 | ), 49 | ) 50 | 51 | # Flux 52 | TransformerRegistry.register( 53 | model_class=FluxTransformer2DModel, 54 | metadata=TransformerMetadata( 55 | cp_plan={ 56 | "": { 57 | ParamId("hidden_states", 0): CPInput(1, 3), 58 | ParamId("encoder_hidden_states", 1): CPInput(1, 3), 59 | ParamId("img_ids", 4): CPInput(0, 2), 60 | ParamId("txt_ids", 5): CPInput(0, 2), 61 | }, 62 | "proj_out": [CPOutput(1, 3)], 63 | } 64 | ), 65 | ) 66 | 67 | # Wan2.1 68 | TransformerRegistry.register( 69 | model_class=WanTransformer3DModel, 70 | metadata=TransformerMetadata( 71 | cp_plan={ 72 | "rope": { 73 | ParamId(index=0): CPInput(2, 4, split_output=True), 74 | }, 75 | "blocks.*": { 76 | ParamId("encoder_hidden_states", 1): CPInput(1, 3), 77 | }, 78 | "blocks.0": { 79 | ParamId("hidden_states", 0): CPInput(1, 3), 80 | }, 81 | "proj_out": [CPOutput(1, 3)], 82 | } 83 | ), 84 | ) 85 | 86 | logger.debug("Metadata for transformer registered") 87 | -------------------------------------------------------------------------------- /finetrainers/models/cogvideox/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_specification import CogVideoXModelSpecification 2 | -------------------------------------------------------------------------------- /finetrainers/models/cogvideox/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from diffusers.models.embeddings import get_3d_rotary_pos_embed 5 | from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid 6 | 7 | 8 | def prepare_rotary_positional_embeddings( 9 | height: int, 10 | width: int, 11 | num_frames: int, 12 | vae_scale_factor_spatial: int = 8, 13 | patch_size: int = 2, 14 | patch_size_t: int = None, 15 | attention_head_dim: int = 64, 16 | device: Optional[torch.device] = None, 17 | base_height: int = 480, 18 | base_width: int = 720, 19 | ) -> Tuple[torch.Tensor, torch.Tensor]: 20 | grid_height = height // (vae_scale_factor_spatial * patch_size) 21 | grid_width = width // (vae_scale_factor_spatial * patch_size) 22 | base_size_width = base_width // (vae_scale_factor_spatial * patch_size) 23 | base_size_height = base_height // (vae_scale_factor_spatial * patch_size) 24 | 25 | if patch_size_t is None: 26 | # CogVideoX 1.0 27 | grid_crops_coords = get_resize_crop_region_for_grid( 28 | (grid_height, grid_width), base_size_width, base_size_height 29 | ) 30 | freqs_cos, freqs_sin = get_3d_rotary_pos_embed( 31 | embed_dim=attention_head_dim, 32 | crops_coords=grid_crops_coords, 33 | grid_size=(grid_height, grid_width), 34 | temporal_size=num_frames, 35 | ) 36 | else: 37 | # CogVideoX 1.5 38 | base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t 39 | 40 | freqs_cos, freqs_sin = get_3d_rotary_pos_embed( 41 | embed_dim=attention_head_dim, 42 | crops_coords=None, 43 | grid_size=(grid_height, grid_width), 44 | temporal_size=base_num_frames, 45 | grid_type="slice", 46 | max_size=(base_size_height, base_size_width), 47 | ) 48 | 49 | freqs_cos = freqs_cos.to(device=device) 50 | freqs_sin = freqs_sin.to(device=device) 51 | return freqs_cos, freqs_sin 52 | -------------------------------------------------------------------------------- /finetrainers/models/cogview4/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_specification import CogView4ModelSpecification 2 | from .control_specification import CogView4ControlModelSpecification 3 | -------------------------------------------------------------------------------- /finetrainers/models/flux/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_specification import FluxModelSpecification 2 | -------------------------------------------------------------------------------- /finetrainers/models/hunyuan_video/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_specification import HunyuanVideoModelSpecification 2 | -------------------------------------------------------------------------------- /finetrainers/models/ltx_video/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_specification import LTXVideoModelSpecification 2 | -------------------------------------------------------------------------------- /finetrainers/models/wan/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_specification import WanModelSpecification 2 | from .control_specification import WanControlModelSpecification 3 | -------------------------------------------------------------------------------- /finetrainers/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Union 3 | 4 | from .accelerate import AccelerateParallelBackend 5 | from .ptd import PytorchDTensorParallelBackend 6 | from .utils import dist_max, dist_mean 7 | 8 | 9 | ParallelBackendType = Union[AccelerateParallelBackend, PytorchDTensorParallelBackend] 10 | 11 | 12 | class ParallelBackendEnum(str, Enum): 13 | ACCELERATE = "accelerate" 14 | PTD = "ptd" 15 | 16 | 17 | def get_parallel_backend_cls(backend: ParallelBackendEnum) -> ParallelBackendType: 18 | if backend == ParallelBackendEnum.ACCELERATE: 19 | return AccelerateParallelBackend 20 | if backend == ParallelBackendEnum.PTD: 21 | return PytorchDTensorParallelBackend 22 | raise ValueError(f"Unknown parallel backend: {backend}") 23 | -------------------------------------------------------------------------------- /finetrainers/parallel/deepspeed.py: -------------------------------------------------------------------------------- 1 | from .base import BaseParallelBackend 2 | 3 | 4 | class DeepspeedParallelBackend(BaseParallelBackend): 5 | def __init__(self): 6 | # TODO(aryan) 7 | raise NotImplementedError("DeepspeedParallelBackend is not implemented yet.") 8 | -------------------------------------------------------------------------------- /finetrainers/parallel/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed._functional_collectives as funcol 3 | import torch.distributed.tensor 4 | 5 | 6 | def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: 7 | if isinstance(x, torch.distributed.tensor.DTensor): 8 | # functional collectives do not support DTensor inputs 9 | x = x.full_tensor() 10 | assert x.numel() == 1 # required by `.item()` 11 | return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() 12 | 13 | 14 | def dist_max(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: 15 | return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.MAX.name, mesh=mesh) 16 | 17 | 18 | def dist_mean(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: 19 | return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.AVG.name, mesh=mesh) 20 | -------------------------------------------------------------------------------- /finetrainers/patches/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | import torch 4 | 5 | from .dependencies.diffusers.peft import load_lora_weights 6 | 7 | 8 | if TYPE_CHECKING: 9 | from finetrainers.args import BaseArgsType 10 | from finetrainers.parallel import ParallelBackendType 11 | 12 | 13 | def perform_patches_for_training(args: "BaseArgsType", parallel_backend: "ParallelBackendType") -> None: 14 | # To avoid circular imports 15 | from finetrainers.config import ModelType, TrainingType 16 | 17 | from .dependencies.diffusers import patch 18 | 19 | # Modeling patches 20 | patch_scaled_dot_product_attention() 21 | 22 | patch.patch_diffusers_rms_norm_forward() 23 | 24 | # LTX Video patches 25 | if args.model_name == ModelType.LTX_VIDEO: 26 | from .models.ltx_video import patch 27 | 28 | patch.patch_transformer_forward() 29 | if parallel_backend.tensor_parallel_enabled: 30 | patch.patch_apply_rotary_emb_for_tp_compatibility() 31 | 32 | # Wan patches 33 | if args.model_name == ModelType.WAN and "transformer" in args.layerwise_upcasting_modules: 34 | from .models.wan import patch 35 | 36 | patch.patch_time_text_image_embedding_forward() 37 | 38 | # LoRA patches 39 | if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0: 40 | from .dependencies.peft import patch 41 | 42 | patch.patch_peft_move_adapter_to_device_of_base_layer() 43 | 44 | 45 | def perform_patches_for_inference(args: "BaseArgsType", parallel_backend: "ParallelBackendType") -> None: 46 | # To avoid circular imports 47 | from .dependencies.diffusers import patch 48 | 49 | # Modeling patches 50 | patch_scaled_dot_product_attention() 51 | 52 | patch.patch_diffusers_rms_norm_forward() 53 | 54 | 55 | def patch_scaled_dot_product_attention(): 56 | from finetrainers.models.attention_dispatch import attention_dispatch 57 | 58 | torch.nn.functional.scaled_dot_product_attention = attention_dispatch 59 | -------------------------------------------------------------------------------- /finetrainers/patches/dependencies/diffusers/control.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import List, Union 3 | 4 | import torch 5 | from diffusers.hooks import HookRegistry, ModelHook 6 | 7 | 8 | _CONTROL_CHANNEL_CONCATENATE_HOOK = "FINETRAINERS_CONTROL_CHANNEL_CONCATENATE_HOOK" 9 | 10 | 11 | class ControlChannelConcatenateHook(ModelHook): 12 | def __init__(self, input_names: List[str], inputs: List[torch.Tensor], dims: List[int]): 13 | self.input_names = input_names 14 | self.inputs = inputs 15 | self.dims = dims 16 | 17 | def pre_forward(self, module: torch.nn.Module, *args, **kwargs): 18 | for input_name, input_tensor, dim in zip(self.input_names, self.inputs, self.dims): 19 | original_tensor = args[input_name] if isinstance(input_name, int) else kwargs[input_name] 20 | control_tensor = torch.cat([original_tensor, input_tensor], dim=dim) 21 | if isinstance(input_name, int): 22 | args[input_name] = control_tensor 23 | else: 24 | kwargs[input_name] = control_tensor 25 | return args, kwargs 26 | 27 | 28 | @contextmanager 29 | def control_channel_concat( 30 | module: torch.nn.Module, input_names: List[Union[int, str]], inputs: List[torch.Tensor], dims: List[int] 31 | ): 32 | registry = HookRegistry.check_if_exists_or_initialize(module) 33 | hook = ControlChannelConcatenateHook(input_names, inputs, dims) 34 | registry.register_hook(hook, _CONTROL_CHANNEL_CONCATENATE_HOOK) 35 | yield 36 | registry.remove_hook(_CONTROL_CHANNEL_CONCATENATE_HOOK, recurse=False) 37 | -------------------------------------------------------------------------------- /finetrainers/patches/dependencies/diffusers/patch.py: -------------------------------------------------------------------------------- 1 | def patch_diffusers_rms_norm_forward() -> None: 2 | import diffusers.models.normalization 3 | 4 | from .rms_norm import _patched_rms_norm_forward 5 | 6 | diffusers.models.normalization.RMSNorm.forward = _patched_rms_norm_forward 7 | -------------------------------------------------------------------------------- /finetrainers/patches/dependencies/diffusers/peft.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import safetensors.torch 6 | from diffusers import DiffusionPipeline 7 | from diffusers.loaders.lora_pipeline import _LOW_CPU_MEM_USAGE_DEFAULT_LORA 8 | from huggingface_hub import repo_exists, snapshot_download 9 | from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict 10 | 11 | from finetrainers.logging import get_logger 12 | from finetrainers.utils import find_files 13 | 14 | 15 | logger = get_logger() 16 | 17 | 18 | def load_lora_weights( 19 | pipeline: DiffusionPipeline, pretrained_model_name_or_path: str, adapter_name: Optional[str] = None, **kwargs 20 | ) -> None: 21 | low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) 22 | 23 | is_local_file_path = Path(pretrained_model_name_or_path).is_dir() 24 | if not is_local_file_path: 25 | does_repo_exist = repo_exists(pretrained_model_name_or_path, repo_type="model") 26 | if not does_repo_exist: 27 | raise ValueError(f"Model repo {pretrained_model_name_or_path} does not exist on the Hub or locally.") 28 | else: 29 | pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path, repo_type="model") 30 | 31 | prefix = "transformer" 32 | state_dict = pipeline.lora_state_dict(pretrained_model_name_or_path) 33 | state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} 34 | 35 | file_list = find_files(pretrained_model_name_or_path, "*.safetensors", depth=1) 36 | if len(file_list) == 0: 37 | raise ValueError(f"No .safetensors files found in {pretrained_model_name_or_path}.") 38 | if len(file_list) > 1: 39 | logger.warning( 40 | f"Multiple .safetensors files found in {pretrained_model_name_or_path}. Using the first one: {file_list[0]}." 41 | ) 42 | with safetensors.torch.safe_open(file_list[0], framework="pt") as f: 43 | metadata = f.metadata() 44 | metadata = json.loads(metadata["lora_config"]) 45 | 46 | transformer = pipeline.transformer 47 | if adapter_name is None: 48 | adapter_name = "default" 49 | 50 | lora_config = LoraConfig(**metadata) 51 | inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) 52 | result = set_peft_model_state_dict( 53 | transformer, 54 | state_dict, 55 | adapter_name=adapter_name, 56 | ignore_mismatched_sizes=False, 57 | low_cpu_mem_usage=low_cpu_mem_usage, 58 | ) 59 | logger.debug( 60 | f"Loaded LoRA weights from {pretrained_model_name_or_path} into {pipeline.__class__.__name__}. Result: {result}" 61 | ) 62 | -------------------------------------------------------------------------------- /finetrainers/patches/dependencies/diffusers/rms_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diffusers.utils import is_torch_npu_available, is_torch_version 4 | 5 | 6 | def _patched_rms_norm_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 7 | if is_torch_npu_available(): 8 | import torch_npu 9 | 10 | if self.weight is not None: 11 | # convert into half-precision if necessary 12 | if self.weight.dtype in [torch.float16, torch.bfloat16]: 13 | hidden_states = hidden_states.to(self.weight.dtype) 14 | hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0] 15 | if self.bias is not None: 16 | hidden_states = hidden_states + self.bias 17 | elif is_torch_version(">=", "2.4"): 18 | ### ===== ======= 19 | input_dtype = hidden_states.dtype 20 | if self.weight is not None: 21 | # convert into half-precision if necessary 22 | if self.weight.dtype in [torch.float16, torch.bfloat16]: 23 | hidden_states = hidden_states.to(self.weight.dtype) 24 | hidden_states = nn.functional.rms_norm( 25 | hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps 26 | ) 27 | if self.bias is not None: 28 | hidden_states = hidden_states + self.bias 29 | hidden_states = hidden_states.to(input_dtype) 30 | ### ===== ===== 31 | else: 32 | input_dtype = hidden_states.dtype 33 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) 34 | hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 35 | 36 | if self.weight is not None: 37 | # convert into half-precision if necessary 38 | if self.weight.dtype in [torch.float16, torch.bfloat16]: 39 | hidden_states = hidden_states.to(self.weight.dtype) 40 | hidden_states = hidden_states * self.weight 41 | if self.bias is not None: 42 | hidden_states = hidden_states + self.bias 43 | else: 44 | hidden_states = hidden_states.to(input_dtype) 45 | 46 | return hidden_states 47 | -------------------------------------------------------------------------------- /finetrainers/patches/dependencies/peft/patch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from peft.tuners.tuners_utils import BaseTunerLayer 4 | 5 | from finetrainers.patches.utils import DisableTensorToDtype 6 | 7 | 8 | def patch_peft_move_adapter_to_device_of_base_layer() -> None: 9 | _perform_patch_move_adapter_to_device_of_base_layer() 10 | 11 | 12 | def _perform_patch_move_adapter_to_device_of_base_layer() -> None: 13 | BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer( 14 | BaseTunerLayer._move_adapter_to_device_of_base_layer 15 | ) 16 | 17 | 18 | def _patched_move_adapter_to_device_of_base_layer(func) -> None: 19 | # TODO(aryan): This is really unsafe probably and may break things. It works for now, but revisit and refactor. 20 | @functools.wraps(func) 21 | def wrapper(self, *args, **kwargs): 22 | with DisableTensorToDtype(): 23 | return func(self, *args, **kwargs) 24 | 25 | return wrapper 26 | -------------------------------------------------------------------------------- /finetrainers/patches/models/wan/patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import diffusers 4 | import torch 5 | 6 | 7 | def patch_time_text_image_embedding_forward() -> None: 8 | _patch_time_text_image_embedding_forward() 9 | 10 | 11 | def _patch_time_text_image_embedding_forward() -> None: 12 | diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding.forward = ( 13 | _patched_WanTimeTextImageEmbedding_forward 14 | ) 15 | 16 | 17 | def _patched_WanTimeTextImageEmbedding_forward( 18 | self, 19 | timestep: torch.Tensor, 20 | encoder_hidden_states: torch.Tensor, 21 | encoder_hidden_states_image: Optional[torch.Tensor] = None, 22 | ): 23 | # Some code has been removed compared to original implementation in Diffusers 24 | # Also, timestep is typed as that of encoder_hidden_states 25 | timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states) 26 | temb = self.time_embedder(timestep).type_as(encoder_hidden_states) 27 | timestep_proj = self.time_proj(self.act_fn(temb)) 28 | 29 | encoder_hidden_states = self.text_embedder(encoder_hidden_states) 30 | if encoder_hidden_states_image is not None: 31 | encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) 32 | 33 | return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image 34 | -------------------------------------------------------------------------------- /finetrainers/patches/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DisableTensorToDtype: 5 | def __enter__(self): 6 | self.original_to = torch.Tensor.to 7 | 8 | def modified_to(tensor, *args, **kwargs): 9 | # remove dtype from args if present 10 | args = [arg if not isinstance(arg, torch.dtype) else None for arg in args] 11 | if "dtype" in kwargs: 12 | kwargs.pop("dtype") 13 | return self.original_to(tensor, *args, **kwargs) 14 | 15 | torch.Tensor.to = modified_to 16 | 17 | def __exit__(self, *args, **kwargs): 18 | torch.Tensor.to = self.original_to 19 | -------------------------------------------------------------------------------- /finetrainers/processors/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from .base import ProcessorMixin 4 | from .canny import CannyProcessor 5 | from .clip import CLIPPooledProcessor 6 | from .glm import CogView4GLMProcessor 7 | from .llama import LlamaProcessor 8 | from .t5 import T5Processor 9 | from .text import CaptionEmbeddingDropoutProcessor, CaptionTextDropoutProcessor 10 | 11 | 12 | class CopyProcessor(ProcessorMixin): 13 | r"""Processor that copies the input data unconditionally to the output.""" 14 | 15 | def __init__(self, output_names: List[str] = None, input_names: Optional[Dict[str, Any]] = None): 16 | super().__init__() 17 | 18 | self.output_names = output_names 19 | self.input_names = input_names 20 | assert len(output_names) == 1 21 | 22 | def forward(self, input: Any) -> Any: 23 | return {self.output_names[0]: input} 24 | -------------------------------------------------------------------------------- /finetrainers/processors/base.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Dict, List 3 | 4 | 5 | class ProcessorMixin: 6 | def __init__(self) -> None: 7 | self._forward_parameter_names = inspect.signature(self.forward).parameters.keys() 8 | self.output_names: List[str] = None 9 | self.input_names: Dict[str, Any] = None 10 | 11 | def __call__(self, *args, **kwargs) -> Any: 12 | shallow_copy_kwargs = dict(kwargs.items()) 13 | if self.input_names is not None: 14 | for k, v in self.input_names.items(): 15 | if k in shallow_copy_kwargs: 16 | shallow_copy_kwargs[v] = shallow_copy_kwargs.pop(k) 17 | acceptable_kwargs = {k: v for k, v in shallow_copy_kwargs.items() if k in self._forward_parameter_names} 18 | output = self.forward(*args, **acceptable_kwargs) 19 | if "__drop__" in output: 20 | output.pop("__drop__") 21 | return output 22 | 23 | def forward(self, *args, **kwargs) -> Dict[str, Any]: 24 | raise NotImplementedError("ProcessorMixin::forward method should be implemented by the subclass.") 25 | -------------------------------------------------------------------------------- /finetrainers/processors/canny.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Union 2 | 3 | import numpy as np 4 | import PIL.Image 5 | import torch 6 | 7 | from ..utils.import_utils import is_kornia_available 8 | from .base import ProcessorMixin 9 | 10 | 11 | if is_kornia_available(): 12 | import kornia 13 | 14 | 15 | class CannyProcessor(ProcessorMixin): 16 | r""" 17 | Processor for obtaining the Canny edge detection of an image. 18 | 19 | Args: 20 | output_names (`List[str]`): 21 | The names of the outputs that the processor should return. The first output is the Canny edge detection of 22 | the input image. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | output_names: List[str] = None, 28 | input_names: Optional[Dict[str, Any]] = None, 29 | device: Optional[torch.device] = None, 30 | ): 31 | super().__init__() 32 | 33 | self.output_names = output_names 34 | self.input_names = input_names 35 | self.device = device 36 | assert len(output_names) == 1 37 | 38 | def forward(self, input: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]]) -> torch.Tensor: 39 | r""" 40 | Obtain the Canny edge detection of the input image. 41 | 42 | Args: 43 | input (`torch.Tensor`, `PIL.Image.Image`, or `List[PIL.Image.Image]`): 44 | The input tensor, image or list of images for which the Canny edge detection should be obtained. 45 | If a tensor, must be a 3D (CHW) or 4D (BCHW) or 5D (BTCHW) tensor. The input tensor should have 46 | values in the range [0, 1]. 47 | 48 | Returns: 49 | torch.Tensor: 50 | The Canny edge detection of the input image. The output has the same shape as the input tensor. If 51 | the input is an image, the output is a 3D tensor. If the input is a list of images, the output is a 5D 52 | tensor. The output tensor has values in the range [0, 1]. 53 | """ 54 | if isinstance(input, PIL.Image.Image): 55 | input = kornia.utils.image.image_to_tensor(np.array(input)).unsqueeze(0) / 255.0 56 | input = input.to(self.device) 57 | output = kornia.filters.canny(input)[1].repeat(1, 3, 1, 1).squeeze(0) 58 | elif isinstance(input, list): 59 | input = kornia.utils.image.image_list_to_tensor([np.array(img) for img in input]) / 255.0 60 | output = kornia.filters.canny(input)[1].repeat(1, 3, 1, 1) 61 | else: 62 | ndim = input.ndim 63 | assert ndim in [3, 4, 5] 64 | 65 | batch_size = 1 if ndim == 3 else input.size(0) 66 | 67 | if ndim == 3: 68 | input = input.unsqueeze(0) # [C, H, W] -> [1, C, H, W] 69 | elif ndim == 5: 70 | input = input.flatten(0, 1) # [B, F, C, H, W] -> [B*F, C, H, W] 71 | 72 | output = kornia.filters.canny(input)[1].repeat(1, 3, 1, 1) 73 | output = output[0] if ndim == 3 else output.unflatten(0, (batch_size, -1)) if ndim == 5 else output 74 | 75 | # TODO(aryan): think about how one can pass parameters to the underlying function from 76 | # a UI perspective. It's important to think about ProcessorMixin in terms of a Graph-based 77 | # data processing pipeline. 78 | return {self.output_names[0]: output} 79 | -------------------------------------------------------------------------------- /finetrainers/processors/clip.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTokenizerFast 5 | 6 | from .base import ProcessorMixin 7 | 8 | 9 | class CLIPPooledProcessor(ProcessorMixin): 10 | r""" 11 | Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings 12 | and attention masks for the input text. 13 | 14 | Args: 15 | output_names (`List[str]`): 16 | The names of the outputs that the processor should return. The first output is the embeddings of the input 17 | text and the second output is the attention mask for the input text. 18 | """ 19 | 20 | def __init__(self, output_names: List[str] = None, input_names: Optional[Dict[str, Any]] = None) -> None: 21 | super().__init__() 22 | 23 | self.output_names = output_names 24 | self.input_names = input_names 25 | 26 | assert len(output_names) == 1 27 | 28 | def forward( 29 | self, 30 | tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast], 31 | text_encoder: CLIPTextModel, 32 | caption: Union[str, List[str]], 33 | ) -> Tuple[torch.Tensor, torch.Tensor]: 34 | r""" 35 | Encode the input text and return the embeddings and attention mask for the input text. 36 | 37 | Args: 38 | tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`): 39 | The tokenizer used to tokenize the input text. 40 | text_encoder (`LlamaModel`): 41 | The text encoder used to encode the input text. 42 | caption (`Union[str, List[str]]`): 43 | The input text to be encoded. 44 | """ 45 | if isinstance(caption, str): 46 | caption = [caption] 47 | 48 | device = text_encoder.device 49 | dtype = text_encoder.dtype 50 | 51 | text_inputs = tokenizer( 52 | caption, 53 | padding="max_length", 54 | max_length=77, 55 | truncation=True, 56 | return_tensors="pt", 57 | ) 58 | text_input_ids = text_inputs.input_ids.to(device) 59 | 60 | prompt_embeds = text_encoder(text_input_ids, output_hidden_states=False).pooler_output 61 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 62 | 63 | return {self.output_names[0]: prompt_embeds} 64 | -------------------------------------------------------------------------------- /finetrainers/processors/glm.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | import torch 4 | from transformers import AutoTokenizer, GlmModel 5 | 6 | from .base import ProcessorMixin 7 | 8 | 9 | class CogView4GLMProcessor(ProcessorMixin): 10 | r""" 11 | Processor for the GLM family of models. This processor is used to encode text inputs and return the embeddings 12 | and attention masks for the input text. 13 | 14 | This processor is specific to CogView4 but can be used with any other model. 15 | 16 | Args: 17 | output_names (`List[str]`): 18 | The names of the outputs that the processor should return. The first output is the embeddings of the input 19 | text and the second output is the attention mask for the input text. 20 | """ 21 | 22 | def __init__(self, output_names: List[str]): 23 | super().__init__() 24 | 25 | self.output_names = output_names 26 | 27 | assert len(self.output_names) == 1 28 | 29 | def forward( 30 | self, 31 | tokenizer: AutoTokenizer, 32 | text_encoder: GlmModel, 33 | caption: Union[str, List[str]], 34 | max_sequence_length: int, 35 | ) -> Tuple[torch.Tensor, torch.Tensor]: 36 | r""" 37 | Encode the input text and return the embeddings and attention mask for the input text. 38 | 39 | Args: 40 | tokenizer (`AutoTokenizer`): 41 | The tokenizer used to tokenize the input text. 42 | text_encoder (`GlmModel`): 43 | The text encoder used to encode the input text. 44 | caption (`Union[str, List[str]]`): 45 | The input text to be encoded. 46 | max_sequence_length (`int`): 47 | The maximum sequence length of the input text. 48 | """ 49 | if isinstance(caption, str): 50 | caption = [caption] 51 | 52 | device = text_encoder.device 53 | dtype = text_encoder.dtype 54 | 55 | text_inputs = tokenizer( 56 | caption, 57 | padding="longest", 58 | max_length=max_sequence_length, 59 | truncation=True, 60 | add_special_tokens=True, 61 | return_tensors="pt", 62 | ) 63 | text_input_ids = text_inputs.input_ids.to(device) 64 | 65 | current_length = text_input_ids.size(1) 66 | pad_length = 16 - current_length % 16 67 | if pad_length > 0: 68 | pad_ids = text_input_ids.new_full((text_input_ids.shape[0], pad_length), fill_value=tokenizer.pad_token_id) 69 | text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) 70 | 71 | prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True).hidden_states[-2] 72 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 73 | 74 | return {self.output_names[0]: prompt_embeds} 75 | -------------------------------------------------------------------------------- /finetrainers/processors/t5.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | from transformers import T5EncoderModel, T5Tokenizer, T5TokenizerFast 5 | 6 | from .base import ProcessorMixin 7 | 8 | 9 | class T5Processor(ProcessorMixin): 10 | r""" 11 | Processor for the T5 family of models. This processor is used to encode text inputs and return the embeddings 12 | and attention masks for the input text. 13 | 14 | Args: 15 | output_names (`List[str]`): 16 | The names of the outputs that the processor should return. The first output is the embeddings of the input 17 | text and the second output is the attention mask for the input text. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | output_names: List[str], 23 | input_names: Optional[Dict[str, Any]] = None, 24 | *, 25 | use_attention_mask: bool = False, 26 | ): 27 | super().__init__() 28 | 29 | self.output_names = output_names 30 | self.input_names = input_names 31 | self.use_attention_mask = use_attention_mask 32 | 33 | if input_names is not None: 34 | assert len(input_names) <= 4 35 | assert len(self.output_names) == 2 36 | 37 | def forward( 38 | self, 39 | tokenizer: Union[T5Tokenizer, T5TokenizerFast], 40 | text_encoder: T5EncoderModel, 41 | caption: Union[str, List[str]], 42 | max_sequence_length: int, 43 | ) -> Tuple[torch.Tensor, torch.Tensor]: 44 | r""" 45 | Encode the input text and return the embeddings and attention mask for the input text. 46 | 47 | Args: 48 | tokenizer (`Union[T5Tokenizer, T5TokenizerFast]`): 49 | The tokenizer used to tokenize the input text. 50 | text_encoder (`T5EncoderModel`): 51 | The text encoder used to encode the input text. 52 | caption (`Union[str, List[str]]`): 53 | The input text to be encoded. 54 | max_sequence_length (`int`): 55 | The maximum sequence length of the input text. 56 | """ 57 | if isinstance(caption, str): 58 | caption = [caption] 59 | 60 | device = text_encoder.device 61 | dtype = text_encoder.dtype 62 | 63 | batch_size = len(caption) 64 | text_inputs = tokenizer( 65 | caption, 66 | padding="max_length", 67 | max_length=max_sequence_length, 68 | truncation=True, 69 | add_special_tokens=True, 70 | return_tensors="pt", 71 | ) 72 | text_input_ids = text_inputs.input_ids 73 | prompt_attention_mask = text_inputs.attention_mask 74 | prompt_attention_mask = prompt_attention_mask.bool().to(device) 75 | 76 | te_mask = None 77 | if self.use_attention_mask: 78 | te_mask = prompt_attention_mask 79 | 80 | prompt_embeds = text_encoder(text_input_ids.to(device), te_mask)[0] 81 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 82 | prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) 83 | 84 | return { 85 | self.output_names[0]: prompt_embeds, 86 | self.output_names[1]: prompt_attention_mask, 87 | } 88 | -------------------------------------------------------------------------------- /finetrainers/processors/text.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import torch 4 | 5 | import finetrainers.functional as FF 6 | 7 | from .base import ProcessorMixin 8 | 9 | 10 | class CaptionTextDropoutProcessor(ProcessorMixin): 11 | def __init__(self, dropout_p: float = 0.0) -> None: 12 | self.dropout_p = dropout_p 13 | 14 | def forward(self, caption: Union[str, List[str]]) -> Union[str, List[str]]: 15 | return FF.dropout_caption(caption, self.dropout_p) 16 | 17 | 18 | class CaptionEmbeddingDropoutProcessor(ProcessorMixin): 19 | def __init__(self, dropout_p: float = 0.0) -> None: 20 | self.dropout_p = dropout_p 21 | 22 | def forward(self, embedding: torch.Tensor) -> torch.Tensor: 23 | return FF.dropout_embeddings_to_zero(embedding, self.dropout_p) 24 | -------------------------------------------------------------------------------- /finetrainers/state.py: -------------------------------------------------------------------------------- 1 | import io 2 | from dataclasses import dataclass, field 3 | from typing import Any, Dict, List 4 | 5 | import torch 6 | import torch.distributed.checkpoint.stateful 7 | 8 | from .parallel import ParallelBackendType 9 | from .utils import get_device_info 10 | 11 | 12 | _device_type, _ = get_device_info() 13 | 14 | 15 | @dataclass 16 | class TrainState(torch.distributed.checkpoint.stateful.Stateful): 17 | step: int = 0 18 | observed_data_samples: int = 0 19 | global_avg_losses: List[float] = field(default_factory=list) 20 | global_max_losses: List[float] = field(default_factory=list) 21 | log_steps: List[int] = field(default_factory=list) 22 | 23 | def state_dict(self) -> Dict[str, Any]: 24 | # Only checkpoint global_avg_losses and global_max_losses per log frequency 25 | # to avoid sync overhead in every iteration. 26 | global_avg_losses_bytes = io.BytesIO() 27 | torch.save(self.global_avg_losses, global_avg_losses_bytes) 28 | global_max_losses_bytes = io.BytesIO() 29 | torch.save(self.global_max_losses, global_max_losses_bytes) 30 | log_steps_bytes = io.BytesIO() 31 | torch.save(self.log_steps, log_steps_bytes) 32 | return { 33 | "step": torch.tensor(self.step, dtype=torch.int32), 34 | "observed_data_samples": torch.tensor(self.observed_data_samples, dtype=torch.int32), 35 | "global_avg_losses": global_avg_losses_bytes, 36 | "global_max_losses": global_max_losses_bytes, 37 | "log_steps": log_steps_bytes, 38 | } 39 | 40 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 41 | state_dict["global_avg_losses"].seek(0) 42 | state_dict["global_max_losses"].seek(0) 43 | state_dict["log_steps"].seek(0) 44 | 45 | self.step = state_dict["step"].item() 46 | self.observed_data_samples = state_dict["observed_data_samples"].item() 47 | self.global_avg_losses = torch.load(state_dict["global_avg_losses"], weights_only=False) 48 | self.global_max_losses = torch.load(state_dict["global_max_losses"], weights_only=False) 49 | self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) 50 | 51 | 52 | @dataclass 53 | class State: 54 | # Parallel state 55 | parallel_backend: ParallelBackendType = None 56 | 57 | # Training state 58 | train_state: TrainState = None 59 | num_trainable_parameters: int = 0 60 | generator: torch.Generator = None 61 | 62 | # Hub state 63 | repo_id: str = None 64 | 65 | # Artifacts state 66 | output_dir: str = None 67 | -------------------------------------------------------------------------------- /finetrainers/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .control_trainer import ControlTrainer 2 | from .sft_trainer import SFTTrainer 3 | -------------------------------------------------------------------------------- /finetrainers/trainer/control_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import ControlFullRankConfig, ControlLowRankConfig 2 | from .trainer import ControlTrainer 3 | -------------------------------------------------------------------------------- /finetrainers/trainer/sft_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import SFTFullRankConfig, SFTLowRankConfig 2 | from .trainer import SFTTrainer 3 | -------------------------------------------------------------------------------- /finetrainers/trainer/sft_trainer/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import TYPE_CHECKING, Any, Dict, List, Union 3 | 4 | from finetrainers.utils import ArgsConfigMixin 5 | 6 | 7 | if TYPE_CHECKING: 8 | from finetrainers.args import BaseArgs 9 | 10 | 11 | class SFTLowRankConfig(ArgsConfigMixin): 12 | r""" 13 | Configuration class for SFT low rank training. 14 | 15 | Args: 16 | rank (int): 17 | Rank of the low rank approximation matrix. 18 | lora_alpha (int): 19 | The lora_alpha parameter to compute scaling factor (lora_alpha / rank) for low-rank matrices. 20 | target_modules (`str` or `List[str]`): 21 | Target modules for the low rank approximation matrices. Can be a regex string or a list of regex strings. 22 | """ 23 | 24 | rank: int = 64 25 | lora_alpha: int = 64 26 | target_modules: Union[str, List[str]] = "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)" 27 | 28 | def add_args(self, parser: argparse.ArgumentParser): 29 | parser.add_argument("--rank", type=int, default=64) 30 | parser.add_argument("--lora_alpha", type=int, default=64) 31 | parser.add_argument( 32 | "--target_modules", 33 | type=str, 34 | nargs="+", 35 | default=["(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)"], 36 | ) 37 | 38 | def validate_args(self, args: "BaseArgs"): 39 | assert self.rank > 0, "Rank must be a positive integer." 40 | assert self.lora_alpha > 0, "lora_alpha must be a positive integer." 41 | 42 | def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): 43 | mapped_args.rank = argparse_args.rank 44 | mapped_args.lora_alpha = argparse_args.lora_alpha 45 | mapped_args.target_modules = ( 46 | argparse_args.target_modules[0] if len(argparse_args.target_modules) == 1 else argparse_args.target_modules 47 | ) 48 | 49 | def to_dict(self) -> Dict[str, Any]: 50 | return {"rank": self.rank, "lora_alpha": self.lora_alpha, "target_modules": self.target_modules} 51 | 52 | 53 | class SFTFullRankConfig(ArgsConfigMixin): 54 | r""" 55 | Configuration class for SFT full rank training. 56 | """ 57 | 58 | def add_args(self, parser: argparse.ArgumentParser): 59 | pass 60 | 61 | def validate_args(self, args: "BaseArgs"): 62 | pass 63 | 64 | def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): 65 | pass 66 | -------------------------------------------------------------------------------- /finetrainers/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler 4 | from transformers import CLIPTokenizer, LlamaTokenizer, LlamaTokenizerFast, T5Tokenizer, T5TokenizerFast 5 | 6 | from .data import ImageArtifact, VideoArtifact 7 | 8 | 9 | ArtifactType = Union[ImageArtifact, VideoArtifact] 10 | SchedulerType = Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler] 11 | TokenizerType = Union[CLIPTokenizer, T5Tokenizer, T5TokenizerFast, LlamaTokenizer, LlamaTokenizerFast] 12 | -------------------------------------------------------------------------------- /finetrainers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Dict, List, Optional, Set, Tuple, Union 3 | 4 | from .activation_checkpoint import apply_activation_checkpointing 5 | from .args_config import ArgsConfigMixin 6 | from .data import determine_batch_size, should_perform_precomputation 7 | from .diffusion import ( 8 | _enable_vae_memory_optimizations, 9 | default_flow_shift, 10 | get_scheduler_alphas, 11 | get_scheduler_sigmas, 12 | prepare_loss_weights, 13 | prepare_sigmas, 14 | prepare_target, 15 | resolution_dependent_timestep_flow_shift, 16 | ) 17 | from .file import delete_files, find_files, string_to_filename 18 | from .hub import save_model_card 19 | from .memory import bytes_to_gigabytes, free_memory, get_memory_statistics, make_contiguous 20 | from .model import resolve_component_cls 21 | from .serialization import safetensors_torch_save_function 22 | from .timing import Timer, TimerDevice 23 | from .torch import ( 24 | align_device_and_dtype, 25 | apply_compile, 26 | clip_grad_norm_, 27 | enable_determinism, 28 | expand_tensor_dims, 29 | get_device_info, 30 | get_submodule_by_name, 31 | get_unwrapped_model_state_dict, 32 | is_compiled_module, 33 | set_requires_grad, 34 | synchronize_device, 35 | unwrap_module, 36 | ) 37 | 38 | 39 | def get_parameter_names(obj: Any, method_name: Optional[str] = None) -> Set[str]: 40 | if method_name is not None: 41 | obj = getattr(obj, method_name) 42 | return {name for name, _ in inspect.signature(obj).parameters.items()} 43 | 44 | 45 | def get_non_null_items( 46 | x: Union[List[Any], Tuple[Any], Dict[str, Any]], 47 | ) -> Union[List[Any], Tuple[Any], Dict[str, Any]]: 48 | if isinstance(x, dict): 49 | return {k: v for k, v in x.items() if v is not None} 50 | if isinstance(x, (list, tuple)): 51 | return type(x)(v for v in x if v is not None) 52 | -------------------------------------------------------------------------------- /finetrainers/utils/_common.py: -------------------------------------------------------------------------------- 1 | DIFFUSERS_TRANSFORMER_BLOCK_NAMES = [ 2 | "transformer_blocks", 3 | "single_transformer_blocks", 4 | "temporal_transformer_blocks", 5 | "blocks", 6 | "layers", 7 | ] 8 | -------------------------------------------------------------------------------- /finetrainers/utils/activation_checkpoint.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from enum import Enum 3 | 4 | import torch 5 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper 6 | 7 | from ._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES 8 | 9 | 10 | class CheckpointType(str, Enum): 11 | FULL = "full" 12 | OPS = "ops" 13 | BLOCK_SKIP = "block_skip" 14 | 15 | 16 | _SELECTIVE_ACTIVATION_CHECKPOINTING_OPS = { 17 | torch.ops.aten.mm.default, 18 | torch.ops.aten._scaled_dot_product_efficient_attention.default, 19 | torch.ops.aten._scaled_dot_product_flash_attention.default, 20 | torch.ops._c10d_functional.reduce_scatter_tensor.default, 21 | } 22 | 23 | 24 | def apply_activation_checkpointing( 25 | module: torch.nn.Module, checkpointing_type: str = CheckpointType.FULL, n_layer: int = 1 26 | ) -> torch.nn.Module: 27 | if checkpointing_type == CheckpointType.FULL: 28 | module = _apply_activation_checkpointing_blocks(module) 29 | elif checkpointing_type == CheckpointType.OPS: 30 | module = _apply_activation_checkpointing_ops(module, _SELECTIVE_ACTIVATION_CHECKPOINTING_OPS) 31 | elif checkpointing_type == CheckpointType.BLOCK_SKIP: 32 | module = _apply_activation_checkpointing_blocks(module, n_layer) 33 | else: 34 | raise ValueError( 35 | f"Checkpointing type '{checkpointing_type}' not supported. Supported types are {CheckpointType.__members__.keys()}" 36 | ) 37 | return module 38 | 39 | 40 | def _apply_activation_checkpointing_blocks(module: torch.nn.Module, n_layer: int = None) -> torch.nn.Module: 41 | for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES: 42 | blocks: torch.nn.Module = getattr(module, transformer_block_name, None) 43 | if blocks is None: 44 | continue 45 | for index, (layer_id, block) in enumerate(blocks.named_children()): 46 | if n_layer is None or index % n_layer == 0: 47 | block = checkpoint_wrapper(block, preserve_rng_state=False) 48 | blocks.register_module(layer_id, block) 49 | return module 50 | 51 | 52 | def _apply_activation_checkpointing_ops(module: torch.nn.Module, ops) -> torch.nn.Module: 53 | from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts 54 | 55 | def _get_custom_policy(meta): 56 | def _custom_policy(ctx, func, *args, **kwargs): 57 | mode = "recompute" if ctx.is_recompute else "forward" 58 | mm_count_key = f"{mode}_mm_count" 59 | if func == torch.ops.aten.mm.default: 60 | meta[mm_count_key] += 1 61 | # Saves output of all compute ops, except every second mm 62 | to_save = func in ops and not (func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0) 63 | return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE 64 | 65 | return _custom_policy 66 | 67 | def selective_checkpointing_context_fn(): 68 | meta = collections.defaultdict(int) 69 | return create_selective_checkpoint_contexts(_get_custom_policy(meta)) 70 | 71 | return checkpoint_wrapper(module, context_fn=selective_checkpointing_context_fn, preserve_rng_state=False) 72 | -------------------------------------------------------------------------------- /finetrainers/utils/args_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import TYPE_CHECKING, Any, Dict 3 | 4 | 5 | if TYPE_CHECKING: 6 | from finetrainers.args import BaseArgs 7 | 8 | 9 | class ArgsConfigMixin: 10 | def add_args(self, parser: argparse.ArgumentParser): 11 | raise NotImplementedError("ArgsConfigMixin::add_args should be implemented by subclasses.") 12 | 13 | def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): 14 | raise NotImplementedError("ArgsConfigMixin::map_args should be implemented by subclasses.") 15 | 16 | def validate_args(self, args: "BaseArgs"): 17 | raise NotImplementedError("ArgsConfigMixin::validate_args should be implemented by subclasses.") 18 | 19 | def to_dict(self) -> Dict[str, Any]: 20 | return {} 21 | -------------------------------------------------------------------------------- /finetrainers/utils/data.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Union 3 | 4 | import torch 5 | 6 | from finetrainers.constants import PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME 7 | from finetrainers.logging import get_logger 8 | 9 | 10 | logger = get_logger() 11 | 12 | 13 | def should_perform_precomputation(precomputation_dir: Union[str, Path]) -> bool: 14 | if isinstance(precomputation_dir, str): 15 | precomputation_dir = Path(precomputation_dir) 16 | conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME 17 | latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME 18 | if conditions_dir.exists() and latents_dir.exists(): 19 | num_files_conditions = len(list(conditions_dir.glob("*.pt"))) 20 | num_files_latents = len(list(latents_dir.glob("*.pt"))) 21 | if num_files_conditions != num_files_latents: 22 | logger.warning( 23 | f"Number of precomputed conditions ({num_files_conditions}) does not match number of precomputed latents ({num_files_latents})." 24 | f"Cleaning up precomputed directories and re-running precomputation." 25 | ) 26 | # clean up precomputed directories 27 | for file in conditions_dir.glob("*.pt"): 28 | file.unlink() 29 | for file in latents_dir.glob("*.pt"): 30 | file.unlink() 31 | return True 32 | if num_files_conditions > 0: 33 | logger.info(f"Found {num_files_conditions} precomputed conditions and latents.") 34 | return False 35 | logger.info("Precomputed data not found. Running precomputation.") 36 | return True 37 | 38 | 39 | def determine_batch_size(x: Any) -> int: 40 | if isinstance(x, list): 41 | return len(x) 42 | if isinstance(x, torch.Tensor): 43 | return x.size(0) 44 | if isinstance(x, dict): 45 | for key in x: 46 | try: 47 | return determine_batch_size(x[key]) 48 | except ValueError: 49 | pass 50 | return 1 51 | raise ValueError("Could not determine batch size from input.") 52 | -------------------------------------------------------------------------------- /finetrainers/utils/file.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import shutil 3 | from pathlib import Path 4 | from typing import List, Union 5 | 6 | from finetrainers.logging import get_logger 7 | 8 | 9 | logger = get_logger() 10 | 11 | 12 | def find_files(root: str, pattern: str, depth: int = 0) -> List[str]: 13 | root_path = pathlib.Path(root) 14 | result_files = [] 15 | 16 | def within_depth(path: pathlib.Path) -> bool: 17 | return len(path.relative_to(root_path).parts) <= depth 18 | 19 | if depth == 0: 20 | result_files.extend([str(file) for file in root_path.glob(pattern)]) 21 | else: 22 | for file in root_path.rglob(pattern): 23 | if not file.is_file() or not within_depth(file.parent): 24 | continue 25 | result_files.append(str(file)) 26 | 27 | return result_files 28 | 29 | 30 | def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None: 31 | if not isinstance(dirs, list): 32 | dirs = [dirs] 33 | dirs = [Path(d) if isinstance(d, str) else d for d in dirs] 34 | logger.debug(f"Deleting files: {dirs}") 35 | for dir in dirs: 36 | if not dir.exists(): 37 | continue 38 | shutil.rmtree(dir, ignore_errors=True) 39 | 40 | 41 | def string_to_filename(s: str) -> str: 42 | return ( 43 | s.replace(" ", "-") 44 | .replace("/", "-") 45 | .replace(":", "-") 46 | .replace(".", "-") 47 | .replace(",", "-") 48 | .replace(";", "-") 49 | .replace("!", "-") 50 | .replace("?", "-") 51 | ) 52 | -------------------------------------------------------------------------------- /finetrainers/utils/hub.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Union 3 | 4 | import numpy as np 5 | import wandb 6 | from diffusers.utils import export_to_video 7 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card 8 | from PIL import Image 9 | 10 | 11 | def save_model_card( 12 | args, 13 | repo_id: str, 14 | videos: Union[List[str], Union[List[Image.Image], List[np.ndarray]]], 15 | validation_prompts: List[str], 16 | fps: int = 30, 17 | ) -> None: 18 | widget_dict = [] 19 | output_dir = str(args.output_dir) 20 | if videos is not None and len(videos) > 0: 21 | for i, (video, validation_prompt) in enumerate(zip(videos, validation_prompts)): 22 | if not isinstance(video, str): 23 | export_to_video(video, os.path.join(output_dir, f"final_video_{i}.mp4"), fps=fps) 24 | widget_dict.append( 25 | { 26 | "text": validation_prompt if validation_prompt else " ", 27 | "output": {"url": video if isinstance(video, str) else f"final_video_{i}.mp4"}, 28 | } 29 | ) 30 | 31 | model_description = f""" 32 | # LoRA Finetune 33 | 34 | 35 | 36 | ## Model description 37 | 38 | This is a lora finetune of model: `{args.pretrained_model_name_or_path}`. 39 | 40 | The model was trained using [`finetrainers`](https://github.com/a-r-r-o-w/finetrainers). 41 | 42 | ## Download model 43 | 44 | [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. 45 | 46 | ## Usage 47 | 48 | Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. 49 | 50 | ```py 51 | TODO 52 | ``` 53 | 54 | For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. 55 | """ 56 | if wandb.run.url: 57 | model_description += f""" 58 | Find out the wandb run URL and training configurations [here]({wandb.run.url}). 59 | """ 60 | 61 | model_card = load_or_create_model_card( 62 | repo_id_or_path=repo_id, 63 | from_training=True, 64 | base_model=args.pretrained_model_name_or_path, 65 | model_description=model_description, 66 | widget=widget_dict, 67 | ) 68 | tags = [ 69 | "text-to-video", 70 | "diffusers-training", 71 | "diffusers", 72 | "lora", 73 | "template:sd-lora", 74 | ] 75 | 76 | model_card = populate_model_card(model_card, tags=tags) 77 | model_card.save(os.path.join(args.output_dir, "README.md")) 78 | -------------------------------------------------------------------------------- /finetrainers/utils/memory.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from typing import Any, Dict, Union 3 | 4 | import torch 5 | 6 | from finetrainers.logging import get_logger 7 | 8 | 9 | logger = get_logger() 10 | 11 | 12 | def get_memory_statistics(precision: int = 3) -> Dict[str, Any]: 13 | memory_allocated = None 14 | memory_reserved = None 15 | max_memory_allocated = None 16 | max_memory_reserved = None 17 | 18 | if torch.cuda.is_available(): 19 | device = torch.cuda.current_device() 20 | memory_allocated = torch.cuda.memory_allocated(device) 21 | memory_reserved = torch.cuda.memory_reserved(device) 22 | max_memory_allocated = torch.cuda.max_memory_allocated(device) 23 | max_memory_reserved = torch.cuda.max_memory_reserved(device) 24 | 25 | elif torch.backends.mps.is_available(): 26 | memory_allocated = torch.mps.current_allocated_memory() 27 | 28 | else: 29 | logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.") 30 | 31 | return { 32 | "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision), 33 | "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision), 34 | "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision), 35 | "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision), 36 | } 37 | 38 | 39 | def bytes_to_gigabytes(x: int) -> float: 40 | if x is not None: 41 | return x / 1024**3 42 | 43 | 44 | def free_memory() -> None: 45 | if torch.cuda.is_available(): 46 | gc.collect() 47 | torch.cuda.empty_cache() 48 | torch.cuda.ipc_collect() 49 | 50 | # TODO(aryan): handle non-cuda devices 51 | 52 | 53 | def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: 54 | if isinstance(x, torch.Tensor): 55 | return x.contiguous() 56 | elif isinstance(x, dict): 57 | return {k: make_contiguous(v) for k, v in x.items()} 58 | else: 59 | return x 60 | -------------------------------------------------------------------------------- /finetrainers/utils/model.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import json 3 | import os 4 | from typing import Optional 5 | 6 | from huggingface_hub import hf_hub_download 7 | 8 | 9 | def resolve_component_cls( 10 | pretrained_model_name_or_path: str, 11 | component_name: str, 12 | filename: str = "model_index.json", 13 | revision: Optional[str] = None, 14 | cache_dir: Optional[str] = None, 15 | ): 16 | pretrained_model_name_or_path = str(pretrained_model_name_or_path) 17 | if os.path.exists(str(pretrained_model_name_or_path)) and os.path.isdir(pretrained_model_name_or_path): 18 | index_path = os.path.join(pretrained_model_name_or_path, filename) 19 | else: 20 | index_path = hf_hub_download( 21 | repo_id=pretrained_model_name_or_path, filename=filename, revision=revision, cache_dir=cache_dir 22 | ) 23 | 24 | with open(index_path, "r") as f: 25 | model_index_dict = json.load(f) 26 | 27 | if component_name not in model_index_dict: 28 | raise ValueError(f"No {component_name} found in the model index dict.") 29 | 30 | cls_config = model_index_dict[component_name] 31 | library = importlib.import_module(cls_config[0]) 32 | return getattr(library, cls_config[1]) 33 | -------------------------------------------------------------------------------- /finetrainers/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import safetensors.torch 4 | 5 | 6 | def safetensors_torch_save_function(weights: Dict[str, Any], filename: str, metadata: Optional[Dict[str, str]] = None): 7 | if metadata is None: 8 | metadata = {} 9 | metadata["format"] = "pt" 10 | safetensors.torch.save_file(weights, filename, metadata) 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | line-length = 119 3 | 4 | [tool.ruff.lint] 5 | # Never enforce `E501` (line length violations). 6 | ignore = ["C901", "E501", "E741", "F402", "F823"] 7 | select = ["C", "E", "F", "I", "W"] 8 | 9 | # Ignore import violations in all `__init__.py` files. 10 | [tool.ruff.lint.per-file-ignores] 11 | "__init__.py" = ["E402", "F401", "F403", "F811"] 12 | 13 | [tool.ruff.lint.isort] 14 | lines-after-imports = 2 15 | known-first-party = [] 16 | 17 | [tool.ruff.format] 18 | # Like Black, use double quotes for strings. 19 | quote-style = "double" 20 | 21 | # Like Black, indent with spaces, rather than tabs. 22 | indent-style = "space" 23 | 24 | # Like Black, respect magic trailing commas. 25 | skip-magic-trailing-comma = false 26 | 27 | # Like Black, automatically detect the appropriate line ending. 28 | line-ending = "auto" 29 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | bitsandbytes 3 | datasets>=3.3.2 4 | diffusers>=0.32.1 5 | transformers>=4.45.2 6 | huggingface_hub 7 | hf_transfer>=0.1.8 8 | peft>=0.13.0 9 | decord>=0.6.0 10 | wandb 11 | pandas 12 | torch>=2.5.1 13 | torchvision>=0.20.1 14 | torchdata>=0.10.1 15 | torchao>=0.7.0 16 | sentencepiece>=0.2.0 17 | imageio-ffmpeg>=0.5.1 18 | numpy>=1.26.4 19 | kornia>=0.7.3 20 | ruff==0.9.10 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | with open("README.md", "r", encoding="utf-8") as file: 5 | long_description = file.read() 6 | 7 | with open("requirements.txt", "r", encoding="utf-8") as file: 8 | requirements = [line for line in file.read().splitlines() if len(line) > 0] 9 | 10 | setup( 11 | name="finetrainers", 12 | version="0.2.0.dev0", 13 | description="Finetrainers is a work-in-progress library to support (accessible) training of diffusion models", 14 | long_description=long_description, 15 | long_description_content_type="text/markdown", 16 | author="Aryan V S", 17 | author_email="contact.aryanvs@gmail.com", 18 | url="https://github.com/a-r-r-o-w/finetrainers", 19 | python_requires=">=3.8.0", 20 | license="Apache-2.0", 21 | packages=find_packages(), 22 | install_requires=requirements, 23 | extras_require={"dev": ["pytest==8.3.2", "ruff==0.1.5"]}, 24 | classifiers=[ 25 | "Development Status :: 1 - Planning", 26 | "Intended Audience :: Science/Research", 27 | "Intended Audience :: Developers", 28 | "Intended Audience :: Education", 29 | "Programming Language :: Python :: 3", 30 | "Programming Language :: Python :: 3.8", 31 | "Programming Language :: Python :: 3.9", 32 | "Programming Language :: Python :: 3.10", 33 | "Operating System :: Microsoft :: Windows", 34 | "Operating System :: Unix", 35 | "License :: OSI Approved :: MIT License", 36 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 37 | ], 38 | ) 39 | 40 | # Steps to publish: 41 | # 1. Update version in setup.py 42 | # 2. python setup.py sdist bdist_wheel 43 | # 3. Check if everything works with testpypi: 44 | # twine upload --repository testpypi dist/* 45 | # 4. Upload to pypi: 46 | # twine upload dist/* 47 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Running tests 2 | 3 | TODO(aryan): everything here needs to be improved. 4 | 5 | ## `trainer/` fast tests 6 | 7 | - For SFT tests: `test_sft_trainer.py` 8 | - For Control tests: `test_control_trainer.py` 9 | 10 | Accelerate: 11 | 12 | ``` 13 | # world_size=1 tests 14 | accelerate launch --config_file accelerate_configs/uncompiled_1.yaml -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_1___batch_size_1 and ___Accelerate" 15 | accelerate launch --config_file accelerate_configs/uncompiled_1.yaml -m pytest -s tests/trainer/test_sft_trainer.py -k "test___layerwise_upcasting___dp_degree_1___batch_size_1 and ___Accelerate" 16 | 17 | # world_size=2 tests 18 | accelerate launch --config_file accelerate_configs/uncompiled_2.yaml -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___batch_size_1 and ___Accelerate" 19 | ``` 20 | 21 | PTD: 22 | 23 | ``` 24 | # world_size=1 tests 25 | torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_1___batch_size_1 and ___PTD" 26 | torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___layerwise_upcasting___dp_degree_1___batch_size_1 and ___PTD" 27 | torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_1___batch_size_2 and ___PTD" 28 | 29 | # world_size=2 tests 30 | torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___batch_size_1 and ___PTD" 31 | torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___layerwise_upcasting___dp_degree_2___batch_size_1 and ___PTD" 32 | torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___batch_size_2 and ___PTD" 33 | torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_shards_2___batch_size_1 and ___PTD" 34 | torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_shards_2___batch_size_2 and ___PTD" 35 | torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___tp_degree_2___batch_size_2 and ___PTD" 36 | torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___cp_degree_2___batch_size_1 and ___PTD" 37 | 38 | # world_size=4 tests 39 | torchrun --nnodes=1 --nproc_per_node 4 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___dp_shards_2___batch_size_1 and ___PTD" 40 | torchrun --nnodes=1 --nproc_per_node 4 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___cp_degree_2___batch_size_1 and ___PTD" 41 | ``` 42 | 43 | ## CP tests 44 | 45 | PTD: 46 | 47 | ``` 48 | # world_size=2 tests 49 | torchrun --nnodes 1 --nproc_per_node 2 -m pytest -s tests/models/attention_dispatch.py::RingAttentionCP2Test 50 | 51 | # world_size=4 tests 52 | torchrun --nnodes 1 --nproc_per_node 4 -m pytest -s tests/models/attention_dispatch.py::RingAttentionCP4Test 53 | ``` 54 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/tests/__init__.py -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/tests/data/__init__.py -------------------------------------------------------------------------------- /tests/data/utils.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import List 3 | 4 | from diffusers.utils import export_to_video 5 | from PIL import Image 6 | 7 | from finetrainers.data.dataset import COMMON_CAPTION_FILES, COMMON_IMAGE_FILES, COMMON_VIDEO_FILES # noqa 8 | 9 | 10 | def create_dummy_directory_structure( 11 | directory_structure: List[str], tmpdir, num_data_files: int, caption: str, metadata_extension: str 12 | ): 13 | for item in directory_structure: 14 | # TODO(aryan): this should be improved 15 | if item in COMMON_CAPTION_FILES: 16 | data_file = pathlib.Path(tmpdir.name) / item 17 | with open(data_file.as_posix(), "w") as f: 18 | for _ in range(num_data_files): 19 | f.write(f"{caption}\n") 20 | elif item in COMMON_IMAGE_FILES: 21 | data_file = pathlib.Path(tmpdir.name) / item 22 | with open(data_file.as_posix(), "w") as f: 23 | for i in range(num_data_files): 24 | f.write(f"images/{i}.jpg\n") 25 | elif item in COMMON_VIDEO_FILES: 26 | data_file = pathlib.Path(tmpdir.name) / item 27 | with open(data_file.as_posix(), "w") as f: 28 | for i in range(num_data_files): 29 | f.write(f"videos/{i}.mp4\n") 30 | elif item == "metadata.csv": 31 | data_file = pathlib.Path(tmpdir.name) / item 32 | with open(data_file.as_posix(), "w") as f: 33 | f.write("file_name,caption\n") 34 | for i in range(num_data_files): 35 | f.write(f"{i}.{metadata_extension},{caption}\n") 36 | elif item == "metadata.jsonl": 37 | data_file = pathlib.Path(tmpdir.name) / item 38 | with open(data_file.as_posix(), "w") as f: 39 | for i in range(num_data_files): 40 | f.write(f'{{"file_name": "{i}.{metadata_extension}", "caption": "{caption}"}}\n') 41 | elif item.endswith(".txt"): 42 | data_file = pathlib.Path(tmpdir.name) / item 43 | with open(data_file.as_posix(), "w") as f: 44 | f.write(caption) 45 | elif item.endswith(".jpg") or item.endswith(".png"): 46 | data_file = pathlib.Path(tmpdir.name) / item 47 | Image.new("RGB", (64, 64)).save(data_file.as_posix()) 48 | elif item.endswith(".mp4"): 49 | data_file = pathlib.Path(tmpdir.name) / item 50 | export_to_video([Image.new("RGB", (64, 64))] * 4, data_file.as_posix(), fps=2) 51 | else: 52 | data_file = pathlib.Path(tmpdir.name, item) 53 | data_file.mkdir(exist_ok=True, parents=True) 54 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/cogvideox/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/tests/models/cogvideox/__init__.py -------------------------------------------------------------------------------- /tests/models/cogvideox/base_specification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXTransformer3DModel 3 | from transformers import AutoTokenizer, T5EncoderModel 4 | 5 | from finetrainers.models.cogvideox import CogVideoXModelSpecification 6 | 7 | 8 | class DummyCogVideoXModelSpecification(CogVideoXModelSpecification): 9 | def __init__(self, **kwargs): 10 | super().__init__(**kwargs) 11 | 12 | def load_condition_models(self): 13 | text_encoder = T5EncoderModel.from_pretrained( 14 | "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype 15 | ) 16 | tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") 17 | return {"text_encoder": text_encoder, "tokenizer": tokenizer} 18 | 19 | def load_latent_models(self): 20 | torch.manual_seed(0) 21 | vae = AutoencoderKLCogVideoX( 22 | in_channels=3, 23 | out_channels=3, 24 | down_block_types=( 25 | "CogVideoXDownBlock3D", 26 | "CogVideoXDownBlock3D", 27 | "CogVideoXDownBlock3D", 28 | "CogVideoXDownBlock3D", 29 | ), 30 | up_block_types=( 31 | "CogVideoXUpBlock3D", 32 | "CogVideoXUpBlock3D", 33 | "CogVideoXUpBlock3D", 34 | "CogVideoXUpBlock3D", 35 | ), 36 | block_out_channels=(8, 8, 8, 8), 37 | latent_channels=4, 38 | layers_per_block=1, 39 | norm_num_groups=2, 40 | temporal_compression_ratio=4, 41 | ) 42 | # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. 43 | # Doing so overrides things like _keep_in_fp32_modules 44 | vae.to(self.vae_dtype) 45 | self.vae_config = vae.config 46 | return {"vae": vae} 47 | 48 | def load_diffusion_models(self): 49 | torch.manual_seed(0) 50 | transformer = CogVideoXTransformer3DModel( 51 | num_attention_heads=4, 52 | attention_head_dim=16, 53 | in_channels=4, 54 | out_channels=4, 55 | time_embed_dim=2, 56 | text_embed_dim=32, 57 | num_layers=2, 58 | sample_width=24, 59 | sample_height=24, 60 | sample_frames=9, 61 | patch_size=2, 62 | temporal_compression_ratio=4, 63 | max_text_seq_length=16, 64 | use_rotary_positional_embeddings=True, 65 | ) 66 | # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. 67 | # Doing so overrides things like _keep_in_fp32_modules 68 | transformer.to(self.transformer_dtype) 69 | self.transformer_config = transformer.config 70 | scheduler = CogVideoXDDIMScheduler() 71 | return {"transformer": transformer, "scheduler": scheduler} 72 | -------------------------------------------------------------------------------- /tests/models/cogview4/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/tests/models/cogview4/__init__.py -------------------------------------------------------------------------------- /tests/models/cogview4/base_specification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler 3 | from transformers import AutoTokenizer, GlmModel 4 | 5 | from finetrainers.models.cogview4 import CogView4ModelSpecification 6 | 7 | 8 | class DummyCogView4ModelSpecification(CogView4ModelSpecification): 9 | def __init__(self, **kwargs): 10 | super().__init__(**kwargs) 11 | 12 | def load_condition_models(self): 13 | text_encoder = GlmModel.from_pretrained( 14 | "hf-internal-testing/tiny-random-cogview4", subfolder="text_encoder", torch_dtype=self.text_encoder_dtype 15 | ) 16 | tokenizer = AutoTokenizer.from_pretrained( 17 | "hf-internal-testing/tiny-random-cogview4", subfolder="tokenizer", trust_remote_code=True 18 | ) 19 | return {"text_encoder": text_encoder, "tokenizer": tokenizer} 20 | 21 | def load_latent_models(self): 22 | torch.manual_seed(0) 23 | vae = AutoencoderKL.from_pretrained( 24 | "hf-internal-testing/tiny-random-cogview4", subfolder="vae", torch_dtype=self.vae_dtype 25 | ) 26 | self.vae_config = vae.config 27 | return {"vae": vae} 28 | 29 | def load_diffusion_models(self): 30 | torch.manual_seed(0) 31 | transformer = CogView4Transformer2DModel.from_pretrained( 32 | "hf-internal-testing/tiny-random-cogview4", subfolder="transformer", torch_dtype=self.transformer_dtype 33 | ) 34 | scheduler = FlowMatchEulerDiscreteScheduler() 35 | return {"transformer": transformer, "scheduler": scheduler} 36 | -------------------------------------------------------------------------------- /tests/models/cogview4/control_specification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler 3 | from transformers import AutoTokenizer, GlmConfig, GlmModel 4 | 5 | from finetrainers.models.cogview4 import CogView4ControlModelSpecification 6 | from finetrainers.models.utils import _expand_linear_with_zeroed_weights 7 | 8 | 9 | class DummyCogView4ControlModelSpecification(CogView4ControlModelSpecification): 10 | def __init__(self, **kwargs): 11 | super().__init__(**kwargs) 12 | 13 | # This needs to be updated for the test to work correctly. 14 | # TODO(aryan): it will not be needed if we hosted the dummy model so that the correct config could be loaded 15 | # with ModelSpecification::_load_configs 16 | self.transformer_config.in_channels = 4 17 | 18 | def load_condition_models(self): 19 | text_encoder_config = GlmConfig( 20 | hidden_size=32, intermediate_size=8, num_hidden_layers=2, num_attention_heads=4, head_dim=8 21 | ) 22 | text_encoder = GlmModel(text_encoder_config).to(self.text_encoder_dtype) 23 | # TODO(aryan): try to not rely on trust_remote_code by creating dummy tokenizer 24 | tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True) 25 | return {"text_encoder": text_encoder, "tokenizer": tokenizer} 26 | 27 | def load_latent_models(self): 28 | torch.manual_seed(0) 29 | vae = AutoencoderKL( 30 | block_out_channels=[32, 64], 31 | in_channels=3, 32 | out_channels=3, 33 | down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], 34 | up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], 35 | latent_channels=4, 36 | sample_size=128, 37 | ).to(self.vae_dtype) 38 | return {"vae": vae} 39 | 40 | def load_diffusion_models(self, new_in_features: int): 41 | torch.manual_seed(0) 42 | transformer = CogView4Transformer2DModel( 43 | patch_size=2, 44 | in_channels=4, 45 | num_layers=2, 46 | attention_head_dim=4, 47 | num_attention_heads=4, 48 | out_channels=4, 49 | text_embed_dim=32, 50 | time_embed_dim=8, 51 | condition_dim=4, 52 | ).to(self.transformer_dtype) 53 | actual_new_in_features = new_in_features * transformer.config.patch_size**2 54 | transformer.patch_embed.proj = _expand_linear_with_zeroed_weights( 55 | transformer.patch_embed.proj, new_in_features=actual_new_in_features 56 | ) 57 | transformer.register_to_config(in_channels=new_in_features) 58 | 59 | scheduler = FlowMatchEulerDiscreteScheduler() 60 | 61 | return {"transformer": transformer, "scheduler": scheduler} 62 | -------------------------------------------------------------------------------- /tests/models/flux/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/tests/models/flux/__init__.py -------------------------------------------------------------------------------- /tests/models/flux/base_specification.py: -------------------------------------------------------------------------------- 1 | from finetrainers.models.flux import FluxModelSpecification 2 | 3 | 4 | class DummyFluxModelSpecification(FluxModelSpecification): 5 | def __init__(self, **kwargs): 6 | super().__init__(pretrained_model_name_or_path="hf-internal-testing/tiny-flux-pipe", **kwargs) 7 | -------------------------------------------------------------------------------- /tests/models/ltx_video/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/tests/models/ltx_video/__init__.py -------------------------------------------------------------------------------- /tests/models/ltx_video/base_specification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXVideoTransformer3DModel 3 | from transformers import AutoTokenizer, T5EncoderModel 4 | 5 | from finetrainers.models.ltx_video import LTXVideoModelSpecification 6 | 7 | 8 | class DummyLTXVideoModelSpecification(LTXVideoModelSpecification): 9 | def __init__(self, **kwargs): 10 | super().__init__(**kwargs) 11 | 12 | def load_condition_models(self): 13 | text_encoder = T5EncoderModel.from_pretrained( 14 | "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype 15 | ) 16 | tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") 17 | return {"text_encoder": text_encoder, "tokenizer": tokenizer} 18 | 19 | def load_latent_models(self): 20 | torch.manual_seed(0) 21 | vae = AutoencoderKLLTXVideo( 22 | in_channels=3, 23 | out_channels=3, 24 | latent_channels=8, 25 | block_out_channels=(8, 8, 8, 8), 26 | decoder_block_out_channels=(8, 8, 8, 8), 27 | layers_per_block=(1, 1, 1, 1, 1), 28 | decoder_layers_per_block=(1, 1, 1, 1, 1), 29 | spatio_temporal_scaling=(True, True, False, False), 30 | decoder_spatio_temporal_scaling=(True, True, False, False), 31 | decoder_inject_noise=(False, False, False, False, False), 32 | upsample_residual=(False, False, False, False), 33 | upsample_factor=(1, 1, 1, 1), 34 | timestep_conditioning=False, 35 | patch_size=1, 36 | patch_size_t=1, 37 | encoder_causal=True, 38 | decoder_causal=False, 39 | ) 40 | # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. 41 | # Doing so overrides things like _keep_in_fp32_modules 42 | vae.to(self.vae_dtype) 43 | self.vae_config = vae.config 44 | return {"vae": vae} 45 | 46 | def load_diffusion_models(self): 47 | torch.manual_seed(0) 48 | transformer = LTXVideoTransformer3DModel( 49 | in_channels=8, 50 | out_channels=8, 51 | patch_size=1, 52 | patch_size_t=1, 53 | num_attention_heads=4, 54 | attention_head_dim=8, 55 | cross_attention_dim=32, 56 | num_layers=1, 57 | caption_channels=32, 58 | ) 59 | # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. 60 | # Doing so overrides things like _keep_in_fp32_modules 61 | transformer.to(self.transformer_dtype) 62 | scheduler = FlowMatchEulerDiscreteScheduler() 63 | return {"transformer": transformer, "scheduler": scheduler} 64 | -------------------------------------------------------------------------------- /tests/models/wan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-r-r-o-w/finetrainers/2494f411a77c11cd7dab4493a20e2c6551cba768/tests/models/wan/__init__.py -------------------------------------------------------------------------------- /tests/models/wan/base_specification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanTransformer3DModel 3 | from transformers import AutoTokenizer, T5EncoderModel 4 | 5 | from finetrainers.models.wan import WanModelSpecification 6 | 7 | 8 | class DummyWanModelSpecification(WanModelSpecification): 9 | def __init__(self, **kwargs): 10 | super().__init__(**kwargs) 11 | 12 | def load_condition_models(self): 13 | text_encoder = T5EncoderModel.from_pretrained( 14 | "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype 15 | ) 16 | tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") 17 | return {"text_encoder": text_encoder, "tokenizer": tokenizer} 18 | 19 | def load_latent_models(self): 20 | torch.manual_seed(0) 21 | vae = AutoencoderKLWan( 22 | base_dim=3, 23 | z_dim=16, 24 | dim_mult=[1, 1, 1, 1], 25 | num_res_blocks=1, 26 | temperal_downsample=[False, True, True], 27 | ) 28 | # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. 29 | # Doing so overrides things like _keep_in_fp32_modules 30 | vae.to(self.vae_dtype) 31 | self.vae_config = vae.config 32 | return {"vae": vae} 33 | 34 | def load_diffusion_models(self): 35 | torch.manual_seed(0) 36 | transformer = WanTransformer3DModel( 37 | patch_size=(1, 2, 2), 38 | num_attention_heads=2, 39 | attention_head_dim=12, 40 | in_channels=16, 41 | out_channels=16, 42 | text_dim=32, 43 | freq_dim=256, 44 | ffn_dim=32, 45 | num_layers=2, 46 | cross_attn_norm=True, 47 | qk_norm="rms_norm_across_heads", 48 | rope_max_seq_len=32, 49 | ) 50 | # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. 51 | # Doing so overrides things like _keep_in_fp32_modules 52 | transformer.to(self.transformer_dtype) 53 | scheduler = FlowMatchEulerDiscreteScheduler() 54 | return {"transformer": transformer, "scheduler": scheduler} 55 | -------------------------------------------------------------------------------- /tests/models/wan/control_specification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanTransformer3DModel 3 | from transformers import AutoTokenizer, T5EncoderModel 4 | 5 | from finetrainers.models.utils import _expand_conv3d_with_zeroed_weights 6 | from finetrainers.models.wan import WanControlModelSpecification 7 | 8 | 9 | class DummyWanControlModelSpecification(WanControlModelSpecification): 10 | def __init__(self, **kwargs): 11 | super().__init__(**kwargs) 12 | 13 | # This needs to be updated for the test to work correctly. 14 | # TODO(aryan): it will not be needed if we hosted the dummy model so that the correct config could be loaded 15 | # with ModelSpecification::_load_configs 16 | self.transformer_config.in_channels = 16 17 | 18 | def load_condition_models(self): 19 | text_encoder = T5EncoderModel.from_pretrained( 20 | "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype 21 | ) 22 | tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") 23 | return {"text_encoder": text_encoder, "tokenizer": tokenizer} 24 | 25 | def load_latent_models(self): 26 | torch.manual_seed(0) 27 | vae = AutoencoderKLWan( 28 | base_dim=3, 29 | z_dim=16, 30 | dim_mult=[1, 1, 1, 1], 31 | num_res_blocks=1, 32 | temperal_downsample=[False, True, True], 33 | ) 34 | # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. 35 | # Doing so overrides things like _keep_in_fp32_modules 36 | vae.to(self.vae_dtype) 37 | self.vae_config = vae.config 38 | return {"vae": vae} 39 | 40 | def load_diffusion_models(self, new_in_features: int): 41 | torch.manual_seed(0) 42 | transformer = WanTransformer3DModel( 43 | patch_size=(1, 2, 2), 44 | num_attention_heads=2, 45 | attention_head_dim=12, 46 | in_channels=16, 47 | out_channels=16, 48 | text_dim=32, 49 | freq_dim=256, 50 | ffn_dim=32, 51 | num_layers=2, 52 | cross_attn_norm=True, 53 | qk_norm="rms_norm_across_heads", 54 | rope_max_seq_len=32, 55 | ).to(self.transformer_dtype) 56 | 57 | transformer.patch_embedding = _expand_conv3d_with_zeroed_weights( 58 | transformer.patch_embedding, new_in_channels=new_in_features 59 | ) 60 | transformer.register_to_config(in_channels=new_in_features) 61 | 62 | # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. 63 | # Doing so overrides things like _keep_in_fp32_modules 64 | transformer.to(self.transformer_dtype) 65 | scheduler = FlowMatchEulerDiscreteScheduler() 66 | return {"transformer": transformer, "scheduler": scheduler} 67 | -------------------------------------------------------------------------------- /tests/scripts/dummy_cogvideox_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPU_IDS="0,1" 4 | DATA_ROOT="$ROOT_DIR/video-dataset-disney" 5 | CAPTION_COLUMN="prompt.txt" 6 | VIDEO_COLUMN="videos.txt" 7 | OUTPUT_DIR="cogvideox" 8 | ID_TOKEN="BW_STYLE" 9 | 10 | # Model arguments 11 | model_cmd="--model_name cogvideox \ 12 | --pretrained_model_name_or_path THUDM/CogVideoX-5b" 13 | 14 | # Dataset arguments 15 | dataset_cmd="--data_root $DATA_ROOT \ 16 | --video_column $VIDEO_COLUMN \ 17 | --caption_column $CAPTION_COLUMN \ 18 | --id_token $ID_TOKEN \ 19 | --video_resolution_buckets 49x480x720 \ 20 | --caption_dropout_p 0.05" 21 | 22 | # Dataloader arguments 23 | dataloader_cmd="--dataloader_num_workers 0 --precompute_conditions" 24 | 25 | # Training arguments 26 | training_cmd="--training_type lora \ 27 | --seed 42 \ 28 | --batch_size 1 \ 29 | --precompute_conditions \ 30 | --train_steps 10 \ 31 | --rank 128 \ 32 | --lora_alpha 128 \ 33 | --target_modules to_q to_k to_v to_out.0 \ 34 | --gradient_accumulation_steps 1 \ 35 | --gradient_checkpointing \ 36 | --checkpointing_steps 5 \ 37 | --checkpointing_limit 2 \ 38 | --resume_from_checkpoint=latest \ 39 | --enable_slicing \ 40 | --enable_tiling" 41 | 42 | # Optimizer arguments 43 | optimizer_cmd="--optimizer adamw \ 44 | --lr 3e-5 \ 45 | --beta1 0.9 \ 46 | --beta2 0.95 \ 47 | --weight_decay 1e-4 \ 48 | --epsilon 1e-8 \ 49 | --max_grad_norm 1.0" 50 | 51 | # Validation arguments 52 | validation_prompts=$(cat <