├── README.md ├── diffusion ├── .ipynb_checkpoints │ └── __init__-checkpoint.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-311.pyc │ ├── __init__.cpython-39.pyc │ ├── download.cpython-310.pyc │ ├── dpm_solver.cpython-310.pyc │ ├── dpm_solver.cpython-311.pyc │ ├── dpm_solver.cpython-39.pyc │ ├── iddpm.cpython-310.pyc │ ├── iddpm.cpython-311.pyc │ ├── iddpm.cpython-39.pyc │ ├── sa_sampler.cpython-310.pyc │ ├── sa_sampler.cpython-311.pyc │ └── sa_sampler.cpython-39.pyc ├── configs │ └── config_relactrl_pixart_1024.py ├── data │ ├── .ipynb_checkpoints │ │ ├── builder-checkpoint.py │ │ └── transforms-checkpoint.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── builder.cpython-310.pyc │ │ ├── builder.cpython-39.pyc │ │ ├── transforms.cpython-310.pyc │ │ └── transforms.cpython-39.pyc │ ├── builder.py │ ├── datasets │ │ ├── Dreambooth.py │ │ ├── InternalData.py │ │ ├── InternalData_ms.py │ │ ├── SA.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── Dreambooth.cpython-310.pyc │ │ │ ├── Dreambooth.cpython-39.pyc │ │ │ ├── InternalData.cpython-310.pyc │ │ │ ├── InternalData.cpython-39.pyc │ │ │ ├── InternalData_ms.cpython-310.pyc │ │ │ ├── InternalData_ms.cpython-39.pyc │ │ │ ├── SA.cpython-310.pyc │ │ │ ├── SA.cpython-39.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── pixart_control.cpython-310.pyc │ │ │ ├── pixart_control.cpython-39.pyc │ │ │ ├── pixart_controldit.cpython-310.pyc │ │ │ ├── utils.cpython-310.pyc │ │ │ └── utils.cpython-39.pyc │ │ ├── pixart_control.py │ │ └── utils.py │ └── transforms.py ├── download.py ├── dpm_solver.py ├── iddpm.py ├── lcm_scheduler.py ├── model │ ├── .ipynb_checkpoints │ │ └── builder-checkpoint.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-311.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── builder.cpython-310.pyc │ │ ├── builder.cpython-311.pyc │ │ ├── builder.cpython-39.pyc │ │ ├── diffusion_utils.cpython-310.pyc │ │ ├── diffusion_utils.cpython-311.pyc │ │ ├── diffusion_utils.cpython-39.pyc │ │ ├── dpm_solver.cpython-310.pyc │ │ ├── dpm_solver.cpython-311.pyc │ │ ├── dpm_solver.cpython-39.pyc │ │ ├── gaussian_diffusion.cpython-310.pyc │ │ ├── gaussian_diffusion.cpython-311.pyc │ │ ├── gaussian_diffusion.cpython-39.pyc │ │ ├── hed.cpython-310.pyc │ │ ├── hed.cpython-39.pyc │ │ ├── respace.cpython-310.pyc │ │ ├── respace.cpython-311.pyc │ │ ├── respace.cpython-39.pyc │ │ ├── sa_solver.cpython-310.pyc │ │ ├── sa_solver.cpython-39.pyc │ │ ├── t5.cpython-310.pyc │ │ ├── t5.cpython-39.pyc │ │ ├── utils.cpython-310.pyc │ │ ├── utils.cpython-311.pyc │ │ └── utils.cpython-39.pyc │ ├── builder.py │ ├── diffusion_utils.py │ ├── dpm_solver.py │ ├── edm_sample.py │ ├── gaussian_diffusion.py │ ├── hed.py │ ├── llava │ │ ├── __init__.py │ │ ├── llava_mpt.py │ │ └── mpt │ │ │ ├── attention.py │ │ │ ├── blocks.py │ │ │ ├── configuration_mpt.py │ │ │ ├── modeling_mpt.py │ │ │ ├── norm.py │ │ │ └── param_init_fns.py │ ├── nets │ │ ├── PixArt.py │ │ ├── PixArtMS.py │ │ ├── PixArt_blocks.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── PixArt.cpython-310.pyc │ │ │ ├── PixArt.cpython-311.pyc │ │ │ ├── PixArt.cpython-39.pyc │ │ │ ├── PixArtMS.cpython-310.pyc │ │ │ ├── PixArtMS.cpython-311.pyc │ │ │ ├── PixArtMS.cpython-39.pyc │ │ │ ├── PixArt_blocks.cpython-310.pyc │ │ │ ├── PixArt_blocks.cpython-311.pyc │ │ │ ├── PixArt_blocks.cpython-39.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-311.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── pixart_controlnet.cpython-310.pyc │ │ │ ├── pixart_controlnet.cpython-311.pyc │ │ │ ├── pixart_controlnet.cpython-39.pyc │ │ │ ├── pixart_controlnet_adamamba.cpython-310.pyc │ │ │ ├── pixart_controlnet_adapter.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe2.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe3_mlp.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe4.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe4_allmlp.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe4_show.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior1.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior11.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior2.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior3.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior4.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior4s.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior51.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior52.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior53.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior54.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior55.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior551.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior552.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior55_e1.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior55_e2.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior55_e3.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior55_e3_show.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior55_e4.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior55_e50.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior55_e51.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior55_e52.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior55_same1.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior55_same2.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior55_x0bug.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior55_x0e1.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior56.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior57.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior58.cpython-310.pyc │ │ │ ├── pixart_controlnet_moe5_prior59.cpython-310.pyc │ │ │ ├── pixart_controlnet_output.cpython-310.pyc │ │ │ ├── pixart_controlnet_output.cpython-39.pyc │ │ │ ├── pixart_controlnet_output2.cpython-310.pyc │ │ │ ├── pixart_controlnet_p1o1.cpython-310.pyc │ │ │ ├── pixart_controlnet_p1o2.cpython-310.pyc │ │ │ ├── pixart_controlnet_p1o3.cpython-310.pyc │ │ │ ├── pixart_controlnet_p2o1.cpython-310.pyc │ │ │ ├── pixart_controlnet_p2o2.cpython-310.pyc │ │ │ ├── pixart_controlnet_remove.cpython-310.pyc │ │ │ ├── pixart_controlnet_skipcopy.cpython-310.pyc │ │ │ ├── pixart_controlnet_skiporigin.cpython-310.pyc │ │ │ ├── pixart_controlnet_skipshow.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage10_style.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage1_cop27.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage1_cop27_shunxu.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage1_copy11_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage1_copy12_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage1_copy12_our2.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage1_copy13.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage1_copy13_shunxu.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage1_copy13_skip.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage1_copy13_skip_new.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage1_copy13_top.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage1_copy13_top2.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage2_our11_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage2_our12_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage2_our13_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage3_our12_our1.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage3_our12_our2.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage3_our12_our3.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage3_our12_our4.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage3_our12_our5.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage3_our12_our6.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage3_our12_our7.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage4_copy10_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage4_copy11_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage4_copy12_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage4_copy13_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage4_copy7_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage5_copy10_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage5_copy11_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage5_copy12_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage5_copy13_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage5_copy7_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage6_copy11_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage6_copy11_our2.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage6_our11_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage7_our11_our.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage7_our11_our2.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage7_our11_our3.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage7_our11_our4.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage8_our11_our5.cpython-310.pyc │ │ │ ├── pixart_controlnet_stage9_our11_our4_test.cpython-310.pyc │ │ │ ├── pixart_main_output.cpython-310.pyc │ │ │ ├── pixart_relactrl_v1.cpython-310.pyc │ │ │ └── relactrl_v1.cpython-310.pyc │ │ ├── pixart_controlnet.py │ │ └── pixart_relactrl_v1.py │ ├── respace.py │ ├── sa_solver.py │ ├── t5.py │ ├── timestep_sampler.py │ └── utils.py ├── sa_sampler.py ├── sa_solver_diffusers.py └── utils │ ├── .ipynb_checkpoints │ ├── dist_utils-checkpoint.py │ └── logger-checkpoint.py │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-311.pyc │ ├── __init__.cpython-39.pyc │ ├── checkpoint.cpython-310.pyc │ ├── checkpoint.cpython-39.pyc │ ├── data_sampler.cpython-310.pyc │ ├── data_sampler.cpython-39.pyc │ ├── dist_utils.cpython-310.pyc │ ├── dist_utils.cpython-311.pyc │ ├── dist_utils.cpython-39.pyc │ ├── logger.cpython-310.pyc │ ├── logger.cpython-311.pyc │ ├── logger.cpython-39.pyc │ ├── lr_scheduler.cpython-310.pyc │ ├── lr_scheduler.cpython-39.pyc │ ├── misc.cpython-310.pyc │ ├── misc.cpython-39.pyc │ ├── optimizer.cpython-310.pyc │ └── optimizer.cpython-39.pyc │ ├── checkpoint.py │ ├── data_sampler.py │ ├── dist_utils.py │ ├── logger.py │ ├── lr_scheduler.py │ ├── misc.py │ └── optimizer.py ├── output ├── result1.png ├── style1_gufeng.png ├── style2_oil.png ├── style3_gufeng.png ├── style3_oil.png ├── style3_paint.png └── style4_3d.png ├── pipeline └── test_relactrl_pixart_1024.py ├── requirements.txt └── resources └── demos ├── prompts_examples.md └── reference_images ├── example1.png ├── example2.png ├── example3.png ├── style1.png ├── style2.png ├── style3.png └── style4.png /README.md: -------------------------------------------------------------------------------- 1 | # RelaCtrl 2 | 3 | This is the official reproduction of [RelaCtrl](https://360cvgroup.github.io/RelaCtrl/), which represents an efficient controlnet-like architecture designed for DiTs. 4 | 5 | **[RelaCtrl: Relevance-Guided Efficient Control for Diffusion Transformers](https://arxiv.org/pdf/2502.14377)** 6 |
7 | Ke Cao*, Jing Wang*, Ao Ma*, Jiasong Feng, Zhanjie Zhang, Xuanhua He, Shanyuan Liu, Bo Cheng, Dawei Leng‡, Yuhui Yin, Jie Zhang‡(*Equal Contribution, ‡Corresponding Authors) 8 |
9 | [![arXiv](https://img.shields.io/badge/arXiv-2502.14377-b31b1b.svg)](https://arxiv.org/pdf/2502.14377) 10 | [![Project Page](https://img.shields.io/badge/Project-Website-green)](https://360cvgroup.github.io/RelaCtrl/) 11 | 12 | 13 | ## 📰 News 14 | - **[2025.04.07]** We released the inference pipeline and some weights of RelaCtrl-PixArt. 15 | - **[2025.02.21]** We have released our paper [RelaCtrl](https://arxiv.org/pdf/2502.14377) and created a dedicated [project homepage](https://360cvgroup.github.io/RelaCtrl/). 16 | 17 | 18 | ## Inference with RealCtrl on PixArt 19 | ### Dependencies and Installation 20 | ``` python 21 | conda create -n relactrl python=3.10 22 | conda activate relactrl 23 | 24 | pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118 25 | 26 | git clone https://github.com/360CVGroup/RelaCtrl.git 27 | cd RelaCtrl 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | ## Download Models 32 | 33 | ### 1. Required PixArt-related Weights 34 | 35 | Download the necessary model weights for PixArt from the links below: 36 | 37 | | Model | Parameters | Download Link | 38 | |--------------|------------|----------------------------------------------------------------| 39 | | **T5** | 4.3B | [T5](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) | 40 | | **VAE** | 80M | [VAE](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/sd-vae-ft-ema) | 41 | | **PixArt-α-1024** | 0.6B | [PixArt-XL-2-1024-MS.pth](https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-1024-MS.pth) or [Diffusers Version](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS) | 42 | 43 | ### 2. RelaCtrl Conditional Weights 44 | 45 | Download the required conditional weights for RelaCtrl: 46 | | Model | Parameters | Download Link | 47 | |--------------|------------|----------------------------------------------------------------| 48 | | **RelaCtrl_PixArt_Canny** | 45M | [Canny](https://huggingface.co/qihoo360/RelaCtrl/tree/main) | 49 | | **RelaCtrl_PixArt_Canny_Style** | 45M | [Style](https://huggingface.co/qihoo360/RelaCtrl/tree/main) | 50 | 51 | 52 | ### Inference with Conditions 53 | ``` python 54 | python pipeline/test_relactrl_pixart_1024.py diffusion/configs/config_relactrl_pixart_1024.py 55 | ``` 56 | Prompt examples for different models can be found in the [prompt_exampeles](resources/demos/prompts_examples.md). 57 | 58 | ### Acknowledgment 59 | The PixArt model weights are derived from the open-source project [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha). 60 | Please refer to the original repository for detailed license information. 61 | 62 | ## BibTeX 63 | ``` 64 | @misc{cao2025relactrl, 65 | title={RelaCtrl: Relevance-Guided Efficient Control for Diffusion Transformers}, 66 | author={Ke Cao and Jing Wang and Ao Ma and Jiasong Feng and Zhanjie Zhang and Xuanhua He and Shanyuan Liu and Bo Cheng and Dawei Leng and Yuhui Yin and Jie Zhang}, 67 | year={2025}, 68 | eprint={2502.14377}, 69 | archivePrefix={arXiv}, 70 | primaryClass={cs.CV}, 71 | url={https://arxiv.org/abs/2502.14377}, 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /diffusion/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from .iddpm import IDDPM 7 | from .dpm_solver import DPMS 8 | from .sa_sampler import SASolverSampler 9 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from .iddpm import IDDPM 7 | from .dpm_solver import DPMS 8 | from .sa_sampler import SASolverSampler 9 | from .download import * 10 | -------------------------------------------------------------------------------- /diffusion/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/download.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/__pycache__/download.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/dpm_solver.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/__pycache__/dpm_solver.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/dpm_solver.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/__pycache__/dpm_solver.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/dpm_solver.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/__pycache__/dpm_solver.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/iddpm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/__pycache__/iddpm.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/iddpm.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/__pycache__/iddpm.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/iddpm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/__pycache__/iddpm.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/sa_sampler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/__pycache__/sa_sampler.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/sa_sampler.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/__pycache__/sa_sampler.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/sa_sampler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/__pycache__/sa_sampler.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/configs/config_relactrl_pixart_1024.py: -------------------------------------------------------------------------------- 1 | # model setting 2 | # model_name = 'PixArt-XL-2-1024-MS.pth' 3 | num_sampling_steps = 20 4 | cfg_scale = 4.5 5 | image_size = 1024 6 | model_path = '/home/jovyan/maao-data-cephfs-0/dataspace/maao/projects/Common/models/PixArt-alpha/PixArt-ControlNet/PixArt-XL-2-1024-ControlNet.pth' 7 | tokenizer_path = '/home/jovyan/maao-data-cephfs-0/dataspace/maao/projects/Common/models/PixArt-alpha/PixArt-XL-2-1024-MS/vae' 8 | llm_model = 't5' 9 | sampling_algo = 'dpm-solver' 10 | port = 7788 11 | condition_strength = 1 12 | 13 | 14 | -------------------------------------------------------------------------------- /diffusion/data/.ipynb_checkpoints/builder-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | from mmcv import Registry, build_from_cfg 5 | from torch.utils.data import DataLoader 6 | 7 | from diffusion.data.transforms import get_transform 8 | from diffusion.utils.logger import get_root_logger 9 | 10 | DATASETS = Registry('datasets') 11 | 12 | DATA_ROOT = '/cache/data' 13 | 14 | 15 | def set_data_root(data_root): 16 | global DATA_ROOT 17 | DATA_ROOT = data_root 18 | 19 | 20 | def get_data_path(data_dir): 21 | if os.path.isabs(data_dir): 22 | return data_dir 23 | global DATA_ROOT 24 | return os.path.join(DATA_ROOT, data_dir) 25 | 26 | 27 | def build_dataset(cfg, resolution=224, **kwargs): 28 | logger = get_root_logger() 29 | 30 | dataset_type = cfg.get('type') 31 | logger.info(f"Constructing dataset {dataset_type}...") 32 | t = time.time() 33 | transform = cfg.pop('transform', 'default_train') 34 | transform = get_transform(transform, resolution) 35 | dataset = build_from_cfg(cfg, DATASETS, default_args=dict(transform=transform, resolution=resolution, **kwargs)) 36 | logger.info(f"Dataset {dataset_type} constructed. time: {(time.time() - t):.2f} s, length (use/ori): {len(dataset)}/{dataset.ori_imgs_nums}") 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, batch_size=256, num_workers=4, shuffle=True, **kwargs): 41 | return ( 42 | DataLoader( 43 | dataset, 44 | batch_sampler=kwargs['batch_sampler'], 45 | num_workers=num_workers, 46 | pin_memory=True, 47 | ) 48 | if 'batch_sampler' in kwargs 49 | else DataLoader( 50 | dataset, 51 | batch_size=batch_size, 52 | shuffle=shuffle, 53 | num_workers=num_workers, 54 | pin_memory=True, 55 | **kwargs 56 | ) 57 | ) 58 | -------------------------------------------------------------------------------- /diffusion/data/.ipynb_checkpoints/transforms-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | 3 | TRANSFORMS = {} 4 | 5 | 6 | def register_transform(transform): 7 | name = transform.__name__ 8 | if name in TRANSFORMS: 9 | raise RuntimeError(f'Transform {name} has already registered.') 10 | TRANSFORMS.update({name: transform}) 11 | 12 | 13 | def get_transform(type, resolution): 14 | transform = TRANSFORMS[type](resolution) 15 | transform = T.Compose(transform) 16 | transform.image_size = resolution 17 | return transform 18 | 19 | 20 | @register_transform 21 | def default_train(n_px): 22 | return [ 23 | T.Lambda(lambda img: img.convert('RGB')), 24 | T.Resize(n_px), # Image.BICUBIC 25 | T.CenterCrop(n_px), 26 | # T.RandomHorizontalFlip(), 27 | T.ToTensor(), 28 | T.Normalize([0.5], [0.5]), 29 | ] 30 | -------------------------------------------------------------------------------- /diffusion/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import * 2 | from .transforms import get_transform 3 | -------------------------------------------------------------------------------- /diffusion/data/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/data/__pycache__/builder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/__pycache__/builder.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/data/__pycache__/builder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/__pycache__/builder.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/data/__pycache__/transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/__pycache__/transforms.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/data/__pycache__/transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/__pycache__/transforms.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/data/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | from mmcv import Registry, build_from_cfg 5 | from torch.utils.data import DataLoader 6 | 7 | from diffusion.data.transforms import get_transform 8 | from diffusion.utils.logger import get_root_logger 9 | 10 | DATASETS = Registry('datasets') 11 | 12 | DATA_ROOT = '/cache/data' 13 | 14 | 15 | def set_data_root(data_root): 16 | global DATA_ROOT 17 | DATA_ROOT = data_root 18 | 19 | 20 | def get_data_path(data_dir): 21 | if os.path.isabs(data_dir): 22 | return data_dir 23 | global DATA_ROOT 24 | return os.path.join(DATA_ROOT, data_dir) 25 | 26 | 27 | def build_dataset(cfg, resolution=224, **kwargs): 28 | logger = get_root_logger() 29 | 30 | dataset_type = cfg.get('type') 31 | logger.info(f"Constructing dataset {dataset_type}...") 32 | t = time.time() 33 | transform = cfg.pop('transform', 'default_train') 34 | transform = get_transform(transform, resolution) 35 | dataset = build_from_cfg(cfg, DATASETS, default_args=dict(transform=transform, resolution=resolution, **kwargs)) 36 | logger.info(f"Dataset {dataset_type} constructed. time: {(time.time() - t):.2f} s, length (use/ori): {len(dataset)}/{dataset.ori_imgs_nums}") 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, batch_size=256, num_workers=4, shuffle=True, **kwargs): 41 | return ( 42 | DataLoader( 43 | dataset, 44 | batch_sampler=kwargs['batch_sampler'], 45 | num_workers=num_workers, 46 | pin_memory=True, 47 | ) 48 | if 'batch_sampler' in kwargs 49 | else DataLoader( 50 | dataset, 51 | batch_size=batch_size, 52 | shuffle=shuffle, 53 | num_workers=num_workers, 54 | pin_memory=True, 55 | **kwargs 56 | ) 57 | ) 58 | -------------------------------------------------------------------------------- /diffusion/data/datasets/Dreambooth.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import torch 4 | from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS 5 | from torch.utils.data import Dataset 6 | from diffusers.utils.torch_utils import randn_tensor 7 | from torchvision import transforms as T 8 | import pathlib 9 | from diffusers.models import AutoencoderKL 10 | 11 | from diffusion.data.builder import get_data_path, DATASETS 12 | from diffusion.data.datasets.utils import * 13 | 14 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 'tif', 'tiff', 'webp', 'JPEG'} 15 | 16 | 17 | @DATASETS.register_module() 18 | class DreamBooth(Dataset): 19 | def __init__(self, 20 | root, 21 | transform=None, 22 | resolution=1024, 23 | **kwargs): 24 | self.root = get_data_path(root) 25 | path = pathlib.Path(self.root) 26 | self.transform = transform 27 | self.resolution = resolution 28 | self.img_samples = sorted( 29 | [file for ext in IMAGE_EXTENSIONS for file in path.glob(f'*.{ext}')] 30 | ) 31 | self.ori_imgs_nums = len(self) 32 | self.loader = default_loader 33 | self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1]) 34 | self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio 35 | self.ratio_nums = {} 36 | for k, v in self.aspect_ratio.items(): 37 | self.ratio_nums[float(k)] = 0 # used for batch-sampler 38 | self.data_info = {'img_hw': torch.tensor([resolution, resolution], dtype=torch.float32), 'aspect_ratio': 1.} 39 | 40 | # image related 41 | with torch.inference_mode(): 42 | vae = AutoencoderKL.from_pretrained("output/pretrained_models/sd-vae-ft-ema") 43 | imgs = [] 44 | for img_path in self.img_samples: 45 | img = self.loader(img_path) 46 | self.ratio_nums[1.0] += 1 47 | if self.transform is not None: 48 | imgs.append(self.transform(img)) 49 | imgs = torch.stack(imgs, dim=0) 50 | self.img_vae = vae.encode(imgs).latent_dist.sample() 51 | del vae 52 | 53 | def __getitem__(self, index): 54 | return self.img_vae[index], self.data_info 55 | 56 | @staticmethod 57 | def vae_feat_loader(path): 58 | # [mean, std] 59 | mean, std = torch.from_numpy(np.load(path)).chunk(2) 60 | sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype) 61 | return mean + std * sample 62 | 63 | def load_ori_img(self, img_path): 64 | # 加载图像并转换为Tensor 65 | transform = T.Compose([ 66 | T.Resize(256), # Image.BICUBIC 67 | T.CenterCrop(256), 68 | T.ToTensor(), 69 | ]) 70 | return transform(Image.open(img_path)) 71 | 72 | def __len__(self): 73 | return len(self.img_samples) 74 | 75 | def __getattr__(self, name): 76 | if name == "set_epoch": 77 | return lambda epoch: None 78 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") 79 | 80 | def get_data_info(self, idx): 81 | return {'height': self.resolution, 'width': self.resolution} 82 | -------------------------------------------------------------------------------- /diffusion/data/datasets/InternalData.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS 7 | from torch.utils.data import Dataset 8 | from diffusers.utils.torch_utils import randn_tensor 9 | from torchvision import transforms as T 10 | from diffusion.data.builder import get_data_path, DATASETS 11 | from diffusion.utils.logger import get_root_logger 12 | 13 | import json 14 | 15 | 16 | @DATASETS.register_module() 17 | class InternalData(Dataset): 18 | def __init__(self, 19 | root, 20 | image_list_json='data_info.json', 21 | transform=None, 22 | resolution=256, 23 | sample_subset=None, 24 | load_vae_feat=False, 25 | input_size=32, 26 | patch_size=2, 27 | mask_ratio=0.0, 28 | load_mask_index=False, 29 | max_length=120, 30 | config=None, 31 | **kwargs): 32 | self.root = get_data_path(root) 33 | self.transform = transform 34 | self.load_vae_feat = load_vae_feat 35 | self.ori_imgs_nums = 0 36 | self.resolution = resolution 37 | self.N = int(resolution // (input_size // patch_size)) 38 | self.mask_ratio = mask_ratio 39 | self.load_mask_index = load_mask_index 40 | self.max_lenth = max_length 41 | self.meta_data_clean = [] 42 | self.img_samples = [] 43 | self.txt_feat_samples = [] 44 | self.vae_feat_samples = [] 45 | self.mask_index_samples = [] 46 | self.prompt_samples = [] 47 | 48 | image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json] 49 | for json_file in image_list_json: 50 | meta_data = self.load_json(os.path.join(self.root, 'partition', json_file)) 51 | self.ori_imgs_nums += len(meta_data) 52 | meta_data_clean = [item for item in meta_data if item['ratio'] <= 4] 53 | self.meta_data_clean.extend(meta_data_clean) 54 | self.img_samples.extend([os.path.join(self.root.replace('InternData', "InternImgs"), item['path']) for item in meta_data_clean]) 55 | self.txt_feat_samples.extend([os.path.join(self.root, 'caption_feature_wmask', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')) for item in meta_data_clean]) 56 | self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_features/{resolution}resolution/noflip', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')) for item in meta_data_clean]) 57 | # self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_features_{resolution}resolution/noflip', item['path'].rsplit('/', 1).replace('.png', '.npy')) for item in meta_data_clean]) 58 | self.prompt_samples.extend([item['prompt'] for item in meta_data_clean]) 59 | 60 | # Set loader and extensions 61 | if load_vae_feat: 62 | self.transform = None 63 | self.loader = self.vae_feat_loader 64 | else: 65 | self.loader = default_loader 66 | 67 | if sample_subset is not None: 68 | self.sample_subset(sample_subset) # sample dataset for local debug 69 | logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) 70 | logger.info(f"T5 max token length: {self.max_lenth}") 71 | 72 | def getdata(self, index): 73 | img_path = self.img_samples[index] 74 | npz_path = self.txt_feat_samples[index] 75 | npy_path = self.vae_feat_samples[index] 76 | prompt = self.prompt_samples[index] 77 | data_info = { 78 | 'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32), 79 | 'aspect_ratio': torch.tensor(1.) 80 | } 81 | 82 | img = self.loader(npy_path) if self.load_vae_feat else self.loader(img_path) 83 | txt_info = np.load(npz_path) 84 | txt_fea = torch.from_numpy(txt_info['caption_feature']) # 1xTx4096 85 | attention_mask = torch.ones(1, 1, txt_fea.shape[1]) # 1x1xT 86 | if 'attention_mask' in txt_info.keys(): 87 | attention_mask = torch.from_numpy(txt_info['attention_mask'])[None] 88 | if txt_fea.shape[1] != self.max_lenth: 89 | txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1) 90 | attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1) 91 | 92 | if self.transform: 93 | img = self.transform(img) 94 | 95 | data_info['prompt'] = prompt 96 | return img, txt_fea, attention_mask, data_info 97 | 98 | def __getitem__(self, idx): 99 | for _ in range(20): 100 | try: 101 | return self.getdata(idx) 102 | except Exception as e: 103 | print(f"Error details: {str(e)}") 104 | idx = np.random.randint(len(self)) 105 | raise RuntimeError('Too many bad data.') 106 | 107 | def get_data_info(self, idx): 108 | data_info = self.meta_data_clean[idx] 109 | return {'height': data_info['height'], 'width': data_info['width']} 110 | 111 | @staticmethod 112 | def vae_feat_loader(path): 113 | # [mean, std] 114 | mean, std = torch.from_numpy(np.load(path)).chunk(2) 115 | sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype) 116 | return mean + std * sample 117 | 118 | def load_ori_img(self, img_path): 119 | # 加载图像并转换为Tensor 120 | transform = T.Compose([ 121 | T.Resize(256), # Image.BICUBIC 122 | T.CenterCrop(256), 123 | T.ToTensor(), 124 | ]) 125 | return transform(Image.open(img_path)) 126 | 127 | def load_json(self, file_path): 128 | with open(file_path, 'r') as f: 129 | meta_data = json.load(f) 130 | 131 | return meta_data 132 | 133 | def sample_subset(self, ratio): 134 | sampled_idx = random.sample(list(range(len(self))), int(len(self) * ratio)) 135 | self.img_samples = [self.img_samples[i] for i in sampled_idx] 136 | 137 | def __len__(self): 138 | return len(self.img_samples) 139 | 140 | def __getattr__(self, name): 141 | if name == "set_epoch": 142 | return lambda epoch: None 143 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") 144 | 145 | -------------------------------------------------------------------------------- /diffusion/data/datasets/InternalData_ms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import random 5 | from torchvision.datasets.folder import default_loader 6 | from diffusion.data.datasets.InternalData import InternalData 7 | from diffusion.data.builder import get_data_path, DATASETS 8 | from diffusion.utils.logger import get_root_logger 9 | import torchvision.transforms as T 10 | from torchvision.transforms.functional import InterpolationMode 11 | from diffusion.data.datasets.utils import * 12 | 13 | def get_closest_ratio(height: float, width: float, ratios: dict): 14 | aspect_ratio = height / width 15 | closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) 16 | return ratios[closest_ratio], float(closest_ratio) 17 | 18 | 19 | @DATASETS.register_module() 20 | class InternalDataMS(InternalData): 21 | def __init__(self, 22 | root, 23 | image_list_json='data_info.json', 24 | transform=None, 25 | resolution=256, 26 | sample_subset=None, 27 | load_vae_feat=False, 28 | input_size=32, 29 | patch_size=2, 30 | mask_ratio=0.0, 31 | mask_type='null', 32 | load_mask_index=False, 33 | max_length=120, 34 | config=None, 35 | **kwargs): 36 | self.root = get_data_path(root) 37 | self.transform = transform 38 | self.load_vae_feat = load_vae_feat 39 | self.ori_imgs_nums = 0 40 | self.resolution = resolution 41 | self.N = int(resolution // (input_size // patch_size)) 42 | self.mask_ratio = mask_ratio 43 | self.load_mask_index = load_mask_index 44 | self.mask_type = mask_type 45 | self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1]) 46 | self.max_lenth = max_length 47 | self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio 48 | self.meta_data_clean = [] 49 | self.img_samples = [] 50 | self.txt_feat_samples = [] 51 | self.vae_feat_samples = [] 52 | self.mask_index_samples = [] 53 | self.ratio_index = {} 54 | self.ratio_nums = {} 55 | for k, v in self.aspect_ratio.items(): 56 | self.ratio_index[float(k)] = [] # used for self.getitem 57 | self.ratio_nums[float(k)] = 0 # used for batch-sampler 58 | 59 | image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json] 60 | for json_file in image_list_json: 61 | meta_data = self.load_json(os.path.join(self.root, 'partition_filter', json_file)) 62 | self.ori_imgs_nums += len(meta_data) 63 | meta_data_clean = [item for item in meta_data if item['ratio'] <= 4] 64 | self.meta_data_clean.extend(meta_data_clean) 65 | self.img_samples.extend([os.path.join(self.root.replace('InternData', "InternImgs"), item['path']) for item in meta_data_clean]) 66 | self.txt_feat_samples.extend([os.path.join(self.root, 'caption_feature_wmask', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')) for item in meta_data_clean]) 67 | self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_fatures_{resolution}_multiscale/ms', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')) for item in meta_data_clean]) 68 | 69 | # Set loader and extensions 70 | if load_vae_feat: 71 | self.transform = None 72 | self.loader = self.vae_feat_loader 73 | else: 74 | self.loader = default_loader 75 | 76 | if sample_subset is not None: 77 | self.sample_subset(sample_subset) # sample dataset for local debug 78 | 79 | # scan the dataset for ratio static 80 | for i, info in enumerate(self.meta_data_clean[:len(self.meta_data_clean)//3]): 81 | ori_h, ori_w = info['height'], info['width'] 82 | closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio) 83 | self.ratio_nums[closest_ratio] += 1 84 | if len(self.ratio_index[closest_ratio]) == 0: 85 | self.ratio_index[closest_ratio].append(i) 86 | # print(self.ratio_nums) 87 | logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) 88 | logger.info(f"T5 max token length: {self.max_lenth}") 89 | 90 | def getdata(self, index): 91 | img_path = self.img_samples[index] 92 | npz_path = self.txt_feat_samples[index] 93 | npy_path = self.vae_feat_samples[index] 94 | ori_h, ori_w = self.meta_data_clean[index]['height'], self.meta_data_clean[index]['width'] 95 | 96 | # Calculate the closest aspect ratio and resize & crop image[w, h] 97 | closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio) 98 | closest_size = list(map(lambda x: int(x), closest_size)) 99 | self.closest_ratio = closest_ratio 100 | 101 | if self.load_vae_feat: 102 | try: 103 | img = self.loader(npy_path) 104 | if index not in self.ratio_index[closest_ratio]: 105 | self.ratio_index[closest_ratio].append(index) 106 | except Exception: 107 | index = random.choice(self.ratio_index[closest_ratio]) 108 | return self.getdata(index) 109 | h, w = (img.shape[1], img.shape[2]) 110 | assert h, w == (ori_h//8, ori_w//8) 111 | else: 112 | img = self.loader(img_path) 113 | h, w = (img.size[1], img.size[0]) 114 | assert h, w == (ori_h, ori_w) 115 | 116 | data_info = {'img_hw': torch.tensor([ori_h, ori_w], dtype=torch.float32)} 117 | data_info['aspect_ratio'] = closest_ratio 118 | data_info["mask_type"] = self.mask_type 119 | 120 | txt_info = np.load(npz_path) 121 | txt_fea = torch.from_numpy(txt_info['caption_feature']) 122 | attention_mask = torch.ones(1, 1, txt_fea.shape[1]) 123 | if 'attention_mask' in txt_info.keys(): 124 | attention_mask = torch.from_numpy(txt_info['attention_mask'])[None] 125 | 126 | if not self.load_vae_feat: 127 | if closest_size[0] / ori_h > closest_size[1] / ori_w: 128 | resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h) 129 | else: 130 | resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1] 131 | self.transform = T.Compose([ 132 | T.Lambda(lambda img: img.convert('RGB')), 133 | T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC 134 | T.CenterCrop(closest_size), 135 | T.ToTensor(), 136 | T.Normalize([.5], [.5]), 137 | ]) 138 | 139 | if self.transform: 140 | img = self.transform(img) 141 | 142 | return img, txt_fea, attention_mask, data_info 143 | 144 | def __getitem__(self, idx): 145 | for _ in range(20): 146 | try: 147 | return self.getdata(idx) 148 | except Exception as e: 149 | print(f"Error details: {str(e)}") 150 | idx = random.choice(self.ratio_index[self.closest_ratio]) 151 | raise RuntimeError('Too many bad data.') 152 | -------------------------------------------------------------------------------- /diffusion/data/datasets/SA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS 8 | from torch.utils.data import Dataset 9 | from diffusers.utils.torch_utils import randn_tensor 10 | 11 | from diffusion.data.builder import get_data_path, DATASETS 12 | 13 | 14 | @DATASETS.register_module() 15 | class SAM(Dataset): 16 | def __init__(self, 17 | root, 18 | image_list_txt='part0.txt', 19 | transform=None, 20 | resolution=256, 21 | sample_subset=None, 22 | load_vae_feat=False, 23 | mask_ratio=0.0, 24 | mask_type='null', 25 | **kwargs): 26 | self.root = get_data_path(root) 27 | self.transform = transform 28 | self.load_vae_feat = load_vae_feat 29 | self.mask_type = mask_type 30 | self.mask_ratio = mask_ratio 31 | self.resolution = resolution 32 | self.img_samples = [] 33 | self.txt_feat_samples = [] 34 | self.vae_feat_samples = [] 35 | image_list_txt = image_list_txt if isinstance(image_list_txt, list) else [image_list_txt] 36 | if image_list_txt == 'all': 37 | image_list_txts = os.listdir(os.path.join(self.root, 'partition')) 38 | for txt in image_list_txts: 39 | image_list = os.path.join(self.root, 'partition', txt) 40 | with open(image_list, 'r') as f: 41 | lines = [line.strip() for line in f.readlines()] 42 | self.img_samples.extend([os.path.join(self.root, 'images', i+'.jpg') for i in lines]) 43 | self.txt_feat_samples.extend([os.path.join(self.root, 'caption_feature_wmask', i+'.npz') for i in lines]) 44 | elif isinstance(image_list_txt, list): 45 | for txt in image_list_txt: 46 | image_list = os.path.join(self.root, 'partition', txt) 47 | with open(image_list, 'r') as f: 48 | lines = [line.strip() for line in f.readlines()] 49 | self.img_samples.extend([os.path.join(self.root, 'images', i + '.jpg') for i in lines]) 50 | self.txt_feat_samples.extend([os.path.join(self.root, 'caption_feature_wmask', i + '.npz') for i in lines]) 51 | self.vae_feat_samples.extend([os.path.join(self.root, 'img_vae_feature/train_vae_256/noflip', i + '.npy') for i in lines]) 52 | 53 | self.ori_imgs_nums = len(self) 54 | # self.img_samples = self.img_samples[:10000] 55 | # Set loader and extensions 56 | if load_vae_feat: 57 | self.transform = None 58 | self.loader = self.vae_feat_loader 59 | else: 60 | self.loader = default_loader 61 | 62 | if sample_subset is not None: 63 | self.sample_subset(sample_subset) # sample dataset for local debug 64 | 65 | def getdata(self, idx): 66 | img_path = self.img_samples[idx] 67 | npz_path = self.txt_feat_samples[idx] 68 | npy_path = self.vae_feat_samples[idx] 69 | data_info = {'img_hw': torch.tensor([self.resolution, self.resolution], dtype=torch.float32), 70 | 'aspect_ratio': torch.tensor(1.)} 71 | 72 | img = self.loader(npy_path) if self.load_vae_feat else self.loader(img_path) 73 | npz_info = np.load(npz_path) 74 | txt_fea = torch.from_numpy(npz_info['caption_feature']) 75 | attention_mask = torch.ones(1, 1, txt_fea.shape[1]) 76 | if 'attention_mask' in npz_info.keys(): 77 | attention_mask = torch.from_numpy(npz_info['attention_mask'])[None] 78 | 79 | if self.transform: 80 | img = self.transform(img) 81 | 82 | data_info["mask_type"] = self.mask_type 83 | 84 | return img, txt_fea, attention_mask, data_info 85 | 86 | def __getitem__(self, idx): 87 | for _ in range(20): 88 | try: 89 | return self.getdata(idx) 90 | except Exception: 91 | print(self.img_samples[idx], ' info is not correct') 92 | idx = np.random.randint(len(self)) 93 | raise RuntimeError('Too many bad data.') 94 | 95 | @staticmethod 96 | def vae_feat_loader(path): 97 | # [mean, std] 98 | mean, std = torch.from_numpy(np.load(path)).chunk(2) 99 | sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype) 100 | return mean + std * sample 101 | # return mean 102 | 103 | def sample_subset(self, ratio): 104 | sampled_idx = random.sample(list(range(len(self))), int(len(self) * ratio)) 105 | self.img_samples = [self.img_samples[i] for i in sampled_idx] 106 | self.txt_feat_samples = [self.txt_feat_samples[i] for i in sampled_idx] 107 | 108 | def __len__(self): 109 | return len(self.img_samples) 110 | -------------------------------------------------------------------------------- /diffusion/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .SA import SAM 2 | from .InternalData import InternalData 3 | from .InternalData_ms import InternalDataMS 4 | from .Dreambooth import DreamBooth 5 | from .pixart_control import InternalDataHed 6 | from .utils import * 7 | -------------------------------------------------------------------------------- /diffusion/data/datasets/__pycache__/Dreambooth.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/datasets/__pycache__/Dreambooth.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/data/datasets/__pycache__/Dreambooth.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/datasets/__pycache__/Dreambooth.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/data/datasets/__pycache__/InternalData.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/datasets/__pycache__/InternalData.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/data/datasets/__pycache__/InternalData.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/datasets/__pycache__/InternalData.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/data/datasets/__pycache__/InternalData_ms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/datasets/__pycache__/InternalData_ms.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/data/datasets/__pycache__/InternalData_ms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/datasets/__pycache__/InternalData_ms.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/data/datasets/__pycache__/SA.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/datasets/__pycache__/SA.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/data/datasets/__pycache__/SA.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/datasets/__pycache__/SA.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/data/datasets/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/datasets/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/data/datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/data/datasets/__pycache__/pixart_control.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/datasets/__pycache__/pixart_control.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/data/datasets/__pycache__/pixart_control.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/datasets/__pycache__/pixart_control.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/data/datasets/__pycache__/pixart_controldit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/datasets/__pycache__/pixart_controldit.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/data/datasets/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/datasets/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/data/datasets/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/data/datasets/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/data/datasets/pixart_control.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS 7 | from torch.utils.data import Dataset 8 | from diffusers.utils.torch_utils import randn_tensor 9 | from torchvision import transforms as T 10 | from diffusion.data.builder import get_data_path, DATASETS 11 | 12 | import json, time 13 | 14 | 15 | @DATASETS.register_module() 16 | class InternalDataHed(Dataset): 17 | def __init__(self, 18 | root, 19 | image_list_json='data_info.json', 20 | transform=None, 21 | resolution=1024, 22 | sample_subset=None, 23 | load_vae_feat=False, 24 | input_size=32, 25 | patch_size=2, 26 | mask_ratio=0.0, 27 | load_mask_index=False, 28 | train_ratio=1.0, 29 | mode='train', 30 | **kwargs): 31 | self.root = get_data_path(root) 32 | self.transform = transform 33 | self.load_vae_feat = load_vae_feat 34 | self.ori_imgs_nums = 0 35 | self.resolution = resolution 36 | self.N = int(resolution // (input_size // patch_size)) 37 | self.mask_ratio = mask_ratio 38 | self.load_mask_index = load_mask_index 39 | self.meta_data_clean = [] 40 | self.img_samples = [] 41 | self.txt_feat_samples = [] 42 | self.vae_feat_samples = [] 43 | self.hed_feat_samples = [] 44 | self.prompt_samples = [] 45 | 46 | image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json] 47 | for json_file in image_list_json: 48 | meta_data = self.load_json(os.path.join(self.root, 'partition_filter', json_file)) 49 | self.ori_imgs_nums += len(meta_data) 50 | meta_data_clean = [item for item in meta_data if item['ratio'] <= 4] 51 | self.meta_data_clean.extend(meta_data_clean) 52 | self.img_samples.extend([os.path.join(self.root.replace('InternData', "InternImgs"), item['path']) for item in meta_data_clean]) 53 | self.txt_feat_samples.extend([os.path.join(self.root, 'prompt_feature', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')) for item in meta_data_clean]) 54 | self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_features_{resolution}/noflip', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')) for item in meta_data_clean]) 55 | self.hed_feat_samples.extend([os.path.join(self.root, f'hed_feature_{resolution}', item['path'].replace('.png', '.npz')) for item in meta_data_clean]) 56 | self.prompt_samples.extend([item['prompt'] for item in meta_data_clean]) 57 | 58 | total_sample = len(self.img_samples) 59 | used_sample_num = int(total_sample * train_ratio) 60 | print("using mode", mode) 61 | if mode == 'train': 62 | self.img_samples = self.img_samples[:used_sample_num] 63 | self.txt_feat_samples = self.txt_feat_samples[:used_sample_num] 64 | self.vae_feat_samples = self.vae_feat_samples[:used_sample_num] 65 | self.hed_feat_samples = self.hed_feat_samples[:used_sample_num] 66 | self.prompt_samples = self.prompt_samples[:used_sample_num] 67 | else: 68 | self.img_samples = self.img_samples[-used_sample_num:] 69 | self.txt_feat_samples = self.txt_feat_samples[-used_sample_num:] 70 | self.vae_feat_samples = self.vae_feat_samples[-used_sample_num:] 71 | self.hed_feat_samples = self.hed_feat_samples[-used_sample_num:] 72 | self.prompt_samples = self.prompt_samples[-used_sample_num:] 73 | 74 | # Set loader and extensions 75 | if load_vae_feat: 76 | self.transform = None 77 | self.loader = self.vae_feat_loader 78 | else: 79 | self.loader = default_loader 80 | 81 | if sample_subset is not None: 82 | self.sample_subset(sample_subset) # sample dataset for local debug 83 | 84 | def getdata(self, index): 85 | img_path = self.img_samples[index] 86 | npz_path = self.txt_feat_samples[index] 87 | npy_path = self.vae_feat_samples[index] 88 | hed_npz_path = self.hed_feat_samples[index] 89 | prompt = self.prompt_samples[index] 90 | # only trained on single-scale 1024 res data 91 | data_info = {'img_hw': torch.tensor([1024., 1024.], dtype=torch.float32), 'aspect_ratio': torch.tensor(1.)} 92 | 93 | if self.load_vae_feat: 94 | img = self.loader(npy_path) 95 | else: 96 | img = self.loader(img_path) 97 | 98 | hed_fea = self.vae_feat_loader_npz(hed_npz_path) 99 | txt_info = np.load(npz_path) 100 | txt_fea = torch.from_numpy(txt_info['caption_feature']) 101 | attention_mask = torch.ones(1, 1, txt_fea.shape[1]) 102 | if 'attention_mask' in txt_info.keys(): 103 | attention_mask = torch.from_numpy(txt_info['attention_mask'])[None] 104 | 105 | if self.transform: 106 | img = self.transform(img) 107 | 108 | data_info['condition'] = hed_fea 109 | data_info['prompt'] = prompt 110 | 111 | return img, txt_fea, attention_mask, data_info 112 | 113 | def __getitem__(self, idx): 114 | for i in range(20): 115 | try: 116 | data = self.getdata(idx) 117 | return data 118 | except Exception as e: 119 | print(f"Error details: {str(e)}") 120 | idx = np.random.randint(len(self)) 121 | raise RuntimeError('Too many bad data.') 122 | 123 | return data 124 | 125 | 126 | def get_data_info(self, idx): 127 | data_info = self.meta_data_clean[idx] 128 | return {'height': data_info['height'], 'width': data_info['width']} 129 | 130 | @staticmethod 131 | def vae_feat_loader(path): 132 | # [mean, std] 133 | input_img = torch.from_numpy(np.load(path)) 134 | mean, std = input_img.chunk(2) 135 | sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype) 136 | result = mean + std * sample 137 | return mean + std * sample 138 | 139 | @staticmethod 140 | def vae_feat_loader_npz(path): 141 | # [mean, std] 142 | mean, std = torch.from_numpy(np.load(path)['arr_0']).chunk(2) 143 | sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype) 144 | return mean + std * sample 145 | 146 | def load_json(self, file_path): 147 | with open(file_path, 'r') as f: 148 | meta_data = json.load(f) 149 | 150 | return meta_data 151 | 152 | def sample_subset(self, ratio): 153 | sampled_idx = random.sample(list(range(len(self))), int(len(self) * ratio)) 154 | self.img_samples = [self.img_samples[i] for i in sampled_idx] 155 | 156 | def __len__(self): 157 | return len(self.img_samples) 158 | 159 | def __getattr__(self, name): 160 | if name == "set_epoch": 161 | return lambda epoch: None 162 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") 163 | 164 | -------------------------------------------------------------------------------- /diffusion/data/datasets/utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | ASPECT_RATIO_1024 = { 4 | '0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.], 5 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.], 6 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.], 7 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.], 8 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.], 9 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.], 10 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.], 11 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.], 12 | '2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.], 13 | '3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.], 14 | } 15 | 16 | ASPECT_RATIO_512 = { 17 | '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0], 18 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], 19 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], 20 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], 21 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], 22 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], 23 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], 24 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], 25 | '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0], 26 | '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0] 27 | } 28 | 29 | ASPECT_RATIO_256 = { 30 | '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0], 31 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0], 32 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0], 33 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0], 34 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0], 35 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0], 36 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0], 37 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0], 38 | '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0], 39 | '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0] 40 | } 41 | 42 | ASPECT_RATIO_256_TEST = { 43 | '0.25': [128.0, 512.0], '0.28': [128.0, 464.0], 44 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0], 45 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0], 46 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0], 47 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0], 48 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0], 49 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0], 50 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0], 51 | '2.5': [400.0, 160.0], '3.0': [432.0, 144.0], 52 | '4.0': [512.0, 128.0] 53 | } 54 | 55 | ASPECT_RATIO_512_TEST = { 56 | '0.25': [256.0, 1024.0], '0.28': [256.0, 928.0], 57 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], 58 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], 59 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], 60 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], 61 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], 62 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], 63 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], 64 | '2.5': [800.0, 320.0], '3.0': [864.0, 288.0], 65 | '4.0': [1024.0, 256.0] 66 | } 67 | 68 | ASPECT_RATIO_1024_TEST = { 69 | '0.25': [512., 2048.], '0.28': [512., 1856.], 70 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.], 71 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.], 72 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.], 73 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.], 74 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.], 75 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.], 76 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.], 77 | '2.5': [1600., 640.], '3.0': [1728., 576.], 78 | '4.0': [2048., 512.], 79 | } 80 | 81 | 82 | def get_chunks(lst, n): 83 | for i in range(0, len(lst), n): 84 | yield lst[i:i + n] 85 | -------------------------------------------------------------------------------- /diffusion/data/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | 3 | TRANSFORMS = {} 4 | 5 | 6 | def register_transform(transform): 7 | name = transform.__name__ 8 | if name in TRANSFORMS: 9 | raise RuntimeError(f'Transform {name} has already registered.') 10 | TRANSFORMS.update({name: transform}) 11 | 12 | 13 | def get_transform(type, resolution): 14 | transform = TRANSFORMS[type](resolution) 15 | transform = T.Compose(transform) 16 | transform.image_size = resolution 17 | return transform 18 | 19 | 20 | @register_transform 21 | def default_train(n_px): 22 | return [ 23 | T.Lambda(lambda img: img.convert('RGB')), 24 | T.Resize(n_px), # Image.BICUBIC 25 | T.CenterCrop(n_px), 26 | # T.RandomHorizontalFlip(), 27 | T.ToTensor(), 28 | T.Normalize([0.5], [0.5]), 29 | ] 30 | -------------------------------------------------------------------------------- /diffusion/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Functions for downloading pre-trained PixArt models 9 | """ 10 | from torchvision.datasets.utils import download_url 11 | import torch 12 | import os 13 | import argparse 14 | 15 | 16 | pretrained_models = {'PixArt-XL-2-512x512.pth', 'PixArt-XL-2-1024-MS.pth', 'PixArt-XL-2-1024-ControlNet.pth'} 17 | vae_models = { 18 | 'sd-vae-ft-ema/config.json', 19 | 'sd-vae-ft-ema/diffusion_pytorch_model.bin' 20 | } 21 | t5_models = { 22 | 't5-v1_1-xxl/config.json', 't5-v1_1-xxl/pytorch_model-00001-of-00002.bin', 23 | 't5-v1_1-xxl/pytorch_model-00002-of-00002.bin', 't5-v1_1-xxl/pytorch_model.bin.index.json', 24 | 't5-v1_1-xxl/special_tokens_map.json', 't5-v1_1-xxl/spiece.model', 25 | 't5-v1_1-xxl/tokenizer_config.json', 26 | } 27 | 28 | 29 | def find_model(model_name): 30 | """ 31 | Finds a pre-trained G.pt model, downloading it if necessary. Alternatively, loads a model from a local path. 32 | """ 33 | if model_name in pretrained_models: 34 | return download_model(model_name) 35 | assert os.path.isfile(model_name), f'Could not find PixArt checkpoint at {model_name}' 36 | return torch.load(model_name, map_location=lambda storage, loc: storage) 37 | 38 | 39 | def my_load_model(model_name): 40 | return torch.load(model_name, map_location=lambda storage, loc: storage) 41 | 42 | 43 | def download_model(model_name): 44 | """ 45 | Downloads a pre-trained PixArt model from the web. 46 | """ 47 | assert model_name in pretrained_models 48 | local_path = f'output/pretrained_models/{model_name}' 49 | if not os.path.isfile(local_path): 50 | os.makedirs('output/pretrained_models', exist_ok=True) 51 | web_path = f'https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/{model_name}' 52 | download_url(web_path, '/home/jovyan/maao-data-cephfs-2/dataspace/caoke/pretrained_models/') 53 | return torch.load(local_path, map_location=lambda storage, loc: storage) 54 | 55 | 56 | def download_other(model_name, model_zoo, output_dir): 57 | """ 58 | Downloads a pre-trained PixArt model from the web. 59 | """ 60 | assert model_name in model_zoo 61 | local_path = os.path.join(output_dir, model_name) 62 | if not os.path.isfile(local_path): 63 | os.makedirs(output_dir, exist_ok=True) 64 | web_path = f'https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/{model_name}' 65 | print(web_path) 66 | download_url(web_path, os.path.join(output_dir, model_name.split('/')[0])) 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('--model_names', nargs='+', type=str, default=pretrained_models) 72 | args = parser.parse_args() 73 | model_names = args.model_names 74 | model_names = set(model_names) 75 | 76 | # Download PixArt checkpoints 77 | for t5_model in t5_models: 78 | download_other(t5_model, t5_models, '/home/jovyan/maao-data-cephfs-2/dataspace/caoke/pretrained_models/t5_ckpts') 79 | for vae_model in vae_models: 80 | download_other(vae_model, vae_models, '/home/jovyan/maao-data-cephfs-2/dataspace/caoke/pretrained_models/') 81 | for model in model_names: 82 | download_model(model) # for vae_model in vae_models: 83 | print('Done.') 84 | -------------------------------------------------------------------------------- /diffusion/dpm_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .model import gaussian_diffusion as gd 3 | from .model.dpm_solver import model_wrapper, DPM_Solver, NoiseScheduleVP 4 | 5 | 6 | def DPMS(model, condition, uncondition, cfg_scale, model_type='noise', noise_schedule="linear", guidance_type='classifier-free', model_kwargs=None, diffusion_steps=1000): 7 | if model_kwargs is None: 8 | model_kwargs = {} 9 | betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps)) 10 | 11 | ## 1. Define the noise schedule. 12 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas) 13 | 14 | ## 2. Convert your discrete-time `model` to the continuous-time 15 | ## noise prediction model. Here is an example for a diffusion model 16 | ## `model` with the noise prediction type ("noise") . 17 | model_fn = model_wrapper( 18 | model, 19 | noise_schedule, 20 | model_type=model_type, 21 | model_kwargs=model_kwargs, 22 | guidance_type=guidance_type, 23 | condition=condition, 24 | unconditional_condition=uncondition, 25 | guidance_scale=cfg_scale, 26 | ) 27 | ## 3. Define dpm-solver and sample by multistep DPM-Solver. 28 | return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") -------------------------------------------------------------------------------- /diffusion/iddpm.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | from .model.respace import SpacedDiffusion, space_timesteps 6 | from .model import gaussian_diffusion as gd 7 | 8 | 9 | def IDDPM( 10 | timestep_respacing, 11 | noise_schedule="linear", 12 | use_kl=False, 13 | sigma_small=False, 14 | predict_xstart=False, 15 | learn_sigma=True, 16 | pred_sigma=True, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000, 19 | snr=False, 20 | return_startx=False, 21 | ): 22 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 23 | if use_kl: 24 | loss_type = gd.LossType.RESCALED_KL 25 | elif rescale_learned_sigmas: 26 | loss_type = gd.LossType.RESCALED_MSE 27 | else: 28 | loss_type = gd.LossType.MSE 29 | if timestep_respacing is None or timestep_respacing == "": 30 | timestep_respacing = [diffusion_steps] 31 | return SpacedDiffusion( 32 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 33 | betas=betas, 34 | model_mean_type=( 35 | gd.ModelMeanType.START_X if predict_xstart else gd.ModelMeanType.EPSILON 36 | ), 37 | model_var_type=( 38 | (gd.ModelVarType.LEARNED_RANGE if learn_sigma else ( 39 | gd.ModelVarType.FIXED_LARGE 40 | if not sigma_small 41 | else gd.ModelVarType.FIXED_SMALL 42 | ) 43 | ) 44 | if pred_sigma 45 | else None 46 | ), 47 | loss_type=loss_type, 48 | snr=snr, 49 | return_startx=return_startx, 50 | # rescale_timesteps=rescale_timesteps, 51 | ) -------------------------------------------------------------------------------- /diffusion/model/.ipynb_checkpoints/builder-checkpoint.py: -------------------------------------------------------------------------------- 1 | from mmcv import Registry 2 | 3 | from diffusion.model.utils import set_grad_checkpoint 4 | 5 | MODELS = Registry('models') 6 | 7 | 8 | def build_model(cfg, use_grad_checkpoint=False, use_fp32_attention=False, gc_step=1, **kwargs): 9 | if isinstance(cfg, str): 10 | cfg = dict(type=cfg) 11 | model = MODELS.build(cfg, default_args=kwargs) 12 | if use_grad_checkpoint: 13 | set_grad_checkpoint(model, use_fp32_attention=use_fp32_attention, gc_step=gc_step) 14 | return model 15 | -------------------------------------------------------------------------------- /diffusion/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .nets import * 2 | -------------------------------------------------------------------------------- /diffusion/model/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/builder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/builder.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/builder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/builder.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/builder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/builder.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/diffusion_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/diffusion_utils.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/diffusion_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/diffusion_utils.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/diffusion_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/diffusion_utils.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/dpm_solver.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/dpm_solver.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/dpm_solver.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/dpm_solver.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/dpm_solver.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/dpm_solver.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/gaussian_diffusion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/gaussian_diffusion.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/gaussian_diffusion.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/gaussian_diffusion.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/gaussian_diffusion.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/gaussian_diffusion.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/hed.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/hed.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/hed.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/hed.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/respace.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/respace.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/respace.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/respace.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/respace.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/respace.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/sa_solver.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/sa_solver.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/sa_solver.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/sa_solver.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/t5.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/t5.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/t5.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/t5.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/model/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/builder.py: -------------------------------------------------------------------------------- 1 | from mmcv import Registry 2 | 3 | from diffusion.model.utils import set_grad_checkpoint 4 | 5 | MODELS = Registry('models') 6 | 7 | 8 | def build_model(cfg, use_grad_checkpoint=False, use_fp32_attention=False, gc_step=1, **kwargs): 9 | if isinstance(cfg, str): 10 | cfg = dict(type=cfg) 11 | model = MODELS.build(cfg, default_args=kwargs) 12 | if use_grad_checkpoint: 13 | set_grad_checkpoint(model, use_fp32_attention=use_fp32_attention, gc_step=gc_step) 14 | return model 15 | -------------------------------------------------------------------------------- /diffusion/model/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = next( 17 | ( 18 | obj 19 | for obj in (mean1, logvar1, mean2, logvar2) 20 | if isinstance(obj, th.Tensor) 21 | ), 22 | None, 23 | ) 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a continuous Gaussian distribution. 53 | :param x: the targets 54 | :param means: the Gaussian mean Tensor. 55 | :param log_scales: the Gaussian log stddev Tensor. 56 | :return: a tensor like x of log probabilities (in nats). 57 | """ 58 | centered_x = x - means 59 | inv_stdv = th.exp(-log_scales) 60 | normalized_x = centered_x * inv_stdv 61 | return th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob( 62 | normalized_x 63 | ) 64 | 65 | 66 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 67 | """ 68 | Compute the log-likelihood of a Gaussian distribution discretizing to a 69 | given image. 70 | :param x: the target images. It is assumed that this was uint8 values, 71 | rescaled to the range [-1, 1]. 72 | :param means: the Gaussian mean Tensor. 73 | :param log_scales: the Gaussian log stddev Tensor. 74 | :return: a tensor like x of log probabilities (in nats). 75 | """ 76 | assert x.shape == means.shape == log_scales.shape 77 | centered_x = x - means 78 | inv_stdv = th.exp(-log_scales) 79 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 80 | cdf_plus = approx_standard_normal_cdf(plus_in) 81 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 82 | cdf_min = approx_standard_normal_cdf(min_in) 83 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 84 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 85 | cdf_delta = cdf_plus - cdf_min 86 | log_probs = th.where( 87 | x < -0.999, 88 | log_cdf_plus, 89 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 90 | ) 91 | assert log_probs.shape == x.shape 92 | return log_probs 93 | -------------------------------------------------------------------------------- /diffusion/model/edm_sample.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | from diffusion.model.utils import * 6 | 7 | 8 | # ---------------------------------------------------------------------------- 9 | # Proposed EDM sampler (Algorithm 2). 10 | 11 | def edm_sampler( 12 | net, latents, class_labels=None, cfg_scale=None, randn_like=torch.randn_like, 13 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, 14 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **kwargs 15 | ): 16 | # Adjust noise levels based on what's supported by the network. 17 | sigma_min = max(sigma_min, net.sigma_min) 18 | sigma_max = min(sigma_max, net.sigma_max) 19 | 20 | # Time step discretization. 21 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 22 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * ( 23 | sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 24 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 25 | 26 | # Main sampling loop. 27 | x_next = latents.to(torch.float64) * t_steps[0] 28 | for i, (t_cur, t_next) in tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:])))): # 0, ..., N-1 29 | x_cur = x_next 30 | 31 | # Increase noise temporarily. 32 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 33 | t_hat = net.round_sigma(t_cur + gamma * t_cur) 34 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) 35 | 36 | # Euler step. 37 | denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64) 38 | d_cur = (x_hat - denoised) / t_hat 39 | x_next = x_hat + (t_next - t_hat) * d_cur 40 | 41 | # Apply 2nd order correction. 42 | if i < num_steps - 1: 43 | denoised = net(x_next.float(), t_next, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64) 44 | d_prime = (x_next - denoised) / t_next 45 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 46 | 47 | return x_next 48 | 49 | 50 | # ---------------------------------------------------------------------------- 51 | # Generalized ablation sampler, representing the superset of all sampling 52 | # methods discussed in the paper. 53 | 54 | def ablation_sampler( 55 | net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like, 56 | num_steps=18, sigma_min=None, sigma_max=None, rho=7, 57 | solver='heun', discretization='edm', schedule='linear', scaling='none', 58 | epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1, 59 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 60 | ): 61 | assert solver in ['euler', 'heun'] 62 | assert discretization in ['vp', 've', 'iddpm', 'edm'] 63 | assert schedule in ['vp', 've', 'linear'] 64 | assert scaling in ['vp', 'none'] 65 | 66 | # Helper functions for VP & VE noise level schedules. 67 | vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5 68 | vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t)) 69 | vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * ( 70 | sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d 71 | ve_sigma = lambda t: t.sqrt() 72 | ve_sigma_deriv = lambda t: 0.5 / t.sqrt() 73 | ve_sigma_inv = lambda sigma: sigma ** 2 74 | 75 | # Select default noise level range based on the specified time step discretization. 76 | if sigma_min is None: 77 | vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s) 78 | sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization] 79 | if sigma_max is None: 80 | vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1) 81 | sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization] 82 | 83 | # Adjust noise levels based on what's supported by the network. 84 | sigma_min = max(sigma_min, net.sigma_min) 85 | sigma_max = min(sigma_max, net.sigma_max) 86 | 87 | # Compute corresponding betas for VP. 88 | vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1) 89 | vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d 90 | 91 | # Define time steps in terms of noise level. 92 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 93 | if discretization == 'vp': 94 | orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) 95 | sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) 96 | elif discretization == 've': 97 | orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1))) 98 | sigma_steps = ve_sigma(orig_t_steps) 99 | elif discretization == 'iddpm': 100 | u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) 101 | alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 102 | for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 103 | u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt() 104 | u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] 105 | sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] 106 | else: 107 | assert discretization == 'edm' 108 | sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * ( 109 | sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 110 | 111 | # Define noise level schedule. 112 | if schedule == 'vp': 113 | sigma = vp_sigma(vp_beta_d, vp_beta_min) 114 | sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) 115 | sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) 116 | elif schedule == 've': 117 | sigma = ve_sigma 118 | sigma_deriv = ve_sigma_deriv 119 | sigma_inv = ve_sigma_inv 120 | else: 121 | assert schedule == 'linear' 122 | sigma = lambda t: t 123 | sigma_deriv = lambda t: 1 124 | sigma_inv = lambda sigma: sigma 125 | 126 | # Define scaling schedule. 127 | if scaling == 'vp': 128 | s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() 129 | s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) 130 | else: 131 | assert scaling == 'none' 132 | s = lambda t: 1 133 | s_deriv = lambda t: 0 134 | 135 | # Compute final time steps based on the corresponding noise levels. 136 | t_steps = sigma_inv(net.round_sigma(sigma_steps)) 137 | t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 138 | 139 | # Main sampling loop. 140 | t_next = t_steps[0] 141 | x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) 142 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 143 | x_cur = x_next 144 | 145 | # Increase noise temporarily. 146 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0 147 | t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) 148 | x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s( 149 | t_hat) * S_noise * randn_like(x_cur) 150 | 151 | # Euler step. 152 | h = t_next - t_hat 153 | denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to( 154 | torch.float64) 155 | d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s( 156 | t_hat) / sigma(t_hat) * denoised 157 | x_prime = x_hat + alpha * h * d_cur 158 | t_prime = t_hat + alpha * h 159 | 160 | # Apply 2nd order correction. 161 | if solver == 'euler' or i == num_steps - 1: 162 | x_next = x_hat + h * d_cur 163 | else: 164 | assert solver == 'heun' 165 | denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to( 166 | torch.float64) 167 | d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv( 168 | t_prime) * s(t_prime) / sigma(t_prime) * denoised 169 | x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime) 170 | 171 | return x_next 172 | -------------------------------------------------------------------------------- /diffusion/model/hed.py: -------------------------------------------------------------------------------- 1 | # This is an improved version and model of HED edge detection with Apache License, Version 2.0. 2 | # Please use this implementation in your products 3 | # This implementation may produce slightly different results from Saining Xie's official implementations, 4 | # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. 5 | # Different from official models and other implementations, this is an RGB-input model (rather than BGR) 6 | # and in this way it works better for gradio's RGB protocol 7 | import sys 8 | from pathlib import Path 9 | current_file_path = Path(__file__).resolve() 10 | sys.path.insert(0, str(current_file_path.parent.parent.parent)) 11 | from torch import nn 12 | import torch 13 | import numpy as np 14 | from torchvision import transforms as T 15 | from tqdm import tqdm 16 | from torch.utils.data import Dataset, DataLoader 17 | import json 18 | from PIL import Image 19 | import torchvision.transforms.functional as TF 20 | from accelerate import Accelerator 21 | from diffusers.models import AutoencoderKL 22 | import os 23 | 24 | image_resize = 1024 25 | 26 | 27 | class DoubleConvBlock(nn.Module): 28 | def __init__(self, input_channel, output_channel, layer_number): 29 | super().__init__() 30 | self.convs = torch.nn.Sequential() 31 | self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) 32 | for i in range(1, layer_number): 33 | self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) 34 | self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0) 35 | 36 | def forward(self, x, down_sampling=False): 37 | h = x 38 | if down_sampling: 39 | h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) 40 | for conv in self.convs: 41 | h = conv(h) 42 | h = torch.nn.functional.relu(h) 43 | return h, self.projection(h) 44 | 45 | 46 | class ControlNetHED_Apache2(nn.Module): 47 | def __init__(self): 48 | super().__init__() 49 | self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) 50 | self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) 51 | self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) 52 | self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) 53 | self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) 54 | self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) 55 | 56 | def forward(self, x): 57 | h = x - self.norm 58 | h, projection1 = self.block1(h) 59 | h, projection2 = self.block2(h, down_sampling=True) 60 | h, projection3 = self.block3(h, down_sampling=True) 61 | h, projection4 = self.block4(h, down_sampling=True) 62 | h, projection5 = self.block5(h, down_sampling=True) 63 | return projection1, projection2, projection3, projection4, projection5 64 | 65 | 66 | class InternData(Dataset): 67 | def __init__(self): 68 | #### 69 | with open('data/InternData/partition/data_info.json', 'r') as f: 70 | self.j = json.load(f) 71 | self.transform = T.Compose([ 72 | T.Lambda(lambda img: img.convert('RGB')), 73 | T.Resize(image_resize), # Image.BICUBIC 74 | T.CenterCrop(image_resize), 75 | T.ToTensor(), 76 | ]) 77 | 78 | def __len__(self): 79 | return len(self.j) 80 | 81 | def getdata(self, idx): 82 | 83 | path = self.j[idx]['path'] 84 | image = Image.open("data/InternImgs/" + path) 85 | image = self.transform(image) 86 | return image, path 87 | 88 | def __getitem__(self, idx): 89 | for i in range(20): 90 | try: 91 | data = self.getdata(idx) 92 | return data 93 | except Exception as e: 94 | print(f"Error details: {str(e)}") 95 | idx = np.random.randint(len(self)) 96 | raise RuntimeError('Too many bad data.') 97 | 98 | class HEDdetector(nn.Module): 99 | def __init__(self, feature=False, vae=None): 100 | super().__init__() 101 | self.model = ControlNetHED_Apache2() 102 | self.model.load_state_dict(torch.load('/home/jovyan/maao-data-cephfs-0/dataspace/maao/projects/Common/models/PixArt-alpha/ControlNetHED.pth', map_location='cpu')) 103 | self.model.eval() 104 | self.model.requires_grad_(False) 105 | if feature: 106 | if vae is None: 107 | self.vae = AutoencoderKL.from_pretrained("/home/jovyan/maao-data-cephfs-0/dataspace/maao/projects/Common/models/PixArt-alpha/sd-vae-ft-ema") 108 | else: 109 | self.vae = vae 110 | self.vae.eval() 111 | self.vae.requires_grad_(False) 112 | else: 113 | self.vae = None 114 | 115 | def forward(self, input_image): 116 | C, H, W = input_image.shape 117 | with torch.inference_mode(): 118 | edges = self.model(input_image * 255.) 119 | edges = torch.cat([TF.resize(e, [H, W]) for e in edges], dim=1) 120 | edge = 1 / (1 + torch.exp(-torch.mean(edges, dim=1, keepdim=True))) 121 | edge.clip_(0, 1) 122 | if self.vae: 123 | edge = TF.normalize(edge, [.5], [.5]) 124 | edge = edge.repeat(1, 3, 1, 1) 125 | posterior = self.vae.encode(edge).latent_dist 126 | edge = torch.cat([posterior.mean, posterior.std], dim=1).cpu().numpy() 127 | return edge 128 | 129 | 130 | def main(): 131 | dataset = InternData() 132 | dataloader = DataLoader(dataset, batch_size=10, shuffle=False, num_workers=8, pin_memory=True) 133 | hed = HEDdetector() 134 | 135 | accelerator = Accelerator() 136 | hed, dataloader = accelerator.prepare(hed, dataloader) 137 | 138 | 139 | for img, path in tqdm(dataloader): 140 | out = hed(img.cuda()) 141 | for p, o in zip(path, out): 142 | save = f'data/InternalData/hed_feature_{image_resize}/' + p.replace('.png', '.npz') 143 | if os.path.exists(save): 144 | continue 145 | os.makedirs(os.path.dirname(save), exist_ok=True) 146 | np.savez_compressed(save, o) 147 | 148 | 149 | if __name__ == "__main__": 150 | main() 151 | -------------------------------------------------------------------------------- /diffusion/model/llava/__init__.py: -------------------------------------------------------------------------------- 1 | from diffusion.model.llava.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig -------------------------------------------------------------------------------- /diffusion/model/llava/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | from .attention import ATTN_CLASS_REGISTRY 6 | from .norm import NORM_CLASS_REGISTRY 7 | 8 | class MPTMLP(nn.Module): 9 | 10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): 11 | super().__init__() 12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 13 | self.act = nn.GELU(approximate='none') 14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 15 | self.down_proj._is_residual = True 16 | 17 | def forward(self, x): 18 | return self.down_proj(self.act(self.up_proj(x))) 19 | 20 | class MPTBlock(nn.Module): 21 | 22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict = None, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs): 23 | if attn_config is None: 24 | attn_config = { 25 | 'attn_type': 'multihead_attention', 26 | 'attn_pdrop': 0.0, 27 | 'attn_impl': 'triton', 28 | 'qk_ln': False, 29 | 'clip_qkv': None, 30 | 'softmax_scale': None, 31 | 'prefix_lm': False, 32 | 'attn_uses_sequence_id': False, 33 | 'alibi': False, 34 | 'alibi_bias_max': 8, 35 | } 36 | del kwargs 37 | super().__init__() 38 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 39 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] 40 | self.norm_1 = norm_class(d_model, device=device) 41 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device) 42 | self.norm_2 = norm_class(d_model, device=device) 43 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) 44 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 45 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 46 | 47 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 48 | a = self.norm_1(x) 49 | (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) 50 | x = x + self.resid_attn_dropout(b) 51 | m = self.norm_2(x) 52 | n = self.ffn(m) 53 | x = x + self.resid_ffn_dropout(n) 54 | return (x, past_key_value) -------------------------------------------------------------------------------- /diffusion/model/llava/mpt/configuration_mpt.py: -------------------------------------------------------------------------------- 1 | """A HuggingFace-style model configuration.""" 2 | from typing import Dict, Optional, Union 3 | from transformers import PretrainedConfig 4 | attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8} 5 | init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'} 6 | 7 | class MPTConfig(PretrainedConfig): 8 | model_type = 'mpt' 9 | 10 | def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs): 11 | """The MPT configuration class. 12 | 13 | Args: 14 | d_model (int): The size of the embedding dimension of the model. 15 | n_heads (int): The number of attention heads. 16 | n_layers (int): The number of layers in the model. 17 | expansion_ratio (int): The ratio of the up/down scale in the MLP. 18 | max_seq_len (int): The maximum sequence length of the model. 19 | vocab_size (int): The size of the vocabulary. 20 | resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. 21 | emb_pdrop (float): The dropout probability for the embedding layer. 22 | learned_pos_emb (bool): Whether to use learned positional embeddings 23 | attn_config (Dict): A dictionary used to configure the model's attention module: 24 | attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention 25 | attn_pdrop (float): The dropout probability for the attention layers. 26 | attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. 27 | qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. 28 | clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to 29 | this value. 30 | softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, 31 | use the default scale of ``1/sqrt(d_keys)``. 32 | prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an 33 | extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix 34 | can attend to one another bi-directionally. Tokens outside the prefix use causal attention. 35 | attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. 36 | When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates 37 | which sub-sequence each token belongs to. 38 | Defaults to ``False`` meaning any provided `sequence_id` will be ignored. 39 | alibi (bool): Whether to use the alibi bias instead of position embeddings. 40 | alibi_bias_max (int): The maximum value of the alibi bias. 41 | init_device (str): The device to use for parameter initialization. 42 | logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. 43 | no_bias (bool): Whether to use bias in all layers. 44 | verbose (int): The verbosity level. 0 is silent. 45 | embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. 46 | norm_type (str): choose type of norm to use 47 | multiquery_attention (bool): Whether to use multiquery attention implementation. 48 | use_cache (bool): Whether or not the model should return the last key/values attentions 49 | init_config (Dict): A dictionary used to configure the model initialization: 50 | init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', 51 | 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 52 | 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. 53 | init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. 54 | emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. 55 | emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution 56 | used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. 57 | init_std (float): The standard deviation of the normal distribution used to initialize the model, 58 | if using the baseline_ parameter initialization scheme. 59 | init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. 60 | fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. 61 | init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. 62 | --- 63 | See llmfoundry.models.utils.param_init_fns.py for info on other param init config options 64 | """ 65 | self.d_model = d_model 66 | self.n_heads = n_heads 67 | self.n_layers = n_layers 68 | self.expansion_ratio = expansion_ratio 69 | self.max_seq_len = max_seq_len 70 | self.vocab_size = vocab_size 71 | self.resid_pdrop = resid_pdrop 72 | self.emb_pdrop = emb_pdrop 73 | self.learned_pos_emb = learned_pos_emb 74 | self.attn_config = attn_config 75 | self.init_device = init_device 76 | self.logit_scale = logit_scale 77 | self.no_bias = no_bias 78 | self.verbose = verbose 79 | self.embedding_fraction = embedding_fraction 80 | self.norm_type = norm_type 81 | self.use_cache = use_cache 82 | self.init_config = init_config 83 | if 'name' in kwargs: 84 | del kwargs['name'] 85 | if 'loss_fn' in kwargs: 86 | del kwargs['loss_fn'] 87 | super().__init__(**kwargs) 88 | self._validate_config() 89 | 90 | def _set_config_defaults(self, config, config_defaults): 91 | for (k, v) in config_defaults.items(): 92 | if k not in config: 93 | config[k] = v 94 | return config 95 | 96 | def _validate_config(self): 97 | self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults) 98 | self.init_config = self._set_config_defaults(self.init_config, init_config_defaults) 99 | if self.d_model % self.n_heads != 0: 100 | raise ValueError('d_model must be divisible by n_heads') 101 | if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])): 102 | raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1") 103 | if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: 104 | raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") 105 | if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 106 | raise NotImplementedError('prefix_lm only implemented with torch and triton attention.') 107 | if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 108 | raise NotImplementedError('alibi only implemented with torch and triton attention.') 109 | if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 110 | raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.') 111 | if self.embedding_fraction > 1 or self.embedding_fraction <= 0: 112 | raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!') 113 | if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model': 114 | raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 115 | if self.init_config.get('name', None) is None: 116 | raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.") 117 | if not self.learned_pos_emb and (not self.attn_config['alibi']): 118 | raise ValueError( 119 | 'Positional information must be provided to the model using either learned_pos_emb or alibi.' 120 | ) -------------------------------------------------------------------------------- /diffusion/model/llava/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _cast_if_autocast_enabled(tensor): 4 | if torch.is_autocast_enabled(): 5 | if tensor.device.type == 'cuda': 6 | dtype = torch.get_autocast_gpu_dtype() 7 | elif tensor.device.type == 'cpu': 8 | dtype = torch.get_autocast_cpu_dtype() 9 | else: 10 | raise NotImplementedError() 11 | return tensor.to(dtype=dtype) 12 | return tensor 13 | 14 | class LPLayerNorm(torch.nn.LayerNorm): 15 | 16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): 17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) 18 | 19 | def forward(self, x): 20 | module_device = x.device 21 | downcast_x = _cast_if_autocast_enabled(x) 22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 24 | with torch.autocast(enabled=False, device_type=module_device.type): 25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) 26 | 27 | def rms_norm(x, weight=None, eps=1e-05): 28 | output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 29 | return output * weight if weight is not None else output 30 | 31 | class RMSNorm(torch.nn.Module): 32 | 33 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 34 | super().__init__() 35 | self.eps = eps 36 | if weight: 37 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) 38 | else: 39 | self.register_parameter('weight', None) 40 | 41 | def forward(self, x): 42 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 43 | 44 | class LPRMSNorm(RMSNorm): 45 | 46 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 47 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) 48 | 49 | def forward(self, x): 50 | downcast_x = _cast_if_autocast_enabled(x) 51 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 52 | with torch.autocast(enabled=False, device_type=x.device.type): 53 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 54 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} -------------------------------------------------------------------------------- /diffusion/model/nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .PixArt import PixArt, PixArt_XL_2 2 | from .PixArtMS import PixArtMS, PixArtMS_XL_2, PixArtMSBlock 3 | from .pixart_controlnet import ControlPixArtHalf, ControlPixArtMSHalf 4 | from .pixart_relactrl_v1 import ControlPixArtMSHalf_RelaCtrl 5 | 6 | 7 | -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/PixArt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/PixArt.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/PixArt.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/PixArt.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/PixArt.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/PixArt.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/PixArtMS.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/PixArtMS.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/PixArtMS.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/PixArtMS.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/PixArtMS.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/PixArtMS.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/PixArt_blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/PixArt_blocks.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/PixArt_blocks.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/PixArt_blocks.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/PixArt_blocks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/PixArt_blocks.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_adamamba.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_adamamba.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_adapter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_adapter.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe2.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe3_mlp.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe3_mlp.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe4.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe4.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe4_allmlp.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe4_allmlp.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe4_show.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe4_show.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior1.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior1.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior11.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior11.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior2.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior3.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior3.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior4.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior4.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior4s.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior4s.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior51.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior51.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior52.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior52.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior53.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior53.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior54.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior54.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior551.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior551.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior552.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior552.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e1.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e1.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e2.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e3.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e3.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e3_show.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e3_show.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e4.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e4.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e50.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e50.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e51.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e51.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e52.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_e52.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_same1.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_same1.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_same2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_same2.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_x0bug.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_x0bug.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_x0e1.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior55_x0e1.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior56.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior56.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior57.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior57.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior58.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior58.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior59.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_moe5_prior59.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_output.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_output.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_output.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_output.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_output2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_output2.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_p1o1.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_p1o1.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_p1o2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_p1o2.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_p1o3.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_p1o3.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_p2o1.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_p2o1.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_p2o2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_p2o2.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_remove.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_remove.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_skipcopy.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_skipcopy.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_skiporigin.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_skiporigin.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_skipshow.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_skipshow.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage10_style.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage10_style.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage1_cop27.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage1_cop27.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage1_cop27_shunxu.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage1_cop27_shunxu.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy11_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy11_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy12_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy12_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy12_our2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy12_our2.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy13.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy13.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy13_shunxu.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy13_shunxu.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy13_skip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy13_skip.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy13_skip_new.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy13_skip_new.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy13_top.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy13_top.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy13_top2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage1_copy13_top2.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage2_our11_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage2_our11_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage2_our12_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage2_our12_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage2_our13_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage2_our13_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage3_our12_our1.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage3_our12_our1.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage3_our12_our2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage3_our12_our2.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage3_our12_our3.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage3_our12_our3.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage3_our12_our4.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage3_our12_our4.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage3_our12_our5.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage3_our12_our5.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage3_our12_our6.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage3_our12_our6.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage3_our12_our7.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage3_our12_our7.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage4_copy10_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage4_copy10_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage4_copy11_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage4_copy11_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage4_copy12_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage4_copy12_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage4_copy13_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage4_copy13_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage4_copy7_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage4_copy7_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage5_copy10_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage5_copy10_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage5_copy11_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage5_copy11_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage5_copy12_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage5_copy12_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage5_copy13_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage5_copy13_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage5_copy7_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage5_copy7_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage6_copy11_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage6_copy11_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage6_copy11_our2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage6_copy11_our2.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage6_our11_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage6_our11_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage7_our11_our.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage7_our11_our.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage7_our11_our2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage7_our11_our2.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage7_our11_our3.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage7_our11_our3.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage7_our11_our4.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage7_our11_our4.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage8_our11_our5.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage8_our11_our5.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_controlnet_stage9_our11_our4_test.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_controlnet_stage9_our11_our4_test.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_main_output.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_main_output.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/pixart_relactrl_v1.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/pixart_relactrl_v1.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/nets/__pycache__/relactrl_v1.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/model/nets/__pycache__/relactrl_v1.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/model/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | frac_stride = 1 if section_count <= 1 else (size - 1) / (section_count - 1) 52 | cur_idx = 0.0 53 | taken_steps = [] 54 | for _ in range(section_count): 55 | taken_steps.append(start_idx + round(cur_idx)) 56 | cur_idx += frac_stride 57 | all_steps += taken_steps 58 | start_idx += size 59 | return set(all_steps) 60 | 61 | 62 | class SpacedDiffusion(GaussianDiffusion): 63 | """ 64 | A diffusion process which can skip steps in a base diffusion process. 65 | :param use_timesteps: a collection (sequence or set) of timesteps from the 66 | original diffusion process to retain. 67 | :param kwargs: the kwargs to create the base diffusion process. 68 | """ 69 | 70 | def __init__(self, use_timesteps, **kwargs): 71 | self.use_timesteps = set(use_timesteps) 72 | self.timestep_map = [] 73 | self.original_num_steps = len(kwargs["betas"]) 74 | 75 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 76 | last_alpha_cumprod = 1.0 77 | new_betas = [] 78 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 79 | if i in self.use_timesteps: 80 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 81 | last_alpha_cumprod = alpha_cumprod 82 | self.timestep_map.append(i) 83 | kwargs["betas"] = np.array(new_betas) 84 | super().__init__(**kwargs) 85 | 86 | def p_mean_variance( 87 | self, model, *args, **kwargs 88 | ): # pylint: disable=signature-differs 89 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 90 | 91 | def training_losses( 92 | self, model, *args, **kwargs 93 | ): # pylint: disable=signature-differs 94 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 95 | 96 | def training_losses_diffusers( 97 | self, model, *args, **kwargs 98 | ): # pylint: disable=signature-differs 99 | return super().training_losses_diffusers(self._wrap_model(model), *args, **kwargs) 100 | 101 | def condition_mean(self, cond_fn, *args, **kwargs): 102 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def condition_score(self, cond_fn, *args, **kwargs): 105 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 106 | 107 | def _wrap_model(self, model): 108 | if isinstance(model, _WrappedModel): 109 | return model 110 | return _WrappedModel( 111 | model, self.timestep_map, self.original_num_steps 112 | ) 113 | 114 | def _scale_timesteps(self, t): 115 | # Scaling is done by the wrapped model. 116 | return t 117 | 118 | 119 | class _WrappedModel: 120 | def __init__(self, model, timestep_map, original_num_steps): 121 | self.model = model 122 | self.timestep_map = timestep_map 123 | # self.rescale_timesteps = rescale_timesteps 124 | self.original_num_steps = original_num_steps 125 | 126 | def __call__(self, x, timestep, **kwargs): 127 | map_tensor = th.tensor(self.timestep_map, device=timestep.device, dtype=timestep.dtype) 128 | new_ts = map_tensor[timestep] 129 | # if self.rescale_timesteps: 130 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 131 | return self.model(x, timestep=new_ts, **kwargs) 132 | -------------------------------------------------------------------------------- /diffusion/model/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs, device=local_ts.device) for _ in batch_sizes] 96 | loss_batches = [th.zeros(max_bs, device=local_losses.device) for _ in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /diffusion/sa_sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from diffusion.model.sa_solver import NoiseScheduleVP, model_wrapper, SASolver 7 | from .model import gaussian_diffusion as gd 8 | 9 | 10 | class SASolverSampler(object): 11 | def __init__(self, model, 12 | noise_schedule="linear", 13 | diffusion_steps=1000, 14 | device='cpu', 15 | ): 16 | super().__init__() 17 | self.model = model 18 | self.device = device 19 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(device) 20 | betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps)) 21 | alphas = 1.0 - betas 22 | self.register_buffer('alphas_cumprod', to_torch(np.cumprod(alphas, axis=0))) 23 | 24 | def register_buffer(self, name, attr): 25 | if type(attr) == torch.Tensor and attr.device != torch.device("cuda"): 26 | attr = attr.to(torch.device("cuda")) 27 | setattr(self, name, attr) 28 | 29 | @torch.no_grad() 30 | def sample(self, S, batch_size, shape, conditioning=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0., mask=None, x0=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1., unconditional_conditioning=None, model_kwargs=None, **kwargs): 31 | if model_kwargs is None: 32 | model_kwargs = {} 33 | if conditioning is not None: 34 | if isinstance(conditioning, dict): 35 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 36 | if cbs != batch_size: 37 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 38 | elif conditioning.shape[0] != batch_size: 39 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 40 | 41 | # sampling 42 | C, H, W = shape 43 | size = (batch_size, C, H, W) 44 | 45 | device = self.device 46 | img = torch.randn(size, device=device) if x_T is None else x_T 47 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 48 | 49 | model_fn = model_wrapper( 50 | self.model, 51 | ns, 52 | model_type="noise", 53 | guidance_type="classifier-free", 54 | condition=conditioning, 55 | unconditional_condition=unconditional_conditioning, 56 | guidance_scale=unconditional_guidance_scale, 57 | model_kwargs=model_kwargs, 58 | ) 59 | 60 | sasolver = SASolver(model_fn, ns, algorithm_type="data_prediction") 61 | 62 | tau_t = lambda t: eta if 0.2 <= t <= 0.8 else 0 63 | 64 | x = sasolver.sample(mode='few_steps', x=img, tau=tau_t, steps=S, skip_type='time', skip_order=1, predictor_order=2, corrector_order=2, pc_mode='PEC', return_intermediate=False) 65 | 66 | return x.to(device), None -------------------------------------------------------------------------------- /diffusion/utils/.ipynb_checkpoints/logger-checkpoint.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch.distributed as dist 4 | from datetime import datetime 5 | from .dist_utils import is_local_master 6 | # from mmcv.utils.logging import logger_initialized 7 | 8 | 9 | def get_root_logger(log_file=None, log_level=logging.INFO, name='PixArt'): 10 | """Get root logger. 11 | 12 | Args: 13 | log_file (str, optional): File path of log. Defaults to None. 14 | log_level (int, optional): The level of logger. 15 | Defaults to logging.INFO. 16 | name (str): logger name 17 | Returns: 18 | :obj:`logging.Logger`: The obtained logger 19 | """ 20 | if log_file is None: 21 | log_file = '/dev/null' 22 | return get_logger(name=name, log_file=log_file, log_level=log_level) 23 | 24 | 25 | def get_logger(name, log_file=None, log_level=logging.INFO): 26 | """Initialize and get a logger by name. 27 | 28 | If the logger has not been initialized, this method will initialize the 29 | logger by adding one or two handlers, otherwise the initialized logger will 30 | be directly returned. During initialization, a StreamHandler will always be 31 | added. If `log_file` is specified and the process rank is 0, a FileHandler 32 | will also be added. 33 | 34 | Args: 35 | name (str): Logger name. 36 | log_file (str | None): The log filename. If specified, a FileHandler 37 | will be added to the logger. 38 | log_level (int): The logger level. Note that only the process of 39 | rank 0 is affected, and other processes will set the level to 40 | "Error" thus be silent most of the time. 41 | 42 | Returns: 43 | logging.Logger: The expected logger. 44 | """ 45 | logger = logging.getLogger(name) 46 | logger.propagate = False # disable root logger to avoid duplicate logging 47 | 48 | # if name in logger_initialized: 49 | # return logger 50 | # # handle hierarchical names 51 | # # e.g., logger "a" is initialized, then logger "a.b" will skip the 52 | # # initialization since it is a child of "a". 53 | # for logger_name in logger_initialized: 54 | # if name.startswith(logger_name): 55 | # return logger 56 | 57 | stream_handler = logging.StreamHandler() 58 | handlers = [stream_handler] 59 | 60 | rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 61 | # only rank 0 will add a FileHandler 62 | if rank == 0 and log_file is not None: 63 | file_handler = logging.FileHandler(log_file, 'w') 64 | handlers.append(file_handler) 65 | 66 | formatter = logging.Formatter( 67 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 68 | for handler in handlers: 69 | handler.setFormatter(formatter) 70 | handler.setLevel(log_level) 71 | logger.addHandler(handler) 72 | 73 | # only rank0 for each node will print logs 74 | log_level = log_level if is_local_master() else logging.ERROR 75 | logger.setLevel(log_level) 76 | 77 | # logger_initialized[name] = True 78 | 79 | return logger 80 | 81 | def rename_file_with_creation_time(file_path): 82 | # 获取文件的创建时间 83 | creation_time = os.path.getctime(file_path) 84 | creation_time_str = datetime.fromtimestamp(creation_time).strftime('%Y-%m-%d_%H-%M-%S') 85 | 86 | # 构建新的文件名 87 | dir_name, file_name = os.path.split(file_path) 88 | name, ext = os.path.splitext(file_name) 89 | new_file_name = f"{name}_{creation_time_str}{ext}" 90 | new_file_path = os.path.join(dir_name, new_file_name) 91 | 92 | # 重命名文件 93 | os.rename(file_path, new_file_path) 94 | print(f"File renamed to: {new_file_path}") 95 | -------------------------------------------------------------------------------- /diffusion/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__init__.py -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/checkpoint.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/checkpoint.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/checkpoint.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/checkpoint.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/data_sampler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/data_sampler.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/data_sampler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/data_sampler.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/dist_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/dist_utils.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/dist_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/dist_utils.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/dist_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/dist_utils.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/logger.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/logger.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/logger.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/logger.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/lr_scheduler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/lr_scheduler.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/lr_scheduler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/lr_scheduler.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/misc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/misc.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/misc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/misc.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/optimizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/optimizer.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/utils/__pycache__/optimizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/diffusion/utils/__pycache__/optimizer.cpython-39.pyc -------------------------------------------------------------------------------- /diffusion/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import torch 4 | 5 | from diffusion.utils.logger import get_root_logger 6 | 7 | 8 | def save_checkpoint(work_dir, 9 | epoch, 10 | model, 11 | model_ema=None, 12 | optimizer=None, 13 | lr_scheduler=None, 14 | keep_last=False, 15 | step=None, 16 | ): 17 | os.makedirs(work_dir, exist_ok=True) 18 | state_dict = dict(state_dict=model.state_dict()) 19 | if model_ema is not None: 20 | state_dict['state_dict_ema'] = model_ema.state_dict() 21 | if optimizer is not None: 22 | state_dict['optimizer'] = optimizer.state_dict() 23 | if lr_scheduler is not None: 24 | state_dict['scheduler'] = lr_scheduler.state_dict() 25 | if epoch is not None: 26 | state_dict['epoch'] = epoch 27 | file_path = os.path.join(work_dir, f"epoch_{epoch}.pth") 28 | if step is not None: 29 | file_path = file_path.split('.pth')[0] + f"_step_{step}.pth" 30 | logger = get_root_logger() 31 | torch.save(state_dict, file_path) 32 | logger.info(f'Saved checkpoint of epoch {epoch} to {file_path.format(epoch)}.') 33 | if keep_last: 34 | for i in range(epoch): 35 | previous_ckgt = file_path.format(i) 36 | if os.path.exists(previous_ckgt): 37 | os.remove(previous_ckgt) 38 | 39 | 40 | def load_checkpoint(checkpoint, 41 | model, 42 | model_ema=None, 43 | optimizer=None, 44 | lr_scheduler=None, 45 | load_ema=False, 46 | resume_optimizer=True, 47 | resume_lr_scheduler=True 48 | ): 49 | assert isinstance(checkpoint, str) 50 | ckpt_file = checkpoint 51 | checkpoint = torch.load(ckpt_file, map_location="cpu") 52 | 53 | state_dict_keys = ['pos_embed', 'base_model.pos_embed', 'model.pos_embed'] 54 | for key in state_dict_keys: 55 | if key in checkpoint['state_dict']: 56 | del checkpoint['state_dict'][key] 57 | if 'state_dict_ema' in checkpoint and key in checkpoint['state_dict_ema']: 58 | del checkpoint['state_dict_ema'][key] 59 | break 60 | 61 | if load_ema: 62 | state_dict = checkpoint['state_dict_ema'] 63 | else: 64 | state_dict = checkpoint.get('state_dict', checkpoint) # to be compatible with the official checkpoint 65 | # model.load_state_dict(state_dict) 66 | missing, unexpect = model.load_state_dict(state_dict, strict=False) 67 | if model_ema is not None: 68 | model_ema.load_state_dict(checkpoint['state_dict_ema'], strict=False) 69 | if optimizer is not None and resume_optimizer: 70 | optimizer.load_state_dict(checkpoint['optimizer']) 71 | if lr_scheduler is not None and resume_lr_scheduler: 72 | lr_scheduler.load_state_dict(checkpoint['scheduler']) 73 | logger = get_root_logger() 74 | if optimizer is not None: 75 | epoch = checkpoint.get('epoch', re.match(r'.*epoch_(\d*).*.pth', ckpt_file).group()[0]) 76 | logger.info(f'Resume checkpoint of epoch {epoch} from {ckpt_file}. Load ema: {load_ema}, ' 77 | f'resume optimizer: {resume_optimizer}, resume lr scheduler: {resume_lr_scheduler}.') 78 | return epoch, missing, unexpect 79 | logger.info(f'Load checkpoint from {ckpt_file}. Load ema: {load_ema}.') 80 | return missing, unexpect 81 | -------------------------------------------------------------------------------- /diffusion/utils/data_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | from typing import Sequence 4 | from torch.utils.data import BatchSampler, Sampler, Dataset 5 | from random import shuffle, choice 6 | from copy import deepcopy 7 | from diffusion.utils.logger import get_root_logger 8 | 9 | 10 | class AspectRatioBatchSampler(BatchSampler): 11 | """A sampler wrapper for grouping images with similar aspect ratio into a same batch. 12 | 13 | Args: 14 | sampler (Sampler): Base sampler. 15 | dataset (Dataset): Dataset providing data information. 16 | batch_size (int): Size of mini-batch. 17 | drop_last (bool): If ``True``, the sampler will drop the last batch if 18 | its size would be less than ``batch_size``. 19 | aspect_ratios (dict): The predefined aspect ratios. 20 | """ 21 | 22 | def __init__(self, 23 | sampler: Sampler, 24 | dataset: Dataset, 25 | batch_size: int, 26 | aspect_ratios: dict, 27 | drop_last: bool = False, 28 | config=None, 29 | valid_num=0, # take as valid aspect-ratio when sample number >= valid_num 30 | **kwargs) -> None: 31 | if not isinstance(sampler, Sampler): 32 | raise TypeError('sampler should be an instance of ``Sampler``, ' 33 | f'but got {sampler}') 34 | if not isinstance(batch_size, int) or batch_size <= 0: 35 | raise ValueError('batch_size should be a positive integer value, ' 36 | f'but got batch_size={batch_size}') 37 | self.sampler = sampler 38 | self.dataset = dataset 39 | self.batch_size = batch_size 40 | self.aspect_ratios = aspect_ratios 41 | self.drop_last = drop_last 42 | self.ratio_nums_gt = kwargs.get('ratio_nums', None) 43 | self.config = config 44 | assert self.ratio_nums_gt 45 | # buckets for each aspect ratio 46 | self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios} 47 | self.current_available_bucket_keys = [str(k) for k, v in self.ratio_nums_gt.items() if v >= valid_num] 48 | logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) 49 | logger.warning(f"Using valid_num={valid_num} in config file. Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}") 50 | 51 | def __iter__(self) -> Sequence[int]: 52 | for idx in self.sampler: 53 | data_info = self.dataset.get_data_info(idx) 54 | height, width = data_info['height'], data_info['width'] 55 | ratio = height / width 56 | # find the closest aspect ratio 57 | closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) 58 | if closest_ratio not in self.current_available_bucket_keys: 59 | continue 60 | bucket = self._aspect_ratio_buckets[closest_ratio] 61 | bucket.append(idx) 62 | # yield a batch of indices in the same aspect ratio group 63 | if len(bucket) == self.batch_size: 64 | yield bucket[:] 65 | del bucket[:] 66 | 67 | # yield the rest data and reset the buckets 68 | for bucket in self._aspect_ratio_buckets.values(): 69 | while len(bucket) > 0: 70 | if len(bucket) <= self.batch_size: 71 | if not self.drop_last: 72 | yield bucket[:] 73 | bucket = [] 74 | else: 75 | yield bucket[:self.batch_size] 76 | bucket = bucket[self.batch_size:] 77 | 78 | 79 | class BalancedAspectRatioBatchSampler(AspectRatioBatchSampler): 80 | def __init__(self, *args, **kwargs): 81 | super().__init__(*args, **kwargs) 82 | # Assign samples to each bucket 83 | self.ratio_nums_gt = kwargs.get('ratio_nums', None) 84 | assert self.ratio_nums_gt 85 | self._aspect_ratio_buckets = {float(ratio): [] for ratio in self.aspect_ratios.keys()} 86 | self.original_buckets = {} 87 | self.current_available_bucket_keys = [k for k, v in self.ratio_nums_gt.items() if v >= 3000] 88 | self.all_available_keys = deepcopy(self.current_available_bucket_keys) 89 | self.exhausted_bucket_keys = [] 90 | self.total_batches = len(self.sampler) // self.batch_size 91 | self._aspect_ratio_count = {} 92 | for k in self.all_available_keys: 93 | self._aspect_ratio_count[float(k)] = 0 94 | self.original_buckets[float(k)] = [] 95 | logger = get_root_logger(os.path.join(self.config.work_dir, 'train_log.log')) 96 | logger.warning(f"Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}") 97 | 98 | def __iter__(self) -> Sequence[int]: 99 | i = 0 100 | for idx in self.sampler: 101 | data_info = self.dataset.get_data_info(idx) 102 | height, width = data_info['height'], data_info['width'] 103 | ratio = height / width 104 | closest_ratio = float(min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))) 105 | if closest_ratio not in self.all_available_keys: 106 | continue 107 | if self._aspect_ratio_count[closest_ratio] < self.ratio_nums_gt[closest_ratio]: 108 | self._aspect_ratio_count[closest_ratio] += 1 109 | self._aspect_ratio_buckets[closest_ratio].append(idx) 110 | self.original_buckets[closest_ratio].append(idx) # Save the original samples for each bucket 111 | if not self.current_available_bucket_keys: 112 | self.current_available_bucket_keys, self.exhausted_bucket_keys = self.exhausted_bucket_keys, [] 113 | 114 | if closest_ratio not in self.current_available_bucket_keys: 115 | continue 116 | key = closest_ratio 117 | bucket = self._aspect_ratio_buckets[key] 118 | if len(bucket) == self.batch_size: 119 | yield bucket[:self.batch_size] 120 | del bucket[:self.batch_size] 121 | i += 1 122 | self.exhausted_bucket_keys.append(key) 123 | self.current_available_bucket_keys.remove(key) 124 | 125 | for _ in range(self.total_batches - i): 126 | key = choice(self.all_available_keys) 127 | bucket = self._aspect_ratio_buckets[key] 128 | if len(bucket) >= self.batch_size: 129 | yield bucket[:self.batch_size] 130 | del bucket[:self.batch_size] 131 | 132 | # If a bucket is exhausted 133 | if not bucket: 134 | self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:]) 135 | shuffle(self._aspect_ratio_buckets[key]) 136 | else: 137 | self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:]) 138 | shuffle(self._aspect_ratio_buckets[key]) 139 | -------------------------------------------------------------------------------- /diffusion/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains primitives for multi-gpu communication. 3 | This is useful when doing distributed training. 4 | """ 5 | import os 6 | import pickle 7 | import shutil 8 | 9 | import gc 10 | import mmcv 11 | import torch 12 | import torch.distributed as dist 13 | from mmcv.runner import get_dist_info 14 | 15 | 16 | def is_distributed(): 17 | return get_world_size() > 1 18 | 19 | 20 | def get_world_size(): 21 | if not dist.is_available(): 22 | return 1 23 | return dist.get_world_size() if dist.is_initialized() else 1 24 | 25 | 26 | def get_rank(): 27 | if not dist.is_available(): 28 | return 0 29 | return dist.get_rank() if dist.is_initialized() else 0 30 | 31 | 32 | def get_local_rank(): 33 | if not dist.is_available(): 34 | return 0 35 | return int(os.getenv('LOCAL_RANK', 0)) if dist.is_initialized() else 0 36 | 37 | 38 | def is_master(): 39 | return get_rank() == 0 40 | 41 | 42 | def is_local_master(): 43 | return get_local_rank() == 0 44 | 45 | 46 | def get_local_proc_group(group_size=8): 47 | world_size = get_world_size() 48 | if world_size <= group_size or group_size == 1: 49 | return None 50 | assert world_size % group_size == 0, f'world size ({world_size}) should be evenly divided by group size ({group_size}).' 51 | process_groups = getattr(get_local_proc_group, 'process_groups', {}) 52 | if group_size not in process_groups: 53 | num_groups = dist.get_world_size() // group_size 54 | groups = [list(range(i * group_size, (i + 1) * group_size)) for i in range(num_groups)] 55 | process_groups.update({group_size: [torch.distributed.new_group(group) for group in groups]}) 56 | get_local_proc_group.process_groups = process_groups 57 | 58 | group_idx = get_rank() // group_size 59 | return get_local_proc_group.process_groups.get(group_size)[group_idx] 60 | 61 | 62 | def synchronize(): 63 | """ 64 | Helper function to synchronize (barrier) among all processes when 65 | using distributed training 66 | """ 67 | if not dist.is_available(): 68 | return 69 | if not dist.is_initialized(): 70 | return 71 | world_size = dist.get_world_size() 72 | if world_size == 1: 73 | return 74 | dist.barrier() 75 | 76 | 77 | def all_gather(data): 78 | """ 79 | Run all_gather on arbitrary picklable data (not necessarily tensors) 80 | Args: 81 | data: any picklable object 82 | Returns: 83 | list[data]: list of data gathered from each rank 84 | """ 85 | to_device = torch.device("cuda") 86 | # to_device = torch.device("cpu") 87 | 88 | world_size = get_world_size() 89 | if world_size == 1: 90 | return [data] 91 | 92 | # serialized to a Tensor 93 | buffer = pickle.dumps(data) 94 | storage = torch.ByteStorage.from_buffer(buffer) 95 | tensor = torch.ByteTensor(storage).to(to_device) 96 | 97 | # obtain Tensor size of each rank 98 | local_size = torch.LongTensor([tensor.numel()]).to(to_device) 99 | size_list = [torch.LongTensor([0]).to(to_device) for _ in range(world_size)] 100 | dist.all_gather(size_list, local_size) 101 | size_list = [int(size.item()) for size in size_list] 102 | max_size = max(size_list) 103 | 104 | tensor_list = [ 105 | torch.ByteTensor(size=(max_size,)).to(to_device) for _ in size_list 106 | ] 107 | if local_size != max_size: 108 | padding = torch.ByteTensor(size=(max_size - local_size,)).to(to_device) 109 | tensor = torch.cat((tensor, padding), dim=0) 110 | dist.all_gather(tensor_list, tensor) 111 | 112 | data_list = [] 113 | for size, tensor in zip(size_list, tensor_list): 114 | buffer = tensor.cpu().numpy().tobytes()[:size] 115 | data_list.append(pickle.loads(buffer)) 116 | 117 | return data_list 118 | 119 | 120 | def reduce_dict(input_dict, average=True): 121 | """ 122 | Args: 123 | input_dict (dict): all the values will be reduced 124 | average (bool): whether to do average or sum 125 | Reduce the values in the dictionary from all processes so that process with rank 126 | 0 has the averaged results. Returns a dict with the same fields as 127 | input_dict, after reduction. 128 | """ 129 | world_size = get_world_size() 130 | if world_size < 2: 131 | return input_dict 132 | with torch.no_grad(): 133 | reduced_dict = _extracted_from_reduce_dict_14(input_dict, average, world_size) 134 | return reduced_dict 135 | 136 | 137 | # TODO Rename this here and in `reduce_dict` 138 | def _extracted_from_reduce_dict_14(input_dict, average, world_size): 139 | names = [] 140 | values = [] 141 | # sort the keys so that they are consistent across processes 142 | for k in sorted(input_dict.keys()): 143 | names.append(k) 144 | values.append(input_dict[k]) 145 | values = torch.stack(values, dim=0) 146 | dist.reduce(values, dst=0) 147 | if dist.get_rank() == 0 and average: 148 | # only main process gets accumulated, so only divide by 149 | # world_size in this case 150 | values /= world_size 151 | return dict(zip(names, values)) 152 | 153 | 154 | def broadcast(data, **kwargs): 155 | if get_world_size() == 1: 156 | return data 157 | data = [data] 158 | dist.broadcast_object_list(data, **kwargs) 159 | return data[0] 160 | 161 | 162 | def all_gather_cpu(result_part, tmpdir=None, collect_by_master=True): 163 | rank, world_size = get_dist_info() 164 | if tmpdir is None: 165 | tmpdir = './tmp' 166 | if rank == 0: 167 | mmcv.mkdir_or_exist(tmpdir) 168 | synchronize() 169 | # dump the part result to the dir 170 | mmcv.dump(result_part, os.path.join(tmpdir, f'part_{rank}.pkl')) 171 | synchronize() 172 | if collect_by_master and rank != 0: 173 | return None 174 | # load results of all parts from tmp dir 175 | results = [] 176 | for i in range(world_size): 177 | part_file = os.path.join(tmpdir, f'part_{i}.pkl') 178 | results.append(mmcv.load(part_file)) 179 | if not collect_by_master: 180 | synchronize() 181 | # remove tmp dir 182 | if rank == 0: 183 | shutil.rmtree(tmpdir) 184 | return results 185 | 186 | def all_gather_tensor(tensor, group_size=None, group=None): 187 | if group_size is None: 188 | group_size = get_world_size() 189 | if group_size == 1: 190 | output = [tensor] 191 | else: 192 | output = [torch.zeros_like(tensor) for _ in range(group_size)] 193 | dist.all_gather(output, tensor, group=group) 194 | return output 195 | 196 | 197 | def gather_difflen_tensor(feat, num_samples_list, concat=True, group=None, group_size=None): 198 | world_size = get_world_size() 199 | if world_size == 1: 200 | return feat if concat else [feat] 201 | num_samples, *feat_dim = feat.size() 202 | # padding to max number of samples 203 | feat_padding = feat.new_zeros((max(num_samples_list), *feat_dim)) 204 | feat_padding[:num_samples] = feat 205 | # gather 206 | feat_gather = all_gather_tensor(feat_padding, group=group, group_size=group_size) 207 | for r, num in enumerate(num_samples_list): 208 | feat_gather[r] = feat_gather[r][:num] 209 | if concat: 210 | feat_gather = torch.cat(feat_gather) 211 | return feat_gather 212 | 213 | 214 | class GatherLayer(torch.autograd.Function): 215 | '''Gather tensors from all process, supporting backward propagation. 216 | ''' 217 | 218 | @staticmethod 219 | def forward(ctx, input): 220 | ctx.save_for_backward(input) 221 | num_samples = torch.tensor(input.size(0), dtype=torch.long, device=input.device) 222 | ctx.num_samples_list = all_gather_tensor(num_samples) 223 | output = gather_difflen_tensor(input, ctx.num_samples_list, concat=False) 224 | return tuple(output) 225 | 226 | @staticmethod 227 | def backward(ctx, *grads): # tuple(output)'s grad 228 | input, = ctx.saved_tensors 229 | num_samples_list = ctx.num_samples_list 230 | rank = get_rank() 231 | start, end = sum(num_samples_list[:rank]), sum(num_samples_list[:rank + 1]) 232 | grads = torch.cat(grads) 233 | if is_distributed(): 234 | dist.all_reduce(grads) 235 | grad_out = torch.zeros_like(input) 236 | grad_out[:] = grads[start:end] 237 | return grad_out, None, None 238 | 239 | 240 | class GatherLayerWithGroup(torch.autograd.Function): 241 | '''Gather tensors from all process, supporting backward propagation. 242 | ''' 243 | 244 | @staticmethod 245 | def forward(ctx, input, group, group_size): 246 | ctx.save_for_backward(input) 247 | ctx.group_size = group_size 248 | output = all_gather_tensor(input, group=group, group_size=group_size) 249 | return tuple(output) 250 | 251 | @staticmethod 252 | def backward(ctx, *grads): # tuple(output)'s grad 253 | input, = ctx.saved_tensors 254 | grads = torch.stack(grads) 255 | if is_distributed(): 256 | dist.all_reduce(grads) 257 | grad_out = torch.zeros_like(input) 258 | grad_out[:] = grads[get_rank() % ctx.group_size] 259 | return grad_out, None, None 260 | 261 | 262 | def gather_layer_with_group(data, group=None, group_size=None): 263 | if group_size is None: 264 | group_size = get_world_size() 265 | return GatherLayer.apply(data, group, group_size) 266 | 267 | from typing import Union 268 | import math 269 | # from torch.distributed.fsdp.fully_sharded_data_parallel import TrainingState_, _calc_grad_norm 270 | 271 | @torch.no_grad() 272 | def clip_grad_norm_( 273 | self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0 274 | ) -> None: 275 | self._lazy_init() 276 | self._wait_for_previous_optim_step() 277 | assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance" 278 | self._assert_state(TrainingState_.IDLE) 279 | 280 | max_norm = float(max_norm) 281 | norm_type = float(norm_type) 282 | # Computes the max norm for this shard's gradients and sync's across workers 283 | local_norm = _calc_grad_norm(self.params_with_grad, norm_type).cuda() # type: ignore[arg-type] 284 | if norm_type == math.inf: 285 | total_norm = local_norm 286 | dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group) 287 | else: 288 | total_norm = local_norm ** norm_type 289 | dist.all_reduce(total_norm, group=self.process_group) 290 | total_norm = total_norm ** (1.0 / norm_type) 291 | 292 | clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6) 293 | if clip_coef < 1: 294 | # multiply by clip_coef, aka, (max_norm/total_norm). 295 | for p in self.params_with_grad: 296 | assert p.grad is not None 297 | p.grad.detach().mul_(clip_coef.to(p.grad.device)) 298 | return total_norm 299 | 300 | 301 | def flush(): 302 | gc.collect() 303 | torch.cuda.empty_cache() 304 | -------------------------------------------------------------------------------- /diffusion/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch.distributed as dist 4 | from datetime import datetime 5 | from .dist_utils import is_local_master 6 | # from mmcv.utils.logging import logger_initialized 7 | 8 | 9 | def get_root_logger(log_file=None, log_level=logging.INFO, name='PixArt'): 10 | """Get root logger. 11 | 12 | Args: 13 | log_file (str, optional): File path of log. Defaults to None. 14 | log_level (int, optional): The level of logger. 15 | Defaults to logging.INFO. 16 | name (str): logger name 17 | Returns: 18 | :obj:`logging.Logger`: The obtained logger 19 | """ 20 | if log_file is None: 21 | log_file = '/dev/null' 22 | return get_logger(name=name, log_file=log_file, log_level=log_level) 23 | 24 | 25 | def get_logger(name, log_file=None, log_level=logging.INFO): 26 | """Initialize and get a logger by name. 27 | 28 | If the logger has not been initialized, this method will initialize the 29 | logger by adding one or two handlers, otherwise the initialized logger will 30 | be directly returned. During initialization, a StreamHandler will always be 31 | added. If `log_file` is specified and the process rank is 0, a FileHandler 32 | will also be added. 33 | 34 | Args: 35 | name (str): Logger name. 36 | log_file (str | None): The log filename. If specified, a FileHandler 37 | will be added to the logger. 38 | log_level (int): The logger level. Note that only the process of 39 | rank 0 is affected, and other processes will set the level to 40 | "Error" thus be silent most of the time. 41 | 42 | Returns: 43 | logging.Logger: The expected logger. 44 | """ 45 | logger = logging.getLogger(name) 46 | logger.propagate = False # disable root logger to avoid duplicate logging 47 | 48 | # if name in logger_initialized: 49 | # return logger 50 | # # handle hierarchical names 51 | # # e.g., logger "a" is initialized, then logger "a.b" will skip the 52 | # # initialization since it is a child of "a". 53 | # for logger_name in logger_initialized: 54 | # if name.startswith(logger_name): 55 | # return logger 56 | 57 | stream_handler = logging.StreamHandler() 58 | handlers = [stream_handler] 59 | 60 | rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 61 | # only rank 0 will add a FileHandler 62 | if rank == 0 and log_file is not None: 63 | file_handler = logging.FileHandler(log_file, 'w') 64 | handlers.append(file_handler) 65 | 66 | formatter = logging.Formatter( 67 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 68 | for handler in handlers: 69 | handler.setFormatter(formatter) 70 | handler.setLevel(log_level) 71 | logger.addHandler(handler) 72 | 73 | # only rank0 for each node will print logs 74 | log_level = log_level if is_local_master() else logging.ERROR 75 | logger.setLevel(log_level) 76 | 77 | # logger_initialized[name] = True 78 | 79 | return logger 80 | 81 | def rename_file_with_creation_time(file_path): 82 | # 获取文件的创建时间 83 | creation_time = os.path.getctime(file_path) 84 | creation_time_str = datetime.fromtimestamp(creation_time).strftime('%Y-%m-%d_%H-%M-%S') 85 | 86 | # 构建新的文件名 87 | dir_name, file_name = os.path.split(file_path) 88 | name, ext = os.path.splitext(file_name) 89 | new_file_name = f"{name}_{creation_time_str}{ext}" 90 | new_file_path = os.path.join(dir_name, new_file_name) 91 | 92 | # 重命名文件 93 | os.rename(file_path, new_file_path) 94 | print(f"File renamed to: {new_file_path}") 95 | -------------------------------------------------------------------------------- /diffusion/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from diffusers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup 2 | from torch.optim import Optimizer 3 | from torch.optim.lr_scheduler import LambdaLR 4 | import math 5 | 6 | from diffusion.utils.logger import get_root_logger 7 | 8 | 9 | def build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio): 10 | if not config.get('lr_schedule_args', None): 11 | config.lr_schedule_args = {} 12 | if config.get('lr_warmup_steps', None): 13 | config['num_warmup_steps'] = config.get('lr_warmup_steps') # for compatibility with old version 14 | 15 | logger = get_root_logger() 16 | logger.info( 17 | f'Lr schedule: {config.lr_schedule}, ' + ",".join( 18 | [f"{key}:{value}" for key, value in config.lr_schedule_args.items()]) + '.') 19 | if config.lr_schedule == 'cosine': 20 | lr_scheduler = get_cosine_schedule_with_warmup( 21 | optimizer=optimizer, 22 | **config.lr_schedule_args, 23 | num_training_steps=(len(train_dataloader) * config.num_epochs), 24 | ) 25 | elif config.lr_schedule == 'constant': 26 | lr_scheduler = get_constant_schedule_with_warmup( 27 | optimizer=optimizer, 28 | **config.lr_schedule_args, 29 | ) 30 | elif config.lr_schedule == 'cosine_decay_to_constant': 31 | assert lr_scale_ratio >= 1 32 | lr_scheduler = get_cosine_decay_to_constant_with_warmup( 33 | optimizer=optimizer, 34 | **config.lr_schedule_args, 35 | final_lr=1 / lr_scale_ratio, 36 | num_training_steps=(len(train_dataloader) * config.num_epochs), 37 | ) 38 | else: 39 | raise RuntimeError(f'Unrecognized lr schedule {config.lr_schedule}.') 40 | return lr_scheduler 41 | 42 | 43 | def get_cosine_decay_to_constant_with_warmup(optimizer: Optimizer, 44 | num_warmup_steps: int, 45 | num_training_steps: int, 46 | final_lr: float = 0.0, 47 | num_decay: float = 0.667, 48 | num_cycles: float = 0.5, 49 | last_epoch: int = -1 50 | ): 51 | """ 52 | Create a schedule with a cosine annealing lr followed by a constant lr. 53 | 54 | Args: 55 | optimizer ([`~torch.optim.Optimizer`]): 56 | The optimizer for which to schedule the learning rate. 57 | num_warmup_steps (`int`): 58 | The number of steps for the warmup phase. 59 | num_training_steps (`int`): 60 | The number of total training steps. 61 | final_lr (`int`): 62 | The final constant lr after cosine decay. 63 | num_decay (`int`): 64 | The 65 | last_epoch (`int`, *optional*, defaults to -1): 66 | The index of the last epoch when resuming training. 67 | 68 | Return: 69 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 70 | """ 71 | 72 | def lr_lambda(current_step): 73 | if current_step < num_warmup_steps: 74 | return float(current_step) / float(max(1, num_warmup_steps)) 75 | 76 | num_decay_steps = int(num_training_steps * num_decay) 77 | if current_step > num_decay_steps: 78 | return final_lr 79 | 80 | progress = float(current_step - num_warmup_steps) / float(max(1, num_decay_steps - num_warmup_steps)) 81 | return ( 82 | max( 83 | 0.0, 84 | 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress)), 85 | ) 86 | * (1 - final_lr) 87 | ) + final_lr 88 | 89 | return LambdaLR(optimizer, lr_lambda, last_epoch) 90 | -------------------------------------------------------------------------------- /diffusion/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from mmcv import Config 4 | from mmcv.runner import build_optimizer as mm_build_optimizer, OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, \ 5 | OPTIMIZERS 6 | from mmcv.utils import _BatchNorm, _InstanceNorm 7 | from torch.nn import GroupNorm, LayerNorm 8 | 9 | from .logger import get_root_logger 10 | 11 | from typing import Tuple, Optional, Callable 12 | 13 | import torch 14 | from torch.optim.optimizer import Optimizer 15 | 16 | 17 | def auto_scale_lr(effective_bs, optimizer_cfg, rule='linear', base_batch_size=256): 18 | assert rule in ['linear', 'sqrt'] 19 | logger = get_root_logger() 20 | # scale by world size 21 | if rule == 'sqrt': 22 | scale_ratio = math.sqrt(effective_bs / base_batch_size) 23 | elif rule == 'linear': 24 | scale_ratio = effective_bs / base_batch_size 25 | optimizer_cfg['lr'] *= scale_ratio 26 | logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.7f} (using {rule} scaling rule).') 27 | return scale_ratio 28 | 29 | 30 | @OPTIMIZER_BUILDERS.register_module() 31 | class MyOptimizerConstructor(DefaultOptimizerConstructor): 32 | 33 | def add_params(self, params, module, prefix='', is_dcn_module=None): 34 | """Add all parameters of module to the params list. 35 | 36 | The parameters of the given module will be added to the list of param 37 | groups, with specific rules defined by paramwise_cfg. 38 | 39 | Args: 40 | params (list[dict]): A list of param groups, it will be modified 41 | in place. 42 | module (nn.Module): The module to be added. 43 | prefix (str): The prefix of the module 44 | 45 | """ 46 | # get param-wise options 47 | custom_keys = self.paramwise_cfg.get('custom_keys', {}) 48 | # first sort with alphabet order and then sort with reversed len of str 49 | # sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) 50 | 51 | bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.) 52 | bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.) 53 | norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.) 54 | bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False) 55 | 56 | # special rules for norm layers and depth-wise conv layers 57 | is_norm = isinstance(module, 58 | (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) 59 | 60 | for name, param in module.named_parameters(recurse=False): 61 | base_lr = self.base_lr 62 | if name == 'bias' and not is_norm and not is_dcn_module: 63 | base_lr *= bias_lr_mult 64 | 65 | # apply weight decay policies 66 | base_wd = self.base_wd 67 | # norm decay 68 | if is_norm: 69 | if self.base_wd is not None: 70 | base_wd *= norm_decay_mult 71 | elif name == 'bias' and not is_dcn_module: 72 | if self.base_wd is not None: 73 | # TODO: current bias_decay_mult will have affect on DCN 74 | base_wd *= bias_decay_mult 75 | 76 | param_group = {'params': [param]} 77 | if not param.requires_grad: 78 | param_group['requires_grad'] = False 79 | params.append(param_group) 80 | continue 81 | if bypass_duplicate and self._is_in(param_group, params): 82 | logger = get_root_logger() 83 | logger.warn(f'{prefix} is duplicate. It is skipped since ' 84 | f'bypass_duplicate={bypass_duplicate}') 85 | continue 86 | # if the parameter match one of the custom keys, ignore other rules 87 | is_custom = False 88 | for key in custom_keys: 89 | scope, key_name = key if isinstance(key, tuple) else (None, key) 90 | if scope is not None and scope not in f'{prefix}': 91 | continue 92 | if key_name in f'{prefix}.{name}': 93 | is_custom = True 94 | if 'lr_mult' in custom_keys[key]: 95 | # if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}': 96 | # param_group['lr'] = self.base_lr 97 | # else: 98 | param_group['lr'] = self.base_lr * custom_keys[key]['lr_mult'] 99 | elif 'lr' not in param_group: 100 | param_group['lr'] = base_lr 101 | if self.base_wd is not None: 102 | if 'decay_mult' in custom_keys[key]: 103 | param_group['weight_decay'] = self.base_wd * custom_keys[key]['decay_mult'] 104 | elif 'weight_decay' not in param_group: 105 | param_group['weight_decay'] = base_wd 106 | 107 | if not is_custom: 108 | # bias_lr_mult affects all bias parameters 109 | # except for norm.bias dcn.conv_offset.bias 110 | if base_lr != self.base_lr: 111 | param_group['lr'] = base_lr 112 | if base_wd != self.base_wd: 113 | param_group['weight_decay'] = base_wd 114 | params.append(param_group) 115 | 116 | for child_name, child_mod in module.named_children(): 117 | child_prefix = f'{prefix}.{child_name}' if prefix else child_name 118 | self.add_params( 119 | params, 120 | child_mod, 121 | prefix=child_prefix, 122 | is_dcn_module=is_dcn_module) 123 | 124 | 125 | def build_optimizer(model, optimizer_cfg): 126 | # default parameter-wise config 127 | logger = get_root_logger() 128 | 129 | if hasattr(model, 'module'): 130 | model = model.module 131 | # set optimizer constructor 132 | optimizer_cfg.setdefault('constructor', 'MyOptimizerConstructor') 133 | # parameter-wise setting: cancel weight decay for some specific modules 134 | custom_keys = dict() 135 | for name, module in model.named_modules(): 136 | if hasattr(module, 'zero_weight_decay'): 137 | custom_keys |= { 138 | (name, key): dict(decay_mult=0) 139 | for key in module.zero_weight_decay 140 | } 141 | 142 | paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys))) 143 | if given_cfg := optimizer_cfg.get('paramwise_cfg'): 144 | paramwise_cfg.merge_from_dict(dict(cfg=given_cfg)) 145 | optimizer_cfg['paramwise_cfg'] = paramwise_cfg.cfg 146 | # build optimizer 147 | optimizer = mm_build_optimizer(model, optimizer_cfg) 148 | 149 | weight_decay_groups = dict() 150 | lr_groups = dict() 151 | for group in optimizer.param_groups: 152 | if not group.get('requires_grad', True): continue 153 | lr_groups.setdefault(group['lr'], []).append(group) 154 | weight_decay_groups.setdefault(group['weight_decay'], []).append(group) 155 | 156 | learnable_count, fix_count = 0, 0 157 | for p in model.parameters(): 158 | if p.requires_grad: 159 | learnable_count += 1 160 | else: 161 | fix_count += 1 162 | fix_info = f"{learnable_count} are learnable, {fix_count} are fix" 163 | lr_info = "Lr group: " + ", ".join([f'{len(group)} params with lr {lr:.5f}' for lr, group in lr_groups.items()]) 164 | wd_info = "Weight decay group: " + ", ".join( 165 | [f'{len(group)} params with weight decay {wd}' for wd, group in weight_decay_groups.items()]) 166 | opt_info = f"Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}." 167 | logger.info(opt_info) 168 | 169 | return optimizer 170 | 171 | 172 | @OPTIMIZERS.register_module() 173 | class Lion(Optimizer): 174 | def __init__( 175 | self, 176 | params, 177 | lr: float = 1e-4, 178 | betas: Tuple[float, float] = (0.9, 0.99), 179 | weight_decay: float = 0.0, 180 | ): 181 | assert lr > 0. 182 | assert all(0. <= beta <= 1. for beta in betas) 183 | 184 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 185 | 186 | super().__init__(params, defaults) 187 | 188 | @staticmethod 189 | def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): 190 | # stepweight decay 191 | p.data.mul_(1 - lr * wd) 192 | 193 | # weight update 194 | update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_() 195 | p.add_(update, alpha=-lr) 196 | 197 | # decay the momentum running average coefficient 198 | exp_avg.lerp_(grad, 1 - beta2) 199 | 200 | @staticmethod 201 | def exists(val): 202 | return val is not None 203 | 204 | @torch.no_grad() 205 | def step( 206 | self, 207 | closure: Optional[Callable] = None 208 | ): 209 | 210 | loss = None 211 | if self.exists(closure): 212 | with torch.enable_grad(): 213 | loss = closure() 214 | 215 | for group in self.param_groups: 216 | for p in filter(lambda p: self.exists(p.grad), group['params']): 217 | 218 | grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \ 219 | self.state[p] 220 | 221 | # init state - exponential moving average of gradient values 222 | if len(state) == 0: 223 | state['exp_avg'] = torch.zeros_like(p) 224 | 225 | exp_avg = state['exp_avg'] 226 | 227 | self.update_fn( 228 | p, 229 | grad, 230 | exp_avg, 231 | lr, 232 | wd, 233 | beta1, 234 | beta2 235 | ) 236 | 237 | return loss 238 | -------------------------------------------------------------------------------- /output/result1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/output/result1.png -------------------------------------------------------------------------------- /output/style1_gufeng.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/output/style1_gufeng.png -------------------------------------------------------------------------------- /output/style2_oil.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/output/style2_oil.png -------------------------------------------------------------------------------- /output/style3_gufeng.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/output/style3_gufeng.png -------------------------------------------------------------------------------- /output/style3_oil.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/output/style3_oil.png -------------------------------------------------------------------------------- /output/style3_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/output/style3_paint.png -------------------------------------------------------------------------------- /output/style4_3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/output/style4_3d.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.1 2 | torchvision==0.16.1 3 | torchvision==0.16.1 4 | mmcv==1.7.0 5 | git+https://github.com/huggingface/diffusers 6 | timm==0.6.12 7 | accelerate 8 | tensorboard 9 | tensorboardX 10 | transformers==4.45.2 11 | ftfy 12 | beautifulsoup4 13 | protobuf==3.20.2 14 | gradio==4.1.1 15 | yapf==0.40.1 16 | opencv-python 17 | bs4 18 | einops 19 | xformers==0.0.23 20 | optimum 21 | peft==0.6.2 22 | sentencepiece==0.2.0 -------------------------------------------------------------------------------- /resources/demos/prompts_examples.md: -------------------------------------------------------------------------------- 1 | ## 📌 Prompt Examples for RelaCtrl 2 | 3 | This section provides example prompts for using the RelaCtrl model under different control modes. 4 | 5 | ### 🔧 Conditional Control Mode 6 | For the standard conditional control model (**RelaCtrl_PixArt_Canny**), the demo image and corresponding prompt examples are shown below: 7 | ``` 8 | 1.exampel1.png 9 | a large, well-maintained estate with a red brick driveway and a beautifully landscaped yard. The property is surrounded by a forest, giving it a serene and peaceful atmosphere. The house is situated in a neighborhood with other homes nearby, creating a sense of community. In the yard, there are several potted plants, adding to the lush greenery of the area. A bench is also present, providing a place for relaxation and enjoyment of the surroundings. The overall scene is picturesque and inviting, making it an ideal location for a family home. 10 | 11 | 2.exampel2.png 12 | a white dress displayed on a mannequin, showcasing its elegant design. The dress is a short, white, and lacy dress, with a fitted waist and a full skirt. The mannequin is positioned in the center of the scene, showcasing the dress's style and fit. The dress appears to be a wedding dress, as it is white and has a classic, elegant appearance. 13 | 14 | 3.exampel3.png 15 | a beautiful blue bird perched on a branch, surrounded by a lush green field. The bird is positioned in the center of the scene, with its wings spread wide, showcasing its vibrant blue feathers. The branch it is perched on is filled with pink flowers, adding a touch of color to the scene. The bird appears to be enjoying its time in the serene environment, surrounded by the natural beauty of the field and flowers. 16 | ``` 17 | 18 | ### 🎨 Style Control Mode 19 | For the style-guided control model (**RelaCtrl_PixArt_Canny_Style**), the demo image and prompt examples are provided below: 20 | ``` 21 | 1.style1.png 22 | gufeng_A tranquil mountain range with snow-capped peaks and their clear reflection in a calm lake, surrounded by trees, creates a stunning, serene landscape. 23 | 24 | 2.style2.png 25 | oil_A man stands in a moonlit snowy field with a scythe, gazing at the moon, amidst trees, exuding mystery. 26 | 27 | 3.style3.png 28 | paint_A vintage car drives down a dirt road, dusting up as it passes, center stage with two people observing on the sides, evoking nostalgia. 29 | 30 | 4.style4.png 31 | 3d_The painting is a close-up portrait of a bearded man with a mustache, focusing on his facial features in a blurred background. 32 | ``` 33 | 34 | Note: When using a style image for inference, you must prepend a **style_** annotation before the actual prompt. 35 | Available style options include: 36 | 37 | - gufeng 38 | 39 | - 3d 40 | 41 | - paint 42 | 43 | - oil 44 | 45 | -------------------------------------------------------------------------------- /resources/demos/reference_images/example1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/resources/demos/reference_images/example1.png -------------------------------------------------------------------------------- /resources/demos/reference_images/example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/resources/demos/reference_images/example2.png -------------------------------------------------------------------------------- /resources/demos/reference_images/example3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/resources/demos/reference_images/example3.png -------------------------------------------------------------------------------- /resources/demos/reference_images/style1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/resources/demos/reference_images/style1.png -------------------------------------------------------------------------------- /resources/demos/reference_images/style2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/resources/demos/reference_images/style2.png -------------------------------------------------------------------------------- /resources/demos/reference_images/style3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/resources/demos/reference_images/style3.png -------------------------------------------------------------------------------- /resources/demos/reference_images/style4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/360CVGroup/RelaCtrl/88b81c255cd50398ba9ea0ccab4bf703955bbe48/resources/demos/reference_images/style4.png --------------------------------------------------------------------------------