├── layout_planner
├── training
│ ├── __init__.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── layout_dataset.py
│ │ ├── layoutnew_dataset.py
│ │ └── quantizer.py
│ ├── utils.py
│ └── trainer_layout.py
├── requirements_part1.txt
├── requirements_part2.txt
├── scripts
│ └── inference_template.sh
├── config
│ └── 13b_z3_no_offload.json
├── models
│ └── modeling_layout.py
└── inference_layout.py
├── assets
├── teaser.png
├── natural_image_data
│ ├── case_1_seed_0
│ │ ├── bg.png
│ │ ├── whole.png
│ │ ├── layer_0.png
│ │ ├── layer_1.png
│ │ ├── layer_2.png
│ │ ├── layer_3.png
│ │ └── composite.png
│ ├── case_1_seed_1
│ │ ├── bg.png
│ │ ├── whole.png
│ │ ├── layer_0.png
│ │ ├── layer_1.png
│ │ ├── layer_2.png
│ │ ├── layer_3.png
│ │ └── composite.png
│ ├── case_1_seed_8
│ │ ├── bg.png
│ │ ├── whole.png
│ │ ├── layer_0.png
│ │ ├── layer_1.png
│ │ ├── layer_2.png
│ │ ├── layer_3.png
│ │ └── composite.png
│ ├── case_2_seed_0
│ │ ├── bg.png
│ │ ├── whole.png
│ │ ├── layer_0.png
│ │ ├── layer_1.png
│ │ ├── layer_2.png
│ │ ├── layer_3.png
│ │ └── composite.png
│ ├── case_2_seed_3
│ │ ├── bg.png
│ │ ├── whole.png
│ │ ├── layer_0.png
│ │ ├── layer_1.png
│ │ ├── layer_2.png
│ │ ├── layer_3.png
│ │ └── composite.png
│ ├── case_2_seed_6
│ │ ├── bg.png
│ │ ├── whole.png
│ │ ├── layer_0.png
│ │ ├── layer_1.png
│ │ ├── layer_2.png
│ │ ├── layer_3.png
│ │ └── composite.png
│ ├── case_3_seed_0
│ │ ├── bg.png
│ │ ├── whole.png
│ │ ├── layer_0.png
│ │ ├── layer_1.png
│ │ ├── layer_2.png
│ │ ├── layer_3.png
│ │ └── composite.png
│ ├── case_3_seed_1
│ │ ├── bg.png
│ │ ├── whole.png
│ │ ├── layer_0.png
│ │ ├── layer_1.png
│ │ ├── layer_2.png
│ │ ├── layer_3.png
│ │ └── composite.png
│ ├── case_3_seed_2
│ │ ├── bg.png
│ │ ├── whole.png
│ │ ├── layer_0.png
│ │ ├── layer_1.png
│ │ ├── layer_2.png
│ │ ├── layer_3.png
│ │ └── composite.png
│ ├── case_4_seed_0
│ │ ├── bg.png
│ │ ├── whole.png
│ │ ├── layer_0.png
│ │ ├── layer_1.png
│ │ ├── layer_2.png
│ │ ├── layer_3.png
│ │ └── composite.png
│ ├── case_4_seed_3
│ │ ├── bg.png
│ │ ├── whole.png
│ │ ├── layer_0.png
│ │ ├── layer_1.png
│ │ ├── layer_2.png
│ │ ├── layer_3.png
│ │ └── composite.png
│ └── case_4_seed_9
│ │ ├── bg.png
│ │ ├── whole.png
│ │ ├── layer_0.png
│ │ ├── layer_1.png
│ │ ├── layer_2.png
│ │ ├── layer_3.png
│ │ └── composite.png
└── semi_transparent_image_data
│ ├── bg
│ ├── case_1_output_1.png
│ ├── case_2_output_0.png
│ ├── case_2_output_2.png
│ ├── case_3_output_0.png
│ ├── case_4_output_2.png
│ ├── case_5_output_3.png
│ ├── case_6_output_3.png
│ ├── case_7_output_3.png
│ ├── case_8_output_2.png
│ ├── case_9_output_1.png
│ ├── case_10_output_1.png
│ └── case_11_output_1.png
│ ├── fg
│ ├── case_1_output_1.png
│ ├── case_2_output_0.png
│ ├── case_2_output_2.png
│ ├── case_3_output_0.png
│ ├── case_4_output_2.png
│ ├── case_5_output_3.png
│ ├── case_6_output_3.png
│ ├── case_7_output_3.png
│ ├── case_8_output_2.png
│ ├── case_9_output_1.png
│ ├── case_10_output_1.png
│ └── case_11_output_1.png
│ └── merged
│ ├── case_10_output_1.png
│ ├── case_11_output_1.png
│ ├── case_1_output_1.png
│ ├── case_2_output_0.png
│ ├── case_2_output_2.png
│ ├── case_3_output_0.png
│ ├── case_4_output_2.png
│ ├── case_5_output_3.png
│ ├── case_6_output_3.png
│ ├── case_7_output_3.png
│ ├── case_8_output_2.png
│ └── case_9_output_1.png
├── CODE_OF_CONDUCT.md
├── multi_layer_gen
├── configs
│ ├── base.py
│ ├── multi_layer_resolution1024_test.py
│ └── multi_layer_resolution512_test.py
├── test.py
├── custom_model_transp_vae.py
├── custom_model_mmdit.py
└── custom_pipeline.py
├── LICENSE
├── SUPPORT.md
├── README.md
├── SECURITY.md
├── example.py
└── .gitignore
/layout_planner/training/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/layout_planner/training/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/teaser.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_0/bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_0/bg.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_1/bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_1/bg.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_8/bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_8/bg.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_0/bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_0/bg.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_3/bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_3/bg.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_6/bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_6/bg.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_0/bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_0/bg.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_1/bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_1/bg.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_2/bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_2/bg.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_0/bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_0/bg.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_3/bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_3/bg.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_9/bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_9/bg.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_0/whole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_0/whole.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_1/whole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_1/whole.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_8/whole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_8/whole.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_0/whole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_0/whole.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_3/whole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_3/whole.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_6/whole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_6/whole.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_0/whole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_0/whole.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_1/whole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_1/whole.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_2/whole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_2/whole.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_0/whole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_0/whole.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_3/whole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_3/whole.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_9/whole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_9/whole.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_0/layer_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_0/layer_0.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_0/layer_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_0/layer_1.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_0/layer_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_0/layer_2.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_0/layer_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_0/layer_3.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_1/layer_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_1/layer_0.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_1/layer_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_1/layer_1.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_1/layer_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_1/layer_2.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_1/layer_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_1/layer_3.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_8/layer_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_8/layer_0.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_8/layer_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_8/layer_1.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_8/layer_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_8/layer_2.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_8/layer_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_8/layer_3.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_0/layer_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_0/layer_0.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_0/layer_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_0/layer_1.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_0/layer_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_0/layer_2.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_0/layer_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_0/layer_3.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_3/layer_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_3/layer_0.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_3/layer_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_3/layer_1.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_3/layer_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_3/layer_2.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_3/layer_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_3/layer_3.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_6/layer_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_6/layer_0.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_6/layer_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_6/layer_1.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_6/layer_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_6/layer_2.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_6/layer_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_6/layer_3.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_0/layer_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_0/layer_0.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_0/layer_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_0/layer_1.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_0/layer_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_0/layer_2.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_0/layer_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_0/layer_3.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_1/layer_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_1/layer_0.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_1/layer_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_1/layer_1.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_1/layer_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_1/layer_2.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_1/layer_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_1/layer_3.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_2/layer_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_2/layer_0.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_2/layer_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_2/layer_1.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_2/layer_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_2/layer_2.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_2/layer_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_2/layer_3.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_0/layer_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_0/layer_0.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_0/layer_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_0/layer_1.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_0/layer_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_0/layer_2.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_0/layer_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_0/layer_3.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_3/layer_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_3/layer_0.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_3/layer_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_3/layer_1.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_3/layer_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_3/layer_2.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_3/layer_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_3/layer_3.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_9/layer_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_9/layer_0.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_9/layer_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_9/layer_1.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_9/layer_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_9/layer_2.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_9/layer_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_9/layer_3.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_0/composite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_0/composite.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_1/composite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_1/composite.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_1_seed_8/composite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_8/composite.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_0/composite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_0/composite.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_3/composite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_3/composite.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_2_seed_6/composite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_6/composite.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_0/composite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_0/composite.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_1/composite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_1/composite.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_3_seed_2/composite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_2/composite.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_0/composite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_0/composite.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_3/composite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_3/composite.png
--------------------------------------------------------------------------------
/assets/natural_image_data/case_4_seed_9/composite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_9/composite.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/bg/case_1_output_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_1_output_1.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/bg/case_2_output_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_2_output_0.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/bg/case_2_output_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_2_output_2.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/bg/case_3_output_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_3_output_0.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/bg/case_4_output_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_4_output_2.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/bg/case_5_output_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_5_output_3.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/bg/case_6_output_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_6_output_3.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/bg/case_7_output_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_7_output_3.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/bg/case_8_output_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_8_output_2.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/bg/case_9_output_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_9_output_1.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/fg/case_1_output_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_1_output_1.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/fg/case_2_output_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_2_output_0.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/fg/case_2_output_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_2_output_2.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/fg/case_3_output_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_3_output_0.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/fg/case_4_output_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_4_output_2.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/fg/case_5_output_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_5_output_3.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/fg/case_6_output_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_6_output_3.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/fg/case_7_output_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_7_output_3.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/fg/case_8_output_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_8_output_2.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/fg/case_9_output_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_9_output_1.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/bg/case_10_output_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_10_output_1.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/bg/case_11_output_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_11_output_1.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/fg/case_10_output_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_10_output_1.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/fg/case_11_output_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_11_output_1.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/merged/case_10_output_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_10_output_1.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/merged/case_11_output_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_11_output_1.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/merged/case_1_output_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_1_output_1.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/merged/case_2_output_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_2_output_0.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/merged/case_2_output_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_2_output_2.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/merged/case_3_output_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_3_output_0.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/merged/case_4_output_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_4_output_2.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/merged/case_5_output_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_5_output_3.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/merged/case_6_output_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_6_output_3.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/merged/case_7_output_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_7_output_3.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/merged/case_8_output_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_8_output_2.png
--------------------------------------------------------------------------------
/assets/semi_transparent_image_data/merged/case_9_output_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_9_output_1.png
--------------------------------------------------------------------------------
/layout_planner/requirements_part1.txt:
--------------------------------------------------------------------------------
1 | gradio
2 | matplotlib
3 | diffusers
4 | accelerate==0.27.2
5 | warmup_scheduler
6 | webdataset
7 | tensorboardx
8 | protobuf==3.20.0
9 | tensorboard
10 | torchmetrics
11 | clean-fid
12 | pynvml
13 | open_clip_torch
14 | datasets
15 | skia-python
16 | colorama
17 | opencv-python
18 | deepspeed==0.14.2
--------------------------------------------------------------------------------
/layout_planner/requirements_part2.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.27.2
2 | click>=8.0.4,<9
3 | datasets>=2.10.0,<3
4 | transformers==4.39.1
5 | langchain>=0.0.139
6 | wandb
7 | ninja
8 | tensorboard
9 | tensorboardx
10 | peft @ git+https://github.com/huggingface/peft.git@382b178911edff38c1ff619bbac2ba556bd2276b
11 | webdataset
12 | tiktoken
13 | einops
14 | bitsandbytes
15 | evaluate
16 | bert_score
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/multi_layer_gen/configs/base.py:
--------------------------------------------------------------------------------
1 | ### Model Settings
2 | pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
3 | revision = None
4 | variant = None
5 | cache_dir = None
6 |
7 | ### Training Settings
8 | seed = 42
9 | report_to = "wandb"
10 | tracker_project_name = "multilayer"
11 | wandb_job_name = "YOU_FORGET_TO_SET"
12 | logging_dir = "logs"
13 | max_train_steps = None
14 | checkpoints_total_limit = None
15 |
16 | # gpu
17 | allow_tf32 = True
18 | gradient_checkpointing = True
19 | mixed_precision = "bf16"
20 |
--------------------------------------------------------------------------------
/layout_planner/scripts/inference_template.sh:
--------------------------------------------------------------------------------
1 | python inference_layout.py \
2 | --input_model "Your base model path" \
3 | --resume "Your checkpoint path" \
4 | --width Your_data_width --height Your_data_height \
5 | --save_path "Your output save path" \
6 | --do_sample Whether_do_sample_when_decoding \
7 | --temperature The_temperature_when_decoding \
8 | --inference_caption "Design an engaging and vibrant recruitment advertisement for our company. The image should feature three animated characters in a modern cityscape, depicting a dynamic and collaborative work environment. Incorporate a light bulb graphic with a question mark, symbolizing innovation, creativity, and problem-solving. Use bold text to announce \"WE ARE RECRUITING\" and provide the company's social media handle \"@reallygreatsite\" and a contact phone number \"+123-456-7890\" for interested individuals. The overall design should be playful and youthful, attracting potential recruits who are innovative and eager to contribute to a lively team."
9 |
10 |
--------------------------------------------------------------------------------
/layout_planner/training/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def accuracy(output, target, padding, topk=(1,)):
5 | """Computes the accuracy over the k top predictions for the specified values of k"""
6 | with torch.no_grad():
7 | maxk = max(topk)
8 | if output.shape[-1] < maxk:
9 | print(f"[WARNING] Less than {maxk} predictions available. Using {output.shape[-1]} for topk.")
10 |
11 | maxk = min(maxk, output.shape[-1])
12 | batch_size = target.size(0)
13 |
14 | # Take topk along the last dimension.
15 | _, pred = output.topk(maxk, -1, True, True) # (N, T, topk)
16 |
17 | mask = (target != padding).type(target.dtype)
18 | target_expand = target[..., None].expand_as(pred)
19 | correct = pred.eq(target_expand)
20 | correct = correct * mask[..., None].expand_as(correct)
21 |
22 | res = []
23 | for k in topk:
24 | correct_k = correct[..., :k].reshape(-1).float().sum(0, keepdim=True)
25 | res.append(correct_k.mul_(100.0 / mask.sum()))
26 | return res
27 |
--------------------------------------------------------------------------------
/layout_planner/training/datasets/layout_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 |
4 |
5 | class LayoutDataset(torch.utils.data.Dataset):
6 | def __init__(self, layout_data_path, split_name):
7 | self.split_name = split_name
8 | self.layout_data_path = layout_data_path
9 | with open(self.layout_data_path, 'r') as file:
10 | self.json_datas = json.load(file)
11 |
12 | self.ids = [x for x in range(len(self.json_datas))]
13 | if self.split_name == 'train':
14 | self.ids = self.ids[:-200]
15 | self.json_datas = self.json_datas[:-200]
16 | elif self.split_name == 'val':
17 | self.ids = self.ids[-200:-100]
18 | self.json_datas = self.json_datas[-200:-100]
19 | elif self.split_name == 'test':
20 | self.ids = self.ids[-100:]
21 | self.json_datas = self.json_datas[-100:]
22 | else:
23 | raise NotImplementedError
24 |
25 | def __getitem__(self, idx):
26 | return self.json_datas[idx]['train']
27 |
28 | def __len__(self):
29 | return len(self.json_datas)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/layout_planner/config/13b_z3_no_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "optimizer": {
6 | "type": "AdamW",
7 | "params": {
8 | "lr": "auto",
9 | "betas": "auto",
10 | "eps": "auto",
11 | "weight_decay": "auto"
12 | }
13 | },
14 | "scheduler": {
15 | "type": "WarmupLR",
16 | "params": {
17 | "warmup_min_lr": "auto",
18 | "warmup_max_lr": "auto",
19 | "warmup_num_steps": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 3,
24 | "overlap_comm": true,
25 | "contiguous_gradients": true,
26 | "sub_group_size": 1e9,
27 | "reduce_bucket_size": "auto",
28 | "stage3_prefetch_bucket_size": "auto",
29 | "stage3_param_persistence_threshold": "auto",
30 | "stage3_max_live_parameters": 1e9,
31 | "stage3_max_reuse_distance": 1e9,
32 | "stage3_gather_16bit_weights_on_model_save": true
33 | },
34 | "gradient_accumulation_steps": "auto",
35 | "gradient_clipping": "auto",
36 | "steps_per_print": 2000,
37 | "train_batch_size": "auto",
38 | "train_micro_batch_size_per_gpu": "auto",
39 | "wall_clock_breakdown": false
40 | }
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/multi_layer_gen/configs/multi_layer_resolution1024_test.py:
--------------------------------------------------------------------------------
1 | _base_ = "./base.py"
2 |
3 | ### path & device settings
4 |
5 | output_path_base = "./output/"
6 | cache_dir = None
7 |
8 |
9 | ### wandb settings
10 | wandb_job_name = "flux_" + '{{fileBasenameNoExtension}}'
11 |
12 | resolution = 1024
13 |
14 | ### Model Settings
15 | rank = 64
16 | text_encoder_rank = 64
17 | train_text_encoder = False
18 | max_layer_num = 50 + 2
19 | learnable_proj = True
20 |
21 | ### Training Settings
22 | weighting_scheme = "none"
23 | logit_mean = 0.0
24 | logit_std = 1.0
25 | mode_scale = 1.29
26 | guidance_scale = 1.0 ###IMPORTANT
27 | layer_weighting = 5.0
28 |
29 | # steps
30 | train_batch_size = 1
31 | num_train_epochs = 1
32 | max_train_steps = None
33 | checkpointing_steps = 2000
34 | resume_from_checkpoint = "latest"
35 | gradient_accumulation_steps = 1
36 |
37 | # lr
38 | optimizer = "prodigy"
39 | learning_rate = 1.0
40 | scale_lr = False
41 | lr_scheduler = "constant"
42 | lr_warmup_steps = 0
43 | lr_num_cycles = 1
44 | lr_power = 1.0
45 |
46 | # optim
47 | adam_beta1 = 0.9
48 | adam_beta2 = 0.999
49 | adam_weight_decay = 1e-3
50 | adam_epsilon = 1e-8
51 | prodigy_beta3 = None
52 | prodigy_decouple = True
53 | prodigy_use_bias_correction = True
54 | prodigy_safeguard_warmup = True
55 | max_grad_norm = 1.0
56 |
57 | # logging
58 | tracker_task_name = '{{fileBasenameNoExtension}}'
59 | output_dir = output_path_base + "{{fileBasenameNoExtension}}"
60 |
61 | ### Validation Settings
62 | num_validation_images = 1
63 | validation_steps = 2000
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
ART: Anonymous Region Transformer for
Variable Multi-Layer Transparent Image Generation
2 |
3 |
4 |
5 |
6 |

7 |
8 | This repository supports [generating multi-layer transparent images](#multi-layer-generation) (constructed with multiple RGBA image layers) based on a global text prompt and an anonymous region layout (bounding boxes without layer captions). The anonymous region layout can be either [predicted by LLM](#llm-for-layout-planning) or manually specified by users.
9 |
10 |
11 | ## 🌟 Features
12 | - **Anonymous Layout**: Requires only a single global caption to generate multiple layers, eliminating the need for individual captions for each layer.
13 | - **High Layer Capacity**: Supports the generation of 50+ layers, enabling complex multi-layer outputs.
14 | - **Efficiency**: Maintains high efficiency compared to full attention and spatial-temporal attention mechanisms.
15 |
16 | ## 🛑 Important Notice (updated 2025/07/23)
17 |
18 | This repository previously contained code and pretrained model weights for generating multi-layer transparent images using a global text prompt and anonymous region layout. However, since the model was trained using data that may have come from illegal sources, we have removed the model weights and inference checkpoints from this repository, along with all associated download links. If you have any questions, please contact the original authors through official channels.
19 |
20 | As a result:
21 |
22 | - The pretrained models and associated checkpoints are no longer available for download.
23 | - No runnable or usable code for inference or training is provided.
24 | - We do not provide any means to use or reproduce the model at this time.
25 | - The training code, which relied on this data, has also been removed.
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/layout_planner/training/trainer_layout.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from queue import Queue
3 | from collections import defaultdict
4 |
5 | import torch
6 | from transformers import Trainer
7 |
8 | from .utils import accuracy
9 |
10 |
11 | def l2_loss(u: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
12 | """
13 | Args:
14 | u: (N, T, D) tensor.
15 | v: (N, T, D) tensor.
16 | Returns:
17 | l1_loss: (N,) tensor of summed L1 loss.
18 | """
19 | assert u.shape == v.shape, (u.shape, v.shape)
20 | return ((u - v) ** 2).sum(dim=-1) ** 0.5
21 |
22 |
23 | def batch_purity(input_ids, tokenizer):
24 | strings = tokenizer.batch_decode(input_ids)
25 | strings = [
26 | string.replace(" ", "").replace('> "', '>"').replace("", "")
27 | for string in strings
28 | ]
29 | return strings
30 |
31 |
32 | class Meter:
33 | def __init__(self,size):
34 | self.size = size
35 | self.reset()
36 |
37 | def reset(self):
38 | self.bins = defaultdict(Queue)
39 |
40 | def update(self,metrics):
41 | for k,v in metrics.items():
42 | self.bins[k].put(v)
43 |
44 | def get(self):
45 | metrics = {}
46 | for k,v in self.bins.items():
47 | metrics[k] = np.mean(list(v.queue))
48 | return metrics
49 |
50 |
51 | class TrainerLayout(Trainer):
52 | def __init__(self, extra_args, **kwargs):
53 | self.quantizer = kwargs.pop("quantizer", None)
54 | super().__init__(**kwargs)
55 | self.extra_args = extra_args
56 | weight = torch.ones(self.extra_args.vocab_size)
57 | weight[self.extra_args.old_vocab_size :] = self.extra_args.new_token_weight
58 | self.weighted_ce_loss = torch.nn.CrossEntropyLoss(weight=weight).cuda()
59 | if 'opt-' in self.extra_args.opt_version:
60 | if self.args.fp16:
61 | self.weighted_ce_loss = self.weighted_ce_loss.half()
62 | elif self.args.bf16:
63 | self.weighted_ce_loss = self.weighted_ce_loss.bfloat16()
64 |
65 | self.meter = Meter(self.args.logging_steps)
66 |
67 | def compute_loss(self, model, inputs, return_outputs=False):
68 | labels = inputs.pop("labels").long()
69 |
70 | (
71 | model_output,
72 | full_labels,
73 | input_embs_norm,
74 | ) = model(labels=labels)
75 |
76 | output = model_output.logits
77 |
78 | masked_full_labels = full_labels.clone()
79 | masked_full_labels[masked_full_labels < self.extra_args.old_vocab_size] = -100
80 |
81 | weighted_ce_loss = self.weighted_ce_loss(
82 | output[:, :-1, :].reshape(-1, output.shape[-1]),
83 | full_labels[:, 1:].reshape(-1),
84 | )
85 | ce_loss = weighted_ce_loss * self.extra_args.ce_loss_scale
86 |
87 | loss = ce_loss
88 | acc1, acc5 = accuracy(output[:, :-1, :], full_labels[:, 1:], -100, topk=(1, 5))
89 | masked_acc1, masked_acc5 = accuracy(
90 | output[:, :-1, :], masked_full_labels[:, 1:], -100, topk=(1, 5)
91 | )
92 |
93 | metrics = {
94 | "loss": loss.item(),
95 | "ce_loss": ce_loss.item(),
96 | "top1": float(acc1),
97 | "top5": float(acc5),
98 | "masked_top1": float(masked_acc1),
99 | "masked_top5": float(masked_acc5),
100 | "inp_emb_norm": input_embs_norm.item(),
101 | }
102 | self.meter.update(metrics)
103 | if self.state.global_step % self.args.logging_steps == 0:
104 | metrics = self.meter.get()
105 | self.meter.reset()
106 | self.log(metrics)
107 |
108 | outputs = model_output
109 |
110 | return (loss, outputs) if return_outputs else loss
111 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import argparse
5 | from PIL import Image
6 | from multi_layer_gen.custom_model_mmdit import CustomFluxTransformer2DModel
7 | from multi_layer_gen.custom_model_transp_vae import AutoencoderKLTransformerTraining as CustomVAE
8 | from multi_layer_gen.custom_pipeline import CustomFluxPipelineCfg
9 |
10 | def test_sample(pipeline, transp_vae, batch, args):
11 |
12 | def adjust_coordinate(value, floor_or_ceil, k=16, min_val=0, max_val=1024):
13 | # Round the value to the nearest multiple of k
14 | if floor_or_ceil == "floor":
15 | rounded_value = math.floor(value / k) * k
16 | else:
17 | rounded_value = math.ceil(value / k) * k
18 | # Clamp the value between min_val and max_val
19 | return max(min_val, min(rounded_value, max_val))
20 |
21 | validation_prompt = batch["wholecaption"]
22 | validation_box_raw = batch["layout"]
23 | validation_box = [
24 | (
25 | adjust_coordinate(rect[0], floor_or_ceil="floor"),
26 | adjust_coordinate(rect[1], floor_or_ceil="floor"),
27 | adjust_coordinate(rect[2], floor_or_ceil="ceil"),
28 | adjust_coordinate(rect[3], floor_or_ceil="ceil"),
29 | )
30 | for rect in validation_box_raw
31 | ]
32 | if len(validation_box) > 52:
33 | validation_box = validation_box[:52]
34 |
35 | generator = torch.Generator(device=torch.device("cuda", index=args.gpu_id)).manual_seed(args.seed) if args.seed else None
36 | output, rgba_output, _, _ = pipeline(
37 | prompt=validation_prompt,
38 | validation_box=validation_box,
39 | generator=generator,
40 | height=args.resolution,
41 | width=args.resolution,
42 | num_layers=len(validation_box),
43 | guidance_scale=args.cfg,
44 | num_inference_steps=args.steps,
45 | transparent_decoder=transp_vae,
46 | )
47 | images = output.images # list of PIL, len=layers
48 | rgba_images = [Image.fromarray(arr, 'RGBA') for arr in rgba_output]
49 |
50 | os.makedirs(os.path.join(args.save_dir, this_index), exist_ok=True)
51 | for frame_idx, frame_pil in enumerate(images):
52 | frame_pil.save(os.path.join(args.save_dir, this_index, f"layer_{frame_idx}.png"))
53 | if frame_idx == 0:
54 | frame_pil.save(os.path.join(args.save_dir, this_index, "merged.png"))
55 | merged_pil = images[1].convert('RGBA')
56 | for frame_idx, frame_pil in enumerate(rgba_images):
57 | if frame_idx < 2:
58 | frame_pil = images[frame_idx].convert('RGBA') # merged and background
59 | else:
60 | merged_pil = Image.alpha_composite(merged_pil, frame_pil)
61 | frame_pil.save(os.path.join(args.save_dir, this_index, f"layer_{frame_idx}_rgba.png"))
62 |
63 | merged_pil = merged_pil.convert('RGB')
64 | merged_pil.save(os.path.join(args.save_dir, this_index, "merged_rgba.png"))
65 |
66 |
67 | args = dict(
68 | save_dir="output/",
69 | resolution=512,
70 | cfg=4.0,
71 | steps=28,
72 | seed=41,
73 | gpu_id=0,
74 | )
75 | args = argparse.Namespace(**args)
76 |
77 | transformer = CustomFluxTransformer2DModel.from_pretrained("ART-Release/ART_v1.0", subfolder="transformer", torch_dtype=torch.bfloat16)
78 | transp_vae = CustomVAE.from_pretrained("ART-Release/ART_v1.0", subfolder="transp_vae", torch_dtype=torch.float32)
79 | pipeline = CustomFluxPipelineCfg.from_pretrained(
80 | "black-forest-labs/FLUX.1-dev",
81 | transformer=transformer,
82 | torch_dtype=torch.bfloat16,
83 | ).to(torch.device("cuda", index=args.gpu_id))
84 | pipeline.enable_model_cpu_offload(gpu_id=args.gpu_id) # Save GPU memory
85 |
86 | sample = {
87 | "index": "reso512_3",
88 | "wholecaption": 'Floral wedding invitation: green leaves, white flowers; circular border. Center: "JOIN US CELEBRATING OUR WEDDING" (cursive), "DONNA AND HARPER" (bold), "03 JUNE 2023" (small bold). White, green color scheme, elegant, natural.',
89 | "layout": [(0, 0, 512, 512), (0, 0, 512, 512), (0, 0, 512, 352), (144, 384, 368, 448), (160, 192, 352, 432), (368, 0, 512, 144), (0, 0, 144, 144), (128, 80, 384, 208), (128, 448, 384, 496), (176, 48, 336, 80)],
90 | }
91 |
92 | test_sample(pipeline=pipeline, transp_vae=transp_vae, batch=sample, args=args)
93 |
94 | del pipeline
95 | if torch.cuda.is_available():
96 | torch.cuda.empty_cache()
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
164 | diffusers
165 | *.png
166 | *__pycache__*
167 | wandb
168 | *.safetensors
169 | work_dirs
170 | evaluation
171 | .gradio
172 | .pdf
173 | *.ttf
174 | tools
175 | BACKUP
176 | exps
177 |
178 | output*
179 | multi_layer_gen/output*
180 | multi_layer_gen/pretrained*
181 | test_semi.ipynb
182 |
183 | # not ignore the following files
184 | !assets/teaser.png
185 | !assets/natural_image_data/*/*.png
186 | !assets/semi_transparent_image_data/*/*.png
--------------------------------------------------------------------------------
/layout_planner/training/datasets/layoutnew_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import Any, Dict, List, Union
3 |
4 | import torch
5 | from torch.utils.data import Dataset
6 |
7 | from transformers import DataCollatorForLanguageModeling
8 |
9 | from .layout_dataset import LayoutDataset
10 |
11 |
12 | def get_dataset(args, split: str, quantizer, tokenizer, **kwargs) -> Dataset:
13 | dataset = LayoutNew(quantizer, tokenizer, split=split, max_len=args.max_len, **kwargs)
14 | return dataset
15 |
16 |
17 | class CenterWrapperNew(torch.utils.data.Dataset):
18 | def __init__(self, dataset):
19 | self.__inner_dataset = dataset
20 |
21 | def __getitem__(self, idx):
22 | example = self.__inner_dataset[idx]
23 | return example
24 |
25 | def __len__(self):
26 | return len(self.__inner_dataset)
27 |
28 |
29 | class LayoutNew(torch.utils.data.Dataset):
30 | def __init__(self, quantizer, tokenizer,
31 | split='train', max_len: int = 32,
32 | return_index=False, inf_length=False,
33 | with_layout=True,
34 | split_version='default',
35 | layout_path = None,
36 | **kwargs):
37 | self.tokenizer = tokenizer
38 | self.max_len = max_len
39 | self.quantizer = quantizer
40 | poster_datasets = []
41 | self.split_version = split_version
42 | self.layout_path = layout_path
43 | if with_layout:
44 | if split_version == 'layout':
45 | print(f"with {split_version} data: {self.layout_path}")
46 | inner_dataset = LayoutDataset(self.layout_path, split)
47 | else:
48 | raise NotImplementedError
49 |
50 | poster_datasets.append(inner_dataset)
51 |
52 | self.inner_dataset = torch.utils.data.ConcatDataset(poster_datasets)
53 | self.inner_dataset = CenterWrapperNew(self.inner_dataset)
54 | self.split = split
55 | self.size = len(self.inner_dataset)
56 | self.return_index = return_index
57 | self.inf_length = inf_length
58 | if self.inf_length:
59 | self.max_len = 1000000
60 |
61 | def __getitem__(self, idx):
62 | example = self.inner_dataset[idx]
63 | json_example = self.quantizer.convert2layout(example)
64 | max_len = self.max_len
65 |
66 | content = self.quantizer.dump2json(json_example)
67 |
68 | raw_input_ids = self.tokenizer(content, return_tensors="pt").input_ids[0]
69 |
70 | if len(raw_input_ids) > max_len and self.split == 'train':
71 | start = random.randint(0, len(raw_input_ids) - max_len)
72 | end = start + max_len
73 | input_ids = raw_input_ids[start:end]
74 |
75 | else:
76 | input_ids = raw_input_ids
77 | if not self.inf_length:
78 | if input_ids.shape[0] > max_len:
79 | input_ids = input_ids[:max_len]
80 |
81 | if not self.inf_length:
82 | if input_ids.shape[0] > self.max_len:
83 | input_ids = input_ids[:self.max_len]
84 | elif input_ids.shape[0] < self.max_len:
85 | padding_1 = torch.zeros((1,), dtype=torch.long) + self.tokenizer.bos_token_id
86 | padding_2 = torch.zeros((self.max_len - input_ids.shape[0] - 1,), dtype=torch.long) + self.tokenizer.pad_token_id
87 | input_ids = torch.cat([input_ids, padding_1, padding_2], dim=0)
88 | assert input_ids.shape[0] == self.max_len
89 |
90 | return_dict = {
91 | 'labels': input_ids,
92 | }
93 | if self.return_index:
94 | return_dict['index'] = idx
95 |
96 | return return_dict
97 |
98 | def __len__(self):
99 | return self.size
100 |
101 |
102 | def layout_collate_fn(batch):
103 | input_ids = [item['labels'] for item in batch]
104 | input_ids = torch.stack(input_ids, dim=0)
105 | return_dict = {
106 | 'labels': input_ids,
107 | }
108 | if 'index' in batch[0]:
109 | index = [item['index'] for item in batch]
110 | return_dict['index'] = index
111 |
112 | return return_dict
113 |
114 |
115 | class DataCollatorForLayout(DataCollatorForLanguageModeling):
116 | def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
117 | return layout_collate_fn(examples)
--------------------------------------------------------------------------------
/layout_planner/training/datasets/quantizer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import copy
3 | import numpy as np
4 | from functools import lru_cache
5 | from collections import OrderedDict
6 |
7 |
8 | class BaseQuantizer:
9 | @property
10 | def ignore_tokens(self):
11 | return []
12 |
13 | def __init__(self, simplify_json=False, **kwargs):
14 | self.simplify_json=simplify_json
15 | self.io_ignore_replace_tokens = ['']
16 |
17 | def dump2json(self, json_example):
18 | if self.simplify_json:
19 | content = json.dumps(json_example, separators=(',',':'))
20 | for token in self.additional_special_tokens:
21 | content = content.replace(f'"{token}"', token)
22 | else:
23 | content = json.dumps(json_example)
24 | return content
25 |
26 | def load_json(self, content):
27 | replace_tokens = set(self.additional_special_tokens) - set(self.io_ignore_replace_tokens) # sirui change
28 | if self.simplify_json:
29 | for token in replace_tokens:
30 | content = content.replace(token, f'"{token}"')
31 | return json.loads(content)
32 |
33 |
34 | specs={
35 | "width":"size",
36 | "height":"size",
37 | "x":"pos", # center x
38 | "y":"pos", # center y
39 | "color":"color",
40 | "font":"font"
41 | }
42 |
43 |
44 | min_max_bins = {
45 | 'size': (0,1,256),
46 | 'pos': (0,1,256),
47 | 'color': (0,137,138),
48 | 'font': (0,511,512)
49 | }
50 |
51 |
52 | class QuantizerV4(BaseQuantizer):
53 | def __init__(self, quant=True, **kwargs):
54 | super().__init__(**kwargs)
55 | self.min = min
56 | self.max = max
57 | self.quant = quant
58 | self.text_split_token = ''
59 | self.set_min_max_bins(min_max_bins)
60 | self.width = kwargs.get('width', 1024)
61 | self.height = kwargs.get('height', 1024)
62 | self.width = int(self.width)
63 | self.height = int(self.height)
64 |
65 | def set_min_max_bins(self, min_max_bins):
66 | min_max_bins = copy.deepcopy(min_max_bins)
67 | # adjust the bins to plus one
68 | for type_name, (min_val, max_val, n_bins) in min_max_bins.items():
69 | assert n_bins % 2 == 0
70 | min_max_bins[type_name] = (min_val, max_val, n_bins+1)
71 | self.min_max_bins = min_max_bins
72 |
73 | def setup_tokenizer(self, tokenizer):
74 | additional_special_tokens = [self.text_split_token]
75 | rest_types = [key for key in self.min_max_bins.keys()]
76 | for type_name in rest_types:
77 | min_val, max_val,n_bins = self.min_max_bins[type_name]
78 | additional_special_tokens += [f'<{type_name}-{i}>' for i in range(n_bins)]
79 |
80 | print('additional_special_tokens', additional_special_tokens)
81 |
82 | tokenizer.add_special_tokens({'additional_special_tokens': additional_special_tokens})
83 | self.additional_special_tokens = set(additional_special_tokens)
84 | return tokenizer
85 |
86 |
87 | @lru_cache(maxsize=128)
88 | def get_bins(self, real_type):
89 | min_val, max_val, n_bins = self.min_max_bins[real_type]
90 | return min_val, max_val, np.linspace(min_val, max_val, n_bins)
91 |
92 | def quantize(self, x, type):
93 | if not self.quant:
94 | return x
95 | """Quantize a float array x into n_bins discrete values."""
96 | real_type = specs[type]
97 | min_val, max_val, bins = self.get_bins(real_type)
98 | x = np.clip(float(x), min_val, max_val)
99 | val = np.digitize(x, bins) - 1
100 | n_bins = len(bins)
101 | assert val >= 0 and val < n_bins
102 | return f'<{real_type}-{val}>'
103 |
104 | def dequantize(self, x):
105 | # ->1
106 | val = x.split('-')[1].strip('>')
107 | # ->pos
108 | real_type = x.split('-')[0][1:]
109 | min_val, max_val, bins = self.get_bins(real_type)
110 | return bins[int(val)]
111 |
112 | def construct_map_dict(self):
113 | map_dict = {}
114 | for i in range(self.min_max_bins['size'][2]):
115 | name = "" % i
116 | value = self.dequantize(name)
117 | map_dict[name] = str(value)
118 | for i in range(self.min_max_bins['pos'][2]):
119 | name = "" % i
120 | value = self.dequantize(name)
121 | map_dict[name] = str(value)
122 | return map_dict
123 |
124 | def postprocess_colorandfont(self, json_example):
125 | import re
126 | json_example = re.sub(r'()', r'"\1"', json_example)
127 | json_example = re.sub(r'()', r'"\1"', json_example)
128 | return json_example
129 |
130 | def convert2layout(self, example):
131 | new_example = OrderedDict()
132 | new_example['wholecaption'] = example['wholecaption']
133 | new_layout = []
134 | for meta_layer in example['layout']:
135 | new_layout.append({
136 | "layer": meta_layer["layer"],
137 | "x": self.quantize(meta_layer["x"]/self.width, 'x'),
138 | "y": self.quantize(meta_layer["y"]/self.height, 'y'),
139 | "width": self.quantize(meta_layer["width"]/self.width, 'width'),
140 | "height": self.quantize(meta_layer["height"]/self.height, 'height')
141 | })
142 | new_example['layout'] = new_layout
143 | return new_example
144 |
145 |
146 | def get_quantizer(version='v1', **kwargs):
147 | if version == 'v4':
148 | quantizer = QuantizerV4(**kwargs)
149 | else:
150 | raise NotImplementedError
151 |
152 | return quantizer
153 |
154 |
--------------------------------------------------------------------------------
/layout_planner/models/modeling_layout.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from typing import Optional, List
4 | from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoModelForCausalLM, OPTForCausalLM, BitsAndBytesConfig
5 |
6 |
7 | def kmp_preprocess(pattern):
8 | pattern_len = len(pattern)
9 | prefix_suffix = [0] * pattern_len
10 | j = 0
11 |
12 | for i in range(1, pattern_len):
13 | while j > 0 and pattern[i] != pattern[j]:
14 | j = prefix_suffix[j - 1]
15 |
16 | if pattern[i] == pattern[j]:
17 | j += 1
18 |
19 | prefix_suffix[i] = j
20 |
21 | return prefix_suffix
22 |
23 |
24 | def kmp_search(text, pattern):
25 | text_len = len(text)
26 | pattern_len = len(pattern)
27 | prefix_suffix = kmp_preprocess(pattern)
28 | matches = []
29 |
30 | j = 0
31 | for i in range(text_len):
32 | while j > 0 and text[i] != pattern[j]:
33 | j = prefix_suffix[j - 1]
34 |
35 | if text[i] == pattern[j]:
36 | j += 1
37 |
38 | if j == pattern_len:
39 | matches.append(i - j + 1)
40 | j = prefix_suffix[j - 1]
41 |
42 | return matches
43 |
44 |
45 | class ModelWrapper:
46 | def __init__(self, model):
47 | self.model = model
48 |
49 | def __getattr__(self, name):
50 | return getattr(self.model, name)
51 |
52 | @torch.no_grad()
53 | def __call__(self, pixel_values):
54 | return self.model(pixel_values)
55 |
56 | def eval(self):
57 | pass
58 |
59 | def train(self):
60 | pass
61 |
62 | def parameters(self):
63 | return self.model.parameters()
64 |
65 |
66 | class LayoutModelConfig(PretrainedConfig):
67 | def __init__(
68 | self,
69 | old_vocab_size: int = 32000,
70 | vocab_size: int = 32000,
71 | pad_token_id: int = 2,
72 | freeze_lm: bool = True,
73 | opt_version: str = 'facebook/opt-6.7b',
74 | hidden_size: int = -1,
75 | load_in_4bit: Optional[bool] = False,
76 | ignore_ids: List[int] = [],
77 | **kwargs,
78 | ):
79 | super().__init__(**kwargs)
80 | assert old_vocab_size > 0, 'old_vocab_size must be positive'
81 | assert vocab_size > 0, 'vocab_size must be positive'
82 |
83 | self.old_vocab_size = old_vocab_size
84 | self.vocab_size = vocab_size
85 | self.pad_token_id = pad_token_id
86 | self.freeze_lm = freeze_lm
87 | self.opt_version = opt_version
88 | self.hidden_size = hidden_size
89 | self.load_in_4bit = load_in_4bit
90 | self.ignore_ids = ignore_ids
91 |
92 |
93 | class LayoutModel(PreTrainedModel):
94 | config_class = LayoutModelConfig
95 | supports_gradient_checkpointing = True
96 |
97 | def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
98 | self.lm.gradient_checkpointing_enable()
99 |
100 | def __init__(self, config: LayoutModelConfig):
101 | super().__init__(config)
102 | self.pad_token_id = config.pad_token_id
103 |
104 | self.args = config
105 |
106 | opt_version = config.opt_version
107 |
108 | print(f"Using {opt_version} for the language model.")
109 |
110 | if config.load_in_4bit:
111 | print("\n would load_in_4bit")
112 | quantization_config = BitsAndBytesConfig(
113 | load_in_4bit=config.load_in_4bit
114 | )
115 | # This means: fit the entire model on the GPU:0
116 | local_rank = int(os.environ.get("LOCAL_RANK", 0))
117 | device_map = {"": local_rank}
118 | torch_dtype = torch.bfloat16
119 | else:
120 | print("\n wouldn't load_in_4bit")
121 | device_map = None
122 | quantization_config = None
123 | torch_dtype = None
124 |
125 | self.lm = AutoModelForCausalLM.from_pretrained(
126 | opt_version,
127 | quantization_config=quantization_config,
128 | device_map=device_map,
129 | trust_remote_code=True,
130 | attn_implementation="flash_attention_2",
131 | torch_dtype=torch.bfloat16,
132 | )
133 | self.config.hidden_size = self.lm.config.hidden_size
134 | self.opt_version = opt_version
135 |
136 | if self.args.freeze_lm:
137 | self.lm.eval()
138 | print("Freezing the LM.")
139 | for param in self.lm.parameters():
140 | param.requires_grad = False
141 | else:
142 | print("\n no freeze lm, so to train lm")
143 | self.lm.train()
144 | self.lm.config.gradient_checkpointing = True
145 |
146 | print('resize token embeddings to match the tokenizer', config.vocab_size)
147 | self.lm.resize_token_embeddings(config.vocab_size)
148 | self.input_embeddings = self.lm.get_input_embeddings()
149 |
150 | def train(self, mode=True):
151 | super().train(mode=mode)
152 | # Overwrite train() to ensure frozen models remain frozen.
153 | if self.args.freeze_lm:
154 | self.lm.eval()
155 |
156 | def forward(
157 | self,
158 | labels: torch.LongTensor,
159 | ):
160 | batch_size = labels.shape[0]
161 | full_labels = labels.detach().clone()
162 |
163 | input_embs = self.input_embeddings(labels) # (N, T, D)
164 | input_embs_norm = ((input_embs ** 2).sum(dim=-1) ** 0.5).mean()
165 |
166 | for ignore_id in self.config.ignore_ids:
167 | full_labels[full_labels == ignore_id] = -100
168 |
169 | pad_idx = []
170 | # -100 is the ignore index for cross entropy loss. https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
171 | for label in full_labels:
172 | for k, token in enumerate(label):
173 | # Mask out pad tokens if they exist.
174 | if token in [self.pad_token_id]:
175 | label[k:] = -100
176 | pad_idx.append(k)
177 | break
178 | if k == len(label) - 1: # No padding found.
179 | pad_idx.append(k + 1)
180 | assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
181 |
182 | output = self.lm( inputs_embeds=input_embs,
183 | labels=full_labels,
184 | output_hidden_states=True)
185 |
186 | return output, full_labels, input_embs_norm
--------------------------------------------------------------------------------
/layout_planner/inference_layout.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import argparse
4 | from typing import List
5 |
6 | import torch
7 | from transformers import AutoTokenizer, set_seed
8 | from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \
9 | STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings
10 |
11 | from models.modeling_layout import LayoutModel, LayoutModelConfig
12 | from training.datasets.quantizer import get_quantizer
13 |
14 |
15 | class StopAtSpecificTokenCriteria(StoppingCriteria):
16 | def __init__(self, token_id_list: List[int] = None):
17 | self.token_id_list = token_id_list
18 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
19 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
20 | return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list
21 |
22 |
23 | # build model and tokenizer
24 | def buildmodel(device='cuda:0',**kwargs):
25 | # seed / input model / resume
26 | resume = kwargs.get('resume', None)
27 | seed = kwargs.get('seed', None)
28 | input_model = kwargs.get('input_model', None)
29 | quantizer_version = kwargs.get('quantizer_version', 'v4')
30 |
31 | set_seed(seed)
32 | old_tokenizer = AutoTokenizer.from_pretrained(input_model, trust_remote_code=True)
33 | old_vocab_size = len(old_tokenizer)
34 | print(f"Old vocab size: {old_vocab_size}")
35 |
36 | tokenizer = AutoTokenizer.from_pretrained(resume, trust_remote_code=True)
37 |
38 | new_vocab_size = len(tokenizer)
39 | print(f"New vocab size: {new_vocab_size}")
40 | quantizer = get_quantizer(quantizer_version,
41 | simplify_json = True,
42 | width = kwargs['width'],
43 | height = kwargs['height']
44 | )
45 | quantizer.setup_tokenizer(tokenizer)
46 | print(f"latest tokenzier size: {len(tokenizer)}")
47 |
48 | model_args = LayoutModelConfig(
49 | old_vocab_size = old_vocab_size,
50 | vocab_size=len(tokenizer),
51 | pad_token_id=tokenizer.pad_token_id,
52 | ignore_ids=tokenizer.convert_tokens_to_ids(quantizer.ignore_tokens),
53 | )
54 |
55 | model_args.opt_version = input_model
56 | model_args.freeze_lm = False
57 | model_args.load_in_4bit = kwargs.get('load_in_4bit', False)
58 |
59 | print(f"Resuming from checkpoint {resume}, Waiting to ready")
60 | model = LayoutModel.from_pretrained(resume, config=model_args).to(device)
61 |
62 | return model, quantizer, tokenizer
63 |
64 |
65 | def preprocess_Input(intention: str):
66 |
67 | intention = intention.replace('\n', '').replace('\r', '').replace('\\', '')
68 | intention = re.sub(r'\.\s*', '. ', intention)
69 |
70 | return intention
71 |
72 |
73 | # build data
74 | def FormulateInput(intention: str):
75 | '''
76 | Formulate user input string to Dict Object
77 | '''
78 | resdict = {}
79 | resdict["wholecaption"] = intention
80 | resdict["layout"] = []
81 |
82 | return resdict
83 |
84 |
85 | @torch.no_grad()
86 | def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, device, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
87 | json_example = inputs
88 | input_intention = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
89 | print("input_intention:\n", input_intention)
90 |
91 | inputs = tokenizer(
92 | input_intention, return_tensors="pt"
93 | ).to(device)
94 |
95 | stopping_criteria = StoppingCriteriaList()
96 | stopping_criteria.append(StopAtSpecificTokenCriteria(token_id_list=[128000]))
97 |
98 | outputs = model.lm.generate(**inputs, use_cache=True, max_length=8000, stopping_criteria=stopping_criteria, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
99 | inputs_length = inputs['input_ids'].shape[1]
100 | outputs = outputs[:, inputs_length:]
101 |
102 | outputs_word = tokenizer.batch_decode(outputs)[0]
103 | split_word = outputs_word.split('}]}')[0]+"}]}"
104 | split_word = '{"wholecaption":"' + json_example["wholecaption"].replace('\n', '\\n').replace('"', '\\"') + '","layout":[{"layer":' + split_word
105 |
106 | map_dict = quantizer.construct_map_dict()
107 | for key ,value in map_dict.items():
108 | split_word = split_word.replace(key, value)
109 |
110 | try:
111 | pred_json_example = json.loads(split_word)
112 | for layer in pred_json_example["layout"]:
113 | layer['x'] = round(int(width)*layer['x'])
114 | layer['y'] = round(int(height)*layer['y'])
115 | layer['width'] = round(int(width)*layer['width'])
116 | layer['height'] = round(int(height)*layer['height'])
117 | except Exception as e:
118 | print(e)
119 | pred_json_example = None
120 | return pred_json_example
121 |
122 |
123 | if __name__ == '__main__':
124 |
125 | parser = argparse.ArgumentParser()
126 | parser.add_argument('--inference_caption', type=str, help='User input whole caption')
127 | parser.add_argument('--save_path', type=str, help='Path to save data')
128 | parser.add_argument('--device', type=str, default='cuda:0')
129 | parser.add_argument('--width', type=int, default=1024, help='Width of the layout')
130 | parser.add_argument('--height', type=int, default=1024, help='Height of the layout')
131 | parser.add_argument('--input_model', type=str, help='Path to input base model')
132 | parser.add_argument('--resume', type=str, help='Path to test model checlpoint')
133 | parser.add_argument('--do_sample', type=bool, default=False)
134 | parser.add_argument('--temperature', type=float, default=0.5)
135 |
136 | args = parser.parse_args()
137 |
138 | inference_caption = args.inference_caption
139 | save_path = args.save_path
140 | device = args.device
141 | width = args.width
142 | height = args.height
143 | input_model = args.input_model
144 | resume = args.resume
145 | do_sample = args.do_sample
146 | temperature = args.temperature
147 |
148 | params_dict = {
149 | "input_model": input_model,
150 | "resume": resume,
151 | "seed": 0,
152 | "quantizer_version": 'v4',
153 | "width": width,
154 | "height": height,
155 | }
156 |
157 | # Init model
158 | model, quantizer, tokenizer = buildmodel(device=device, **params_dict)
159 | model = model.to(device)
160 | model = model.bfloat16()
161 | model.eval()
162 |
163 | intention = preprocess_Input(inference_caption)
164 | rawdata = FormulateInput(intention)
165 | preddata = evaluate_v1(rawdata, model, quantizer, tokenizer, width, height, device, do_sample=do_sample, temperature=temperature)
166 | max_try_time = 3
167 | while preddata is None and max_try_time > 0:
168 | preddata = evaluate_v1(rawdata, model, quantizer, tokenizer, width, height, device, do_sample=do_sample, temperature=temperature)
169 | max_try_time -= 1
170 |
171 | print("output : ", preddata)
172 |
173 | with open(save_path,'w') as file:
174 | json.dump(preddata, file, indent=4)
175 |
--------------------------------------------------------------------------------
/multi_layer_gen/configs/multi_layer_resolution512_test.py:
--------------------------------------------------------------------------------
1 | _base_ = "./base.py"
2 |
3 | ### path & device settings
4 |
5 | output_path_base = "./output/"
6 | cache_dir = None
7 |
8 |
9 | ### wandb settings
10 | wandb_job_name = "flux_" + '{{fileBasenameNoExtension}}'
11 |
12 | resolution = 512
13 |
14 | ### Model Settings
15 | rank = 64
16 | text_encoder_rank = 64
17 | train_text_encoder = False
18 | max_layer_num = 50 + 2
19 | learnable_proj = True
20 |
21 | ### Training Settings
22 | weighting_scheme = "none"
23 | logit_mean = 0.0
24 | logit_std = 1.0
25 | mode_scale = 1.29
26 | guidance_scale = 1.0 ###IMPORTANT
27 | layer_weighting = 5.0
28 |
29 | # steps
30 | train_batch_size = 1
31 | num_train_epochs = 1
32 | max_train_steps = None
33 | checkpointing_steps = 2000
34 | resume_from_checkpoint = "latest"
35 | gradient_accumulation_steps = 1
36 |
37 | # lr
38 | optimizer = "prodigy"
39 | learning_rate = 1.0
40 | scale_lr = False
41 | lr_scheduler = "constant"
42 | lr_warmup_steps = 0
43 | lr_num_cycles = 1
44 | lr_power = 1.0
45 |
46 | # optim
47 | adam_beta1 = 0.9
48 | adam_beta2 = 0.999
49 | adam_weight_decay = 1e-3
50 | adam_epsilon = 1e-8
51 | prodigy_beta3 = None
52 | prodigy_decouple = True
53 | prodigy_use_bias_correction = True
54 | prodigy_safeguard_warmup = True
55 | max_grad_norm = 1.0
56 |
57 | # logging
58 | tracker_task_name = '{{fileBasenameNoExtension}}'
59 | output_dir = output_path_base + "{{fileBasenameNoExtension}}"
60 |
61 | ### Validation Settings
62 | num_validation_images = 1
63 | validation_steps = 2000
64 | validation_prompts = [
65 | 'The image features a background with a soft, pastel color gradient that transitions from pink to purple. There are abstract floral elements scattered throughout the background, with some appearing to be in full bloom and others in a more delicate, bud-like state. The flowers have a watercolor effect, with soft edges that blend into the background.\n\nCentered in the image is a quote in a serif font that reads, "You\'re free to be different." The text is black, which stands out against the lighter background. The overall style of the image is artistic and inspirational, with a motivational message that encourages individuality and self-expression. The image could be used for motivational purposes, as a background for a blog or social media post, or as part of a personal development or self-help theme.',
66 | 'The image features a logo for a company named "Bull Head Party Adventure." The logo is stylized with a cartoon-like depiction of a bull\'s head, which is the central element of the design. The bull has prominent horns and a fierce expression, with its mouth slightly open as if it\'s snarling or roaring. The color scheme of the bull is a mix of brown and beige tones, with the horns highlighted in a lighter shade.\n\nBelow the bull\'s head, the company name is written in a bold, sans-serif font. The text is arranged in two lines, with "Bull Head" on the top line and "Party Adventure" on the bottom line. The font color matches the color of the bull, creating a cohesive look. The overall style of the image is playful and energetic, suggesting that the company may offer exciting or adventurous party experiences.',
67 | 'The image features a festive and colorful illustration with a theme related to the Islamic holiday of Eid al-Fitr. At the center of the image is a large, ornate crescent moon with intricate patterns and decorations. Surrounding the moon are several smaller stars and crescents, also adorned with decorative elements. These smaller celestial motifs are suspended from the moon, creating a sense of depth and dimension.\n\nBelow the central moon, there is a banner with the text "Eid Mubarak" in a stylized, elegant font. The text is in a bold, dark color that stands out against the lighter background. The background itself is a gradient of light to dark green, which complements the golden and white hues of the celestial motifs.\n\nThe overall style of the image is celebratory and decorative, with a focus on the traditional symbols associated with Eid al-Fitr. The use of gold and white gives the image a luxurious and festive feel, while the green background is a color often associated with Islam. The image appears to be a digital artwork or graphic design, possibly intended for use as a greeting card or a festive decoration.',
68 | 'The image is a festive graphic with a dark background. At the center, there is a large, bold text that reads "Happy New Year 2023" in a combination of white and gold colors. The text is surrounded by numerous white balloons with gold ribbons, giving the impression of a celebratory atmosphere. The balloons are scattered around the text, creating a sense of depth and movement. Additionally, there are small gold sparkles and confetti-like elements that add to the celebratory theme. The overall design suggests a New Year\'s celebration, with the year 2023 being the focal point.',
69 | 'The image is a stylized illustration with a flat design aesthetic. It depicts a scene related to healthcare or medical care. In the center, there is a hospital bed with a patient lying down, appearing to be resting or possibly receiving treatment. The patient is surrounded by three individuals who seem to be healthcare professionals or caregivers. They are standing around the bed, with one on each side and one at the foot of the bed. The person at the foot of the bed is holding a clipboard, suggesting they might be taking notes or reviewing medical records.\n\nThe room has a window with curtains partially drawn, allowing some light to enter. The color palette is soft, with pastel tones dominating the scene. The text "INTERNATIONAL CANCER DAY" is prominently displayed at the top of the image, indicating that the illustration is related to this event. The overall impression is one of care and support, with a focus on the patient\'s well-being.',
70 | 'The image features a stylized illustration of a man with a beard and a tank top, drinking from a can. The man is depicted in a simplified, cartoon-like style with a limited color palette. Above him, there is a text that reads "Happy Eating, Friends" in a bold, friendly font. Below the illustration, there is another line of text that states "Food is a Necessity That is Not Prioritized," which is also in a bold, sans-serif font. The background of the image is a gradient of light to dark blue, giving the impression of a sky or a calm, serene environment. The overall style of the image is casual and approachable, with a focus on the message conveyed by the text.',
71 | 'The image is a digital illustration with a pastel pink background. At the top, there is a text that reads "Sending you my Easter wishes" in a simple, sans-serif font. Below this, a larger text states "May Your Heart be Happy!" in a more decorative, serif font. Underneath this main message, there is a smaller text that says "Let the miracle of the season fill you with hope and love."\n\nThe illustration features three stylized flowers with smiling faces. On the left, there is a purple flower with a yellow center. In the center, there is a blue flower with a green center. On the right, there is a pink flower with a yellow center. Each flower has a pair of eyes and a mouth, giving them a friendly appearance. The flowers are drawn with a cartoon-like style, using solid colors and simple shapes.\n\nThe overall style of the image is cheerful and whimsical, with a clear Easter theme suggested by the text and the presence of flowers, which are often associated with spring and new beginnings.',
72 | 'The image is a vibrant and colorful graphic with a pink background. In the center, there is a photograph of a man and a woman embracing each other. The man is wearing a white shirt, and the woman is wearing a patterned top. They are both smiling and appear to be in a joyful mood.\n\nSurrounding the photograph are various elements that suggest a festive or celebratory theme. There are three hot air balloons in the background, each with a different design: one with a heart, one with a gift box, and one with a basket. These balloons are floating against a clear sky.\n\nAdditionally, there are two gift boxes with ribbons, one on the left and one on the right side of the image. These gift boxes are stylized with a glossy finish and are placed at different heights, creating a sense of depth.\n\nAt the bottom of the image, there is a large red heart, which is a common symbol associated with love and Valentine\'s Day.\n\nFinally, at the very bottom of the image, there is a text that reads "Happy Valentine\'s Day," which confirms the theme of the image as a Valentine\'s Day greeting. The text is in a playful, cursive font that matches the overall cheerful and romantic tone of the image.',
73 | 'The image depicts a stylized illustration of two women sitting on stools, engaged in conversation. They are wearing traditional attire, with headscarves and patterned dresses. The woman on the left is wearing a brown dress with a purple pattern, while the woman on the right is wearing a purple dress with a brown pattern. Between them is a purple flower. Above the women, the text "INTERNATIONAL WOMEN\'S DAY" is written in bold, uppercase letters. The background is a soft, pastel pink, and there are abstract, swirling lines in a darker shade of pink above the women. The overall style of the image is simplistic and cartoonish, with a warm and friendly tone.',
74 | 'The image is a stylized illustration with a warm, peach-colored background. At the center, there is a vintage-style radio with a prominent dial and antenna. The radio is emitting a blue, star-like burst of light or energy from its top. Surrounding the radio are various objects and elements that seem to be floating or suspended in the air. These include a brown, cone-shaped object, a blue, star-like shape, and a brown, wavy, abstract shape that could be interpreted as a flower or a wave.\n\nAt the top of the image, there is text that reads "World Radio Day" in a bold, serif font. Below this, in a smaller, sans-serif font, is the date "13 February 2022." The overall style of the image is playful and cartoonish, with a clear focus on celebrating World Radio Day.',
75 | 'The image is a graphic design of a baby shower invitation. The central focus is a cute, cartoon-style teddy bear with a friendly expression, sitting upright. The bear is colored in a soft, light brown hue. Above the bear, there is a bold text that reads "YOU\'RE INVITED" in a playful, sans-serif font. Below this, the words "BABY SHOWER" are prominently displayed in a larger, more decorative font, suggesting the theme of the event.\n\nThe background of the invitation is a soft, light pink color, which adds to the gentle and welcoming atmosphere of the design. At the bottom of the image, there is additional text providing specific details about the event. It reads "27 January, 2022 - 8:00 PM" followed by "FAUGET INDUSTRIES CAFE," indicating the date, time, and location of the baby shower.\n\nThe overall style of the image is warm, inviting, and child-friendly, with a clear focus on the theme of a baby shower celebration. The use of a teddy bear as the central image reinforces the baby-related theme. The design is simple yet effective, with a clear hierarchy of information that guides the viewer\'s attention from the top to the bottom of the invitation.',
76 | ]
77 |
78 | validation_boxes = [
79 | [(0, 0, 512, 512), (0, 0, 512, 512), (368, 0, 512, 272), (0, 272, 112, 512), (160, 208, 352, 304)],
80 | [(0, 0, 512, 512), (0, 0, 512, 512), (128, 128, 384, 304), (96, 288, 416, 336), (128, 336, 384, 368)],
81 | [(0, 0, 512, 512), (0, 0, 512, 512), (112, 48, 400, 368), (0, 48, 96, 176), (128, 336, 384, 384), (240, 384, 384, 432)],
82 | [(0, 0, 512, 512), (0, 0, 512, 512), (32, 32, 480, 480), (80, 176, 432, 368), (64, 176, 448, 224), (144, 96, 368, 224)],
83 | [(0, 0, 512, 512), (0, 0, 512, 512), (0, 64, 176, 272), (0, 400, 512, 512), (16, 160, 496, 512), (224, 48, 464, 112), (208, 96, 464, 160)],
84 | [(0, 0, 512, 512), (0, 0, 512, 512), (112, 224, 512, 512), (0, 0, 240, 160), (144, 144, 512, 512), (48, 64, 432, 208), (48, 400, 256, 448)],
85 | [(0, 0, 512, 512), (0, 0, 512, 512), (160, 48, 352, 80), (64, 80, 448, 192), (128, 208, 384, 240), (320, 240, 512, 512), (80, 272, 368, 512), (0, 224, 192, 512)],
86 | [(0, 0, 512, 512), (0, 0, 512, 512), (48, 0, 464, 304), (128, 144, 384, 400), (288, 288, 384, 368), (336, 304, 400, 368), (176, 432, 336, 480), (224, 400, 288, 432)],
87 | [(0, 0, 512, 512), (0, 0, 512, 512), (32, 288, 448, 512), (144, 176, 336, 400), (224, 208, 272, 256), (160, 128, 336, 192), (192, 368, 304, 400), (368, 80, 448, 224), (48, 160, 128, 256)],
88 | [(0, 0, 512, 512), (0, 0, 512, 512), (0, 352, 512, 512), (112, 176, 368, 432), (48, 176, 128, 256), (48, 368, 128, 448), (384, 192, 480, 272), (384, 336, 432, 384), (80, 80, 432, 128), (176, 128, 336, 160)],
89 | [(0, 0, 512, 512), (0, 0, 512, 512), (0, 0, 512, 352), (144, 384, 368, 448), (160, 192, 352, 432), (368, 0, 512, 144), (0, 0, 144, 144), (128, 80, 384, 208), (128, 448, 384, 496), (176, 48, 336, 80)],
90 | ]
--------------------------------------------------------------------------------
/multi_layer_gen/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 |
16 | import os
17 | import sys
18 | import math
19 | import random
20 | import argparse
21 | import numpy as np
22 | from PIL import Image
23 | from tqdm import tqdm
24 | from mmengine.config import Config
25 |
26 | import torch
27 | import torch.utils.checkpoint
28 | from torchvision.utils import save_image
29 |
30 | from diffusers import FluxTransformer2DModel
31 | from diffusers.utils import check_min_version
32 | from diffusers.configuration_utils import FrozenDict
33 | from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
34 |
35 | from custom_model_mmdit import CustomFluxTransformer2DModel
36 | from custom_model_transp_vae import AutoencoderKLTransformerTraining as CustomVAE
37 | from custom_pipeline import CustomFluxPipelineCfg
38 |
39 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
40 | check_min_version("0.31.0.dev0")
41 |
42 |
43 | def seed_everything(seed):
44 | random.seed(seed)
45 | np.random.seed(seed)
46 | torch.manual_seed(seed)
47 | if torch.cuda.is_available():
48 | torch.cuda.manual_seed(seed)
49 | torch.cuda.manual_seed_all(seed)
50 | torch.backends.cudnn.deterministic = True
51 |
52 |
53 | def parse_config(path=None):
54 |
55 | if path is None:
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument('config_dir', type=str)
58 | args = parser.parse_args()
59 | path = args.config_dir
60 | config = Config.fromfile(path)
61 |
62 | config.config_dir = path
63 |
64 | if "LOCAL_RANK" in os.environ:
65 | config.local_rank = int(os.environ["LOCAL_RANK"])
66 | elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
67 | config.local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
68 | else:
69 | config.local_rank = -1
70 |
71 | return config
72 |
73 |
74 | def initialize_pipeline(config, args):
75 |
76 | # Load the original Transformer model from the pretrained model
77 | transformer_orig = FluxTransformer2DModel.from_pretrained(
78 | config.transformer_varient if hasattr(config, "transformer_varient") else config.pretrained_model_name_or_path,
79 | subfolder="" if hasattr(config, "transformer_varient") else "transformer",
80 | revision=config.revision,
81 | variant=config.variant,
82 | torch_dtype=torch.bfloat16,
83 | cache_dir=config.get("cache_dir", None),
84 | )
85 |
86 | # Configure the custom Transformer model
87 | mmdit_config = dict(transformer_orig.config)
88 | mmdit_config["_class_name"] = "CustomSD3Transformer2DModel"
89 | mmdit_config["max_layer_num"] = config.max_layer_num
90 | mmdit_config = FrozenDict(mmdit_config)
91 | transformer = CustomFluxTransformer2DModel.from_config(mmdit_config).to(dtype=torch.bfloat16)
92 | missing_keys, unexpected_keys = transformer.load_state_dict(transformer_orig.state_dict(), strict=False)
93 |
94 | # Fuse initial LoRA weights
95 | if args.pre_fuse_lora_dir is not None:
96 | lora_state_dict = CustomFluxPipelineCfg.lora_state_dict(args.pre_fuse_lora_dir)
97 | CustomFluxPipelineCfg.load_lora_into_transformer(lora_state_dict, None, transformer)
98 | transformer.fuse_lora(safe_fusing=True)
99 | transformer.unload_lora() # Unload LoRA parameters
100 |
101 | # Load layer_pe weights
102 | layer_pe_path = os.path.join(args.ckpt_dir, "layer_pe.pth")
103 | layer_pe = torch.load(layer_pe_path)
104 | missing_keys, unexpected_keys = transformer.load_state_dict(layer_pe, strict=False)
105 |
106 | # Initialize the custom pipeline
107 | pipeline_type = CustomFluxPipelineCfg
108 | pipeline = pipeline_type.from_pretrained(
109 | config.pretrained_model_name_or_path,
110 | transformer=transformer,
111 | revision=config.revision,
112 | variant=config.variant,
113 | torch_dtype=torch.bfloat16,
114 | cache_dir=config.get("cache_dir", None),
115 | ).to(torch.device("cuda", index=args.gpu_id))
116 | pipeline.enable_model_cpu_offload(gpu_id=args.gpu_id) # Save GPU memory
117 |
118 | # Load LoRA weights
119 | pipeline.load_lora_weights(args.ckpt_dir, adapter_name="layer")
120 |
121 | # Load additional LoRA weights
122 | if args.extra_lora_dir is not None:
123 | _SET_ADAPTER_SCALE_FN_MAPPING["CustomFluxTransformer2DModel"] = _SET_ADAPTER_SCALE_FN_MAPPING["FluxTransformer2DModel"]
124 | pipeline.load_lora_weights(args.extra_lora_dir, adapter_name="extra")
125 | pipeline.set_adapters(["layer", "extra"], adapter_weights=[1.0, 0.5])
126 |
127 | return pipeline
128 |
129 | def get_fg_layer_box(list_layer_pt):
130 | list_layer_box = []
131 | for layer_pt in list_layer_pt:
132 | alpha_channel = layer_pt[:, 3:4]
133 |
134 | if layer_pt.shape[1] == 3:
135 | list_layer_box.append(
136 | (0, 0, layer_pt.shape[3], layer_pt.shape[2])
137 | )
138 | continue
139 |
140 | # Step 1: Find the non-zero indices
141 | _, _, rows, cols = torch.nonzero(alpha_channel + 1, as_tuple=True)
142 |
143 | if (rows.numel() == 0) or (cols.numel() == 0):
144 | # If there are no non-zero indices, we can skip this layer
145 | list_layer_box.append(None)
146 | continue
147 |
148 | # Step 2: Get the minimum and maximum indices for rows and columns
149 | min_row, max_row = rows.min().item(), rows.max().item()
150 | min_col, max_col = cols.min().item(), cols.max().item()
151 |
152 | # Step 3: Quantize the minimum values down to the nearest multiple of 16
153 | quantized_min_row = (min_row // 16) * 16
154 | quantized_min_col = (min_col // 16) * 16
155 |
156 | # Step 4: Quantize the maximum values up to the nearest multiple of 16 outside of the max
157 | quantized_max_row = ((max_row // 16) + 1) * 16
158 | quantized_max_col = ((max_col // 16) + 1) * 16
159 | list_layer_box.append(
160 | (quantized_min_col, quantized_min_row, quantized_max_col, quantized_max_row)
161 | )
162 | return list_layer_box
163 |
164 |
165 | def adjust_coordinate(value, floor_or_ceil, k=16, min_val=0, max_val=1024):
166 | # Round the value to the nearest multiple of k
167 | if floor_or_ceil == "floor":
168 | rounded_value = math.floor(value / k) * k
169 | else:
170 | rounded_value = math.ceil(value / k) * k
171 | # Clamp the value between min_val and max_val
172 | return max(min_val, min(rounded_value, max_val))
173 |
174 |
175 | def test(args):
176 |
177 | if args.seed is not None:
178 | seed_everything(args.seed)
179 |
180 | cfg_path = args.cfg_path
181 | config = parse_config(cfg_path)
182 |
183 | if args.variant is not None: args.save_dir += '_' + args.variant
184 |
185 | # Initialize pipeline
186 | pipeline = initialize_pipeline(config, args)
187 |
188 | # load multi-layer-transparent-vae-decoder
189 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
190 | transp_vae = CustomVAE()
191 | transp_vae_path = args.transp_vae_ckpt
192 | missing, unexpected = transp_vae.load_state_dict(torch.load(transp_vae_path)['model'], strict=False)
193 | transp_vae.eval()
194 |
195 | test_samples = [
196 | {
197 | "index": "reso512_1",
198 | "wholecaption": "The image shows a collection of luggage items on a carpeted floor. There are three main pieces of luggage: a large suitcase, a smaller suitcase, and a duffel bag. The large suitcase is positioned in the center, with the smaller suitcase to its left and the duffel bag to its right. The luggage appears to be packed and ready for travel. In the foreground, there is a plastic bag containing what looks like a pair of shoes. The background features a white curtain, suggesting that the setting might be indoors, possibly a hotel room or a similar temporary accommodation. The image is in black and white, which gives it a timeless or classic feel.",
199 | "layout": [(0, 0, 512, 512), (0, 0, 512, 512), (281, 203, 474, 397), (94, 22, 294, 406), (190, 327, 379, 471)],
200 | },
201 | {
202 | "index": "reso512_2",
203 | "wholecaption": "The image features a logo for a flower shop named ”Estelle Darcy Flower Shop.” The logo is designed with a stylized flower, which appears to be a rose, in shades of pink and green. The flower is positioned to the left of the text, which is written in a cursive font. The text is in a brown color, and the overall style of the image is simple and elegant, with a clean, light background that does not distract from the logo itself. The logo conveys a sense of freshness and natural beauty, which is fitting for a flower shop.",
204 | "layout": [(0, 0, 512, 512), (0, 0, 512, 512), (320, 160, 432, 352), (128, 240, 368, 320), (128, 304, 352, 336)],
205 | },
206 | ]
207 |
208 | for idx, batch in tqdm(enumerate(test_samples)):
209 |
210 | generator = torch.Generator(device=torch.device("cuda", index=args.gpu_id)).manual_seed(args.seed) if args.seed else None
211 |
212 | this_index = batch["index"]
213 |
214 | validation_prompt = batch["wholecaption"]
215 | validation_box_raw = batch["layout"]
216 | validation_box = [
217 | (
218 | adjust_coordinate(rect[0], floor_or_ceil="floor"),
219 | adjust_coordinate(rect[1], floor_or_ceil="floor"),
220 | adjust_coordinate(rect[2], floor_or_ceil="ceil"),
221 | adjust_coordinate(rect[3], floor_or_ceil="ceil"),
222 | )
223 | for rect in validation_box_raw
224 | ]
225 | if len(validation_box) > 52:
226 | validation_box = validation_box[:52]
227 |
228 | output, rgba_output, _, _ = pipeline(
229 | prompt=validation_prompt,
230 | validation_box=validation_box,
231 | generator=generator,
232 | height=config.resolution,
233 | width=config.resolution,
234 | num_layers=len(validation_box),
235 | guidance_scale=args.cfg,
236 | num_inference_steps=args.steps,
237 | transparent_decoder=transp_vae,
238 | )
239 | images = output.images # list of PIL, len=layers
240 | rgba_images = [Image.fromarray(arr, 'RGBA') for arr in rgba_output]
241 |
242 | os.makedirs(os.path.join(args.save_dir, this_index), exist_ok=True)
243 | os.system(f"rm -rf {os.path.join(args.save_dir, this_index)}/*")
244 | for frame_idx, frame_pil in enumerate(images):
245 | frame_pil.save(os.path.join(args.save_dir, this_index, f"layer_{frame_idx}.png"))
246 | if frame_idx == 0:
247 | frame_pil.save(os.path.join(args.save_dir, this_index, "merged.png"))
248 | merged_pil = images[1].convert('RGBA')
249 | for frame_idx, frame_pil in enumerate(rgba_images):
250 | if frame_idx < 2:
251 | frame_pil = images[frame_idx].convert('RGBA') # merged and background
252 | else:
253 | merged_pil = Image.alpha_composite(merged_pil, frame_pil)
254 | frame_pil.save(os.path.join(args.save_dir, this_index, f"layer_{frame_idx}_rgba.png"))
255 |
256 | merged_pil = merged_pil.convert('RGB')
257 | merged_pil.save(os.path.join(args.save_dir, this_index, "merged_rgba.png"))
258 |
259 | del pipeline
260 | if torch.cuda.is_available():
261 | torch.cuda.empty_cache()
262 |
263 |
264 | if __name__ == "__main__":
265 | parser = argparse.ArgumentParser()
266 | parser.add_argument("--cfg_path", type=str)
267 | parser.add_argument("--ckpt_dir", type=str)
268 | parser.add_argument("--transp_vae_ckpt", type=str)
269 | parser.add_argument("--pre_fuse_lora_dir", type=str)
270 | parser.add_argument("--extra_lora_dir", type=str, default=None)
271 | parser.add_argument("--save_dir", type=str)
272 | parser.add_argument("--variant", type=str, default="None")
273 | parser.add_argument("--cfg", type=float, default=4.0)
274 | parser.add_argument("--steps", type=int, default=28)
275 | parser.add_argument("--seed", type=int, default=45)
276 | parser.add_argument("--gpu_id", type=int, default=0)
277 |
278 | args = parser.parse_args()
279 |
280 | test(args)
--------------------------------------------------------------------------------
/multi_layer_gen/custom_model_transp_vae.py:
--------------------------------------------------------------------------------
1 | import einops
2 | from collections import OrderedDict
3 | from functools import partial
4 | from typing import Callable
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torchvision
9 | from torch.utils.checkpoint import checkpoint
10 |
11 | from accelerate.utils import set_module_tensor_to_device
12 | from diffusers.models.embeddings import apply_rotary_emb, FluxPosEmbed
13 | from diffusers.models.modeling_utils import ModelMixin
14 | from diffusers.configuration_utils import ConfigMixin
15 | from diffusers.loaders import FromOriginalModelMixin
16 |
17 |
18 | class MLPBlock(torchvision.ops.misc.MLP):
19 | """Transformer MLP block."""
20 |
21 | _version = 2
22 |
23 | def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
24 | super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
25 |
26 | for m in self.modules():
27 | if isinstance(m, nn.Linear):
28 | nn.init.xavier_uniform_(m.weight)
29 | if m.bias is not None:
30 | nn.init.normal_(m.bias, std=1e-6)
31 |
32 | def _load_from_state_dict(
33 | self,
34 | state_dict,
35 | prefix,
36 | local_metadata,
37 | strict,
38 | missing_keys,
39 | unexpected_keys,
40 | error_msgs,
41 | ):
42 | version = local_metadata.get("version", None)
43 |
44 | if version is None or version < 2:
45 | # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
46 | for i in range(2):
47 | for type in ["weight", "bias"]:
48 | old_key = f"{prefix}linear_{i+1}.{type}"
49 | new_key = f"{prefix}{3*i}.{type}"
50 | if old_key in state_dict:
51 | state_dict[new_key] = state_dict.pop(old_key)
52 |
53 | super()._load_from_state_dict(
54 | state_dict,
55 | prefix,
56 | local_metadata,
57 | strict,
58 | missing_keys,
59 | unexpected_keys,
60 | error_msgs,
61 | )
62 |
63 |
64 | class EncoderBlock(nn.Module):
65 | """Transformer encoder block."""
66 |
67 | def __init__(
68 | self,
69 | num_heads: int,
70 | hidden_dim: int,
71 | mlp_dim: int,
72 | dropout: float,
73 | attention_dropout: float,
74 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
75 | ):
76 | super().__init__()
77 | self.num_heads = num_heads
78 | self.hidden_dim = hidden_dim
79 | self.num_heads = num_heads
80 |
81 | # Attention block
82 | self.ln_1 = norm_layer(hidden_dim)
83 | self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
84 | self.dropout = nn.Dropout(dropout)
85 |
86 | # MLP block
87 | self.ln_2 = norm_layer(hidden_dim)
88 | self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
89 |
90 | def forward(self, input: torch.Tensor, freqs_cis):
91 | torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
92 | B, L, C = input.shape
93 | x = self.ln_1(input)
94 | if freqs_cis is not None:
95 | query = x.view(B, L, self.num_heads, self.hidden_dim // self.num_heads).transpose(1, 2)
96 | query = apply_rotary_emb(query, freqs_cis)
97 | query = query.transpose(1, 2).reshape(B, L, self.hidden_dim)
98 | x, _ = self.self_attention(query, query, x, need_weights=False)
99 | x = self.dropout(x)
100 | x = x + input
101 |
102 | y = self.ln_2(x)
103 | y = self.mlp(y)
104 | return x + y
105 |
106 |
107 | class Encoder(nn.Module):
108 | """Transformer Model Encoder for sequence to sequence translation."""
109 |
110 | def __init__(
111 | self,
112 | seq_length: int,
113 | num_layers: int,
114 | num_heads: int,
115 | hidden_dim: int,
116 | mlp_dim: int,
117 | dropout: float,
118 | attention_dropout: float,
119 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
120 | ):
121 | super().__init__()
122 | # Note that batch_size is on the first dim because
123 | # we have batch_first=True in nn.MultiAttention() by default
124 | # self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
125 | self.dropout = nn.Dropout(dropout)
126 | layers: OrderedDict[str, nn.Module] = OrderedDict()
127 | for i in range(num_layers):
128 | layers[f"encoder_layer_{i}"] = EncoderBlock(
129 | num_heads,
130 | hidden_dim,
131 | mlp_dim,
132 | dropout,
133 | attention_dropout,
134 | norm_layer,
135 | )
136 | self.layers = nn.Sequential(layers)
137 | self.ln = norm_layer(hidden_dim)
138 |
139 | def forward(self, input: torch.Tensor, freqs_cis):
140 | torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
141 | input = input # + self.pos_embedding
142 | x = self.dropout(input)
143 | for l in self.layers:
144 | x = checkpoint(l, x, freqs_cis)
145 | x = self.ln(x)
146 | return x
147 |
148 |
149 | class ViTEncoder(nn.Module):
150 | def __init__(self, arch='vit-b/32'):
151 | super().__init__()
152 | self.arch = arch
153 |
154 | if self.arch == 'vit-b/32':
155 | ch = 768
156 | layers = 12
157 | heads = 12
158 | elif self.arch == 'vit-h/14':
159 | ch = 1280
160 | layers = 32
161 | heads = 16
162 |
163 | self.encoder = Encoder(
164 | seq_length=-1,
165 | num_layers=layers,
166 | num_heads=heads,
167 | hidden_dim=ch,
168 | mlp_dim=ch*4,
169 | dropout=0.0,
170 | attention_dropout=0.0,
171 | )
172 | self.fc_in = nn.Linear(16, ch)
173 | self.fc_out = nn.Linear(ch, 256)
174 |
175 | if self.arch == 'vit-b/32':
176 | from torchvision.models.vision_transformer import vit_b_32, ViT_B_32_Weights
177 | vit = vit_b_32(weights=ViT_B_32_Weights.DEFAULT)
178 | elif self.arch == 'vit-h/14':
179 | from torchvision.models.vision_transformer import vit_h_14, ViT_H_14_Weights
180 | vit = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1)
181 |
182 | missing_keys, unexpected_keys = self.encoder.load_state_dict(vit.encoder.state_dict(), strict=False)
183 | if len(missing_keys) > 0 or len(unexpected_keys) > 0:
184 | print(f"ViT Encoder Missing keys: {missing_keys}")
185 | print(f"ViT Encoder Unexpected keys: {unexpected_keys}")
186 | del vit
187 |
188 | def forward(self, x, freqs_cis):
189 | out = self.fc_in(x)
190 | out = self.encoder(out, freqs_cis)
191 | out = checkpoint(self.fc_out, out)
192 | return out
193 |
194 |
195 | def patchify(x, patch_size=8):
196 | if len(x.shape) == 4:
197 | bs, c, h, w = x.shape
198 | x = einops.rearrange(x, "b c (h p1) (w p2) -> b (c p1 p2) h w", p1=patch_size, p2=patch_size)
199 | elif len(x.shape) == 3:
200 | c, h, w = x.shape
201 | x = einops.rearrange(x, "c (h p1) (w p2) -> (c p1 p2) h w", p1=patch_size, p2=patch_size)
202 | return x
203 |
204 |
205 | def unpatchify(x, patch_size=8):
206 | if len(x.shape) == 4:
207 | bs, c, h, w = x.shape
208 | x = einops.rearrange(x, "b (c p1 p2) h w -> b c (h p1) (w p2)", p1=patch_size, p2=patch_size)
209 | elif len(x.shape) == 3:
210 | c, h, w = x.shape
211 | x = einops.rearrange(x, "(c p1 p2) h w -> c (h p1) (w p2)", p1=patch_size, p2=patch_size)
212 | return x
213 |
214 |
215 | def crop_each_layer(hidden_states, use_layers, list_layer_box, H, W, pos_embedding):
216 | token_list = []
217 | cos_list, sin_list = [], []
218 | for layer_idx in range(hidden_states.shape[1]):
219 | if list_layer_box[layer_idx] is None:
220 | continue
221 | else:
222 | x1, y1, x2, y2 = list_layer_box[layer_idx]
223 | x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8
224 | layer_token = hidden_states[:, layer_idx, y1:y2, x1:x2]
225 | c, h, w = layer_token.shape
226 | layer_token = layer_token.reshape(c, -1)
227 | token_list.append(layer_token)
228 | ids = prepare_latent_image_ids(-1, H * 2, W * 2, hidden_states.device, hidden_states.dtype)
229 | ids[:, 0] = use_layers[layer_idx]
230 | image_rotary_emb = pos_embedding(ids)
231 | pos_cos, pos_sin = image_rotary_emb[0].reshape(H, W, -1), image_rotary_emb[1].reshape(H, W, -1)
232 | cos_list.append(pos_cos[y1:y2, x1:x2].reshape(-1, 64))
233 | sin_list.append(pos_sin[y1:y2, x1:x2].reshape(-1, 64))
234 | token_list = torch.cat(token_list, dim=1).permute(1, 0)
235 | cos_list = torch.cat(cos_list, dim=0)
236 | sin_list = torch.cat(sin_list, dim=0)
237 | return token_list, (cos_list, sin_list)
238 |
239 |
240 | def prepare_latent_image_ids(batch_size, height, width, device, dtype):
241 | latent_image_ids = torch.zeros(height // 2, width // 2, 3)
242 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
243 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
244 |
245 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
246 |
247 | latent_image_ids = latent_image_ids.reshape(
248 | latent_image_id_height * latent_image_id_width, latent_image_id_channels
249 | )
250 |
251 | return latent_image_ids.to(device=device, dtype=dtype)
252 |
253 |
254 | class AutoencoderKLTransformerTraining(ModelMixin, ConfigMixin, FromOriginalModelMixin):
255 | def __init__(self):
256 | super().__init__()
257 |
258 | self.decoder_arch = 'vit'
259 | self.layer_embedding = 'rope'
260 |
261 | self.decoder = ViTEncoder()
262 | self.pos_embedding = FluxPosEmbed(theta=10000, axes_dim=(8, 28, 28))
263 | if 'rel' in self.layer_embedding or 'abs' in self.layer_embedding:
264 | self.layer_embedding = nn.Parameter(torch.empty(16, 2 + self.max_layers, 1, 1).normal_(std=0.02), requires_grad=True)
265 |
266 | def zero_module(module):
267 | """
268 | Zero out the parameters of a module and return it.
269 | """
270 | for p in module.parameters():
271 | p.detach().zero_()
272 | return module
273 |
274 | def encode(self, z_2d, box, use_layers):
275 | B, C, T, H, W = z_2d.shape
276 |
277 | z, freqs_cis = [], []
278 | for b in range(B):
279 | _z = z_2d[b]
280 | if 'vit' in self.decoder_arch:
281 | _use_layers = torch.tensor(use_layers[b], device=z_2d.device)
282 | if 'rel' in self.layer_embedding:
283 | _use_layers[_use_layers > 2] = 2
284 | if 'rel' in self.layer_embedding or 'abs' in self.layer_embedding:
285 | _z = _z + self.layer_embedding[:, _use_layers] # + self.pos_embedding
286 | if 'rope' not in self.layer_embedding:
287 | use_layers[b] = [0] * len(use_layers[b])
288 | _z, cis = crop_each_layer(_z, use_layers[b], box[b], H, W, self.pos_embedding) ### modified
289 | z.append(_z)
290 | freqs_cis.append(cis)
291 |
292 | return z, freqs_cis
293 |
294 | def decode(self, z, freqs_cis, box, H, W):
295 | B = len(z)
296 | pad = torch.zeros(4, H, W, device=z[0].device, dtype=z[0].dtype)
297 | pad[3, :, :] = -1
298 | x = []
299 | for b in range(B):
300 | _x = []
301 | _z = self.decoder(z[b].unsqueeze(0), freqs_cis[b]).squeeze(0)
302 | current_index = 0
303 | for layer_idx in range(len(box[b])):
304 | if box[b][layer_idx] == None:
305 | _x.append(pad.clone())
306 | else:
307 | x1, y1, x2, y2 = box[b][layer_idx]
308 | x1_tok, y1_tok, x2_tok, y2_tok = x1 // 8, y1 // 8, x2 // 8, y2 // 8
309 | token_length = (x2_tok - x1_tok) * (y2_tok - y1_tok)
310 | tokens = _z[current_index:current_index + token_length]
311 | pixels = einops.rearrange(tokens, "(h w) c -> c h w", h=y2_tok - y1_tok, w=x2_tok - x1_tok)
312 | unpatched = unpatchify(pixels)
313 | pixels = pad.clone()
314 | pixels[:, y1:y2, x1:x2] = unpatched
315 | _x.append(pixels)
316 | current_index += token_length
317 | _x = torch.stack(_x, dim=1)
318 | x.append(_x)
319 | x = torch.stack(x, dim=0)
320 | return x
321 |
322 | def forward(self, z_2d, box, use_layers=None):
323 | z_2d = z_2d.transpose(0, 1).unsqueeze(0)
324 | use_layers = use_layers or [list(range(z_2d.shape[2]))]
325 | z, freqs_cis = self.encode(z_2d, box, use_layers)
326 | H, W = z_2d.shape[-2:]
327 | x_hat = self.decode(z, freqs_cis, box, H * 8, W * 8)
328 | assert x_hat.shape[0] == 1, x_hat.shape
329 | x_hat = einops.rearrange(x_hat[0], "c t h w -> t c h w")
330 | x_hat_rgb, x_hat_alpha = x_hat[:, :3], x_hat[:, 3:]
331 | return x_hat_rgb, x_hat_alpha
332 |
--------------------------------------------------------------------------------
/multi_layer_gen/custom_model_mmdit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import Any, Dict, List, Optional, Union, Tuple
4 |
5 | from accelerate.utils import set_module_tensor_to_device
6 | from diffusers.models.modeling_outputs import Transformer2DModelOutput
7 | from diffusers.models.normalization import AdaLayerNormContinuous
8 | from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
9 | from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel, FluxTransformerBlock, FluxSingleTransformerBlock
10 |
11 | from diffusers.configuration_utils import register_to_config
12 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
13 |
14 |
15 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16 |
17 |
18 | class CustomFluxTransformer2DModel(FluxTransformer2DModel):
19 | """
20 | The Transformer model introduced in Flux.
21 |
22 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
23 |
24 | Parameters:
25 | patch_size (`int`): Patch size to turn the input data into small patches.
26 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
27 | num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
28 | num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
29 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
30 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
31 | joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
32 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
33 | guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
34 | """
35 |
36 | @register_to_config
37 | def __init__(
38 | self,
39 | patch_size: int = 1,
40 | in_channels: int = 64,
41 | num_layers: int = 19,
42 | num_single_layers: int = 38,
43 | attention_head_dim: int = 128,
44 | num_attention_heads: int = 24,
45 | joint_attention_dim: int = 4096,
46 | pooled_projection_dim: int = 768,
47 | guidance_embeds: bool = False,
48 | axes_dims_rope: Tuple[int] = (16, 56, 56),
49 | max_layer_num: int = 52,
50 | ):
51 | super(FluxTransformer2DModel, self).__init__()
52 | self.out_channels = in_channels
53 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
54 |
55 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
56 |
57 | text_time_guidance_cls = (
58 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
59 | )
60 | self.time_text_embed = text_time_guidance_cls(
61 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
62 | )
63 |
64 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
65 | self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
66 |
67 | self.transformer_blocks = nn.ModuleList(
68 | [
69 | FluxTransformerBlock(
70 | dim=self.inner_dim,
71 | num_attention_heads=self.config.num_attention_heads,
72 | attention_head_dim=self.config.attention_head_dim,
73 | )
74 | for i in range(self.config.num_layers)
75 | ]
76 | )
77 |
78 | self.single_transformer_blocks = nn.ModuleList(
79 | [
80 | FluxSingleTransformerBlock(
81 | dim=self.inner_dim,
82 | num_attention_heads=self.config.num_attention_heads,
83 | attention_head_dim=self.config.attention_head_dim,
84 | )
85 | for i in range(self.config.num_single_layers)
86 | ]
87 | )
88 |
89 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
90 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
91 |
92 | self.gradient_checkpointing = False
93 |
94 | self.max_layer_num = max_layer_num
95 |
96 | # the following process ensures self.layer_pe is not created as a meta tensor
97 | layer_pe_value = nn.init.trunc_normal_(
98 | nn.Parameter(torch.zeros(
99 | 1, self.max_layer_num, 1, 1, self.inner_dim,
100 | )),
101 | mean=0.0, std=0.02, a=-2.0, b=2.0,
102 | ).data.detach()
103 | self.layer_pe = nn.Parameter(layer_pe_value)
104 | set_module_tensor_to_device(
105 | self,
106 | 'layer_pe',
107 | device='cpu',
108 | value=layer_pe_value,
109 | dtype=layer_pe_value.dtype,
110 | )
111 |
112 | @classmethod
113 | def from_pretrained(cls, *args, **kwarg):
114 | model = super().from_pretrained(*args, **kwarg)
115 | for name, para in model.named_parameters():
116 | if name != 'layer_pe':
117 | device = para.device
118 | break
119 | model.layer_pe.to(device)
120 | return model
121 |
122 | def crop_each_layer(self, hidden_states, list_layer_box):
123 | """
124 | hidden_states: [1, n_layers, h, w, inner_dim]
125 | list_layer_box: List, length=n_layers, each element is a Tuple of 4 elements (x1, y1, x2, y2)
126 | """
127 | token_list = []
128 | for layer_idx in range(hidden_states.shape[1]):
129 | if list_layer_box[layer_idx] == None:
130 | continue
131 | else:
132 | x1, y1, x2, y2 = list_layer_box[layer_idx]
133 | x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
134 | layer_token = hidden_states[:, layer_idx, y1:y2, x1:x2, :]
135 | bs, h, w, c = layer_token.shape
136 | layer_token = layer_token.reshape(bs, -1, c)
137 | token_list.append(layer_token)
138 | result = torch.cat(token_list, dim=1)
139 | return result
140 |
141 | def fill_in_processed_tokens(self, hidden_states, full_hidden_states, list_layer_box):
142 | """
143 | hidden_states: [1, h1xw1 + h2xw2 + ... + hlxwl , inner_dim]
144 | full_hidden_states: [1, n_layers, h, w, inner_dim]
145 | list_layer_box: List, length=n_layers, each element is a Tuple of 4 elements (x1, y1, x2, y2)
146 | """
147 | used_token_len = 0
148 | bs = hidden_states.shape[0]
149 | for layer_idx in range(full_hidden_states.shape[1]):
150 | if list_layer_box[layer_idx] == None:
151 | continue
152 | else:
153 | x1, y1, x2, y2 = list_layer_box[layer_idx]
154 | x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
155 | full_hidden_states[:, layer_idx, y1:y2, x1:x2, :] = hidden_states[:, used_token_len: used_token_len + (y2-y1) * (x2-x1), :].reshape(bs, y2-y1, x2-x1, -1)
156 | used_token_len = used_token_len + (y2-y1) * (x2-x1)
157 | return full_hidden_states
158 |
159 | def forward(
160 | self,
161 | hidden_states: torch.Tensor,
162 | list_layer_box: List[Tuple] = None,
163 | encoder_hidden_states: torch.Tensor = None,
164 | pooled_projections: torch.Tensor = None,
165 | timestep: torch.LongTensor = None,
166 | img_ids: torch.Tensor = None,
167 | txt_ids: torch.Tensor = None,
168 | guidance: torch.Tensor = None,
169 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
170 | return_dict: bool = True,
171 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
172 | """
173 | The [`FluxTransformer2DModel`] forward method.
174 |
175 | Args:
176 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
177 | Input `hidden_states`.
178 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
179 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
180 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
181 | from the embeddings of input conditions.
182 | timestep ( `torch.LongTensor`):
183 | Used to indicate denoising step.
184 | block_controlnet_hidden_states: (`list` of `torch.Tensor`):
185 | A list of tensors that if specified are added to the residuals of transformer blocks.
186 | joint_attention_kwargs (`dict`, *optional*):
187 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
188 | `self.processor` in
189 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
190 | return_dict (`bool`, *optional*, defaults to `True`):
191 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
192 | tuple.
193 |
194 | Returns:
195 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
196 | `tuple` where the first element is the sample tensor.
197 | """
198 | if joint_attention_kwargs is not None:
199 | joint_attention_kwargs = joint_attention_kwargs.copy()
200 | lora_scale = joint_attention_kwargs.pop("scale", 1.0)
201 | else:
202 | lora_scale = 1.0
203 |
204 | if USE_PEFT_BACKEND:
205 | # weight the lora layers by setting `lora_scale` for each PEFT layer
206 | scale_lora_layers(self, lora_scale)
207 | else:
208 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
209 | logger.warning(
210 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
211 | )
212 |
213 | bs, n_layers, channel_latent, height, width = hidden_states.shape # [bs, n_layers, c_latent, h, w]
214 |
215 | hidden_states = hidden_states.view(bs, n_layers, channel_latent, height // 2, 2, width // 2, 2) # [bs, n_layers, c_latent, h/2, 2, w/2, 2]
216 | hidden_states = hidden_states.permute(0, 1, 3, 5, 2, 4, 6) # [bs, n_layers, h/2, w/2, c_latent, 2, 2]
217 | hidden_states = hidden_states.reshape(bs, n_layers, height // 2, width // 2, channel_latent * 4) # [bs, n_layers, h/2, w/2, c_latent*4]
218 | hidden_states = self.x_embedder(hidden_states) # [bs, n_layers, h/2, w/2, inner_dim]
219 |
220 | full_hidden_states = torch.zeros_like(hidden_states) # [bs, n_layers, h/2, w/2, inner_dim]
221 | layer_pe = self.layer_pe.view(1, self.max_layer_num, 1, 1, self.inner_dim) # [1, max_n_layers, 1, 1, inner_dim]
222 | hidden_states = hidden_states + layer_pe[:, :n_layers] # [bs, n_layers, h/2, w/2, inner_dim] + [1, n_layers, 1, 1, inner_dim] --> [bs, f, h/2, w/2, inner_dim]
223 | hidden_states = self.crop_each_layer(hidden_states, list_layer_box) # [bs, token_len, inner_dim]
224 |
225 | timestep = timestep.to(hidden_states.dtype) * 1000
226 | if guidance is not None:
227 | guidance = guidance.to(hidden_states.dtype) * 1000
228 | else:
229 | guidance = None
230 | temb = (
231 | self.time_text_embed(timestep, pooled_projections)
232 | if guidance is None
233 | else self.time_text_embed(timestep, guidance, pooled_projections)
234 | )
235 | encoder_hidden_states = self.context_embedder(encoder_hidden_states)
236 |
237 | if txt_ids.ndim == 3:
238 | logger.warning(
239 | "Passing `txt_ids` 3d torch.Tensor is deprecated."
240 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
241 | )
242 | txt_ids = txt_ids[0]
243 | if img_ids.ndim == 3:
244 | logger.warning(
245 | "Passing `img_ids` 3d torch.Tensor is deprecated."
246 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
247 | )
248 | img_ids = img_ids[0]
249 | ids = torch.cat((txt_ids, img_ids), dim=0)
250 | image_rotary_emb = self.pos_embed(ids)
251 |
252 | for index_block, block in enumerate(self.transformer_blocks):
253 | if self.training and self.gradient_checkpointing:
254 |
255 | def create_custom_forward(module, return_dict=None):
256 | def custom_forward(*inputs):
257 | if return_dict is not None:
258 | return module(*inputs, return_dict=return_dict)
259 | else:
260 | return module(*inputs)
261 |
262 | return custom_forward
263 |
264 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
265 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
266 | create_custom_forward(block),
267 | hidden_states,
268 | encoder_hidden_states,
269 | temb,
270 | image_rotary_emb,
271 | **ckpt_kwargs,
272 | )
273 |
274 | else:
275 | encoder_hidden_states, hidden_states = block(
276 | hidden_states=hidden_states,
277 | encoder_hidden_states=encoder_hidden_states,
278 | temb=temb,
279 | image_rotary_emb=image_rotary_emb,
280 | )
281 |
282 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
283 |
284 | for index_block, block in enumerate(self.single_transformer_blocks):
285 | if self.training and self.gradient_checkpointing:
286 |
287 | def create_custom_forward(module, return_dict=None):
288 | def custom_forward(*inputs):
289 | if return_dict is not None:
290 | return module(*inputs, return_dict=return_dict)
291 | else:
292 | return module(*inputs)
293 |
294 | return custom_forward
295 |
296 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
297 | hidden_states = torch.utils.checkpoint.checkpoint(
298 | create_custom_forward(block),
299 | hidden_states,
300 | temb,
301 | image_rotary_emb,
302 | **ckpt_kwargs,
303 | )
304 |
305 | else:
306 | hidden_states = block(
307 | hidden_states=hidden_states,
308 | temb=temb,
309 | image_rotary_emb=image_rotary_emb,
310 | )
311 |
312 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
313 |
314 | hidden_states = self.fill_in_processed_tokens(hidden_states, full_hidden_states, list_layer_box) # [bs, n_layers, h/2, w/2, inner_dim]
315 | hidden_states = hidden_states.view(bs, -1, self.inner_dim) # [bs, n_layers * full_len, inner_dim]
316 |
317 | hidden_states = self.norm_out(hidden_states, temb) # [bs, n_layers * full_len, inner_dim]
318 | hidden_states = self.proj_out(hidden_states) # [bs, n_layers * full_len, c_latent*4]
319 |
320 | # unpatchify
321 | hidden_states = hidden_states.view(bs, n_layers, height//2, width//2, channel_latent, 2, 2) # [bs, n_layers, h/2, w/2, c_latent, 2, 2]
322 | hidden_states = hidden_states.permute(0, 1, 4, 2, 5, 3, 6)
323 | output = hidden_states.reshape(bs, n_layers, channel_latent, height, width) # [bs, n_layers, c_latent, h, w]
324 |
325 | if USE_PEFT_BACKEND:
326 | # remove `lora_scale` from each PEFT layer
327 | unscale_lora_layers(self, lora_scale)
328 |
329 | if not return_dict:
330 | return (output,)
331 |
332 | return Transformer2DModelOutput(sample=output)
--------------------------------------------------------------------------------
/multi_layer_gen/custom_pipeline.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import Any, Callable, Dict, List, Optional, Union
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from diffusers.utils.torch_utils import randn_tensor
8 | from diffusers.utils import is_torch_xla_available, logging
9 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
10 | from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, FluxPipeline
11 |
12 | if is_torch_xla_available():
13 | import torch_xla.core.xla_model as xm # type: ignore
14 | XLA_AVAILABLE = True
15 | else:
16 | XLA_AVAILABLE = False
17 |
18 |
19 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20 |
21 |
22 | def _get_clip_prompt_embeds(
23 | tokenizer,
24 | text_encoder,
25 | prompt: Union[str, List[str]],
26 | num_images_per_prompt: int = 1,
27 | device: Optional[torch.device] = None,
28 | ):
29 | device = device or text_encoder.device
30 | dtype = text_encoder.dtype
31 |
32 | prompt = [prompt] if isinstance(prompt, str) else prompt
33 | batch_size = len(prompt)
34 |
35 | text_inputs = tokenizer(
36 | prompt,
37 | padding="max_length",
38 | max_length=text_encoder.config.max_position_embeddings,
39 | truncation=True,
40 | return_overflowing_tokens=False,
41 | return_length=False,
42 | return_tensors="pt",
43 | )
44 |
45 | text_input_ids = text_inputs.input_ids
46 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
47 |
48 | # Use pooled output of CLIPTextModel
49 | prompt_embeds = prompt_embeds.pooler_output
50 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
51 |
52 | # duplicate text embeddings for each generation per prompt, using mps friendly method
53 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
54 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
55 |
56 | return prompt_embeds
57 |
58 |
59 | def _get_t5_prompt_embeds(
60 | tokenizer,
61 | text_encoder,
62 | prompt: Union[str, List[str]] = None,
63 | num_images_per_prompt: int = 1,
64 | max_sequence_length: int = 512,
65 | device: Optional[torch.device] = None,
66 | dtype: Optional[torch.dtype] = None,
67 | ):
68 | device = device or text_encoder.device
69 | dtype = dtype or text_encoder.dtype
70 |
71 | prompt = [prompt] if isinstance(prompt, str) else prompt
72 | batch_size = len(prompt)
73 |
74 | text_inputs = tokenizer(
75 | prompt,
76 | padding="max_length",
77 | max_length=max_sequence_length,
78 | truncation=True,
79 | return_length=False,
80 | return_overflowing_tokens=False,
81 | return_tensors="pt",
82 | )
83 | text_input_ids = text_inputs.input_ids
84 |
85 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0]
86 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
87 |
88 | _, seq_len, _ = prompt_embeds.shape
89 |
90 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
91 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
92 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
93 |
94 | return prompt_embeds
95 |
96 |
97 | def encode_prompt(
98 | tokenizers,
99 | text_encoders,
100 | prompt: Union[str, List[str]],
101 | prompt_2: Union[str, List[str]] = None,
102 | num_images_per_prompt: int = 1,
103 | max_sequence_length: int = 512,
104 | ):
105 |
106 | tokenizer_1, tokenizer_2 = tokenizers
107 | text_encoder_1, text_encoder_2 = text_encoders
108 | device = text_encoder_1.device
109 | dtype = text_encoder_1.dtype
110 |
111 | prompt = [prompt] if isinstance(prompt, str) else prompt
112 | prompt_2 = prompt_2 or prompt
113 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
114 |
115 | # We only use the pooled prompt output from the CLIPTextModel
116 | pooled_prompt_embeds = _get_clip_prompt_embeds(
117 | tokenizer=tokenizer_1,
118 | text_encoder=text_encoder_1,
119 | prompt=prompt,
120 | num_images_per_prompt=num_images_per_prompt,
121 | )
122 | prompt_embeds = _get_t5_prompt_embeds(
123 | tokenizer=tokenizer_2,
124 | text_encoder=text_encoder_2,
125 | prompt=prompt_2,
126 | num_images_per_prompt=num_images_per_prompt,
127 | max_sequence_length=max_sequence_length,
128 | )
129 |
130 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
131 |
132 | return prompt_embeds, pooled_prompt_embeds, text_ids
133 |
134 |
135 | class CustomFluxPipeline(FluxPipeline):
136 |
137 | @staticmethod
138 | def _prepare_latent_image_ids(height, width, list_layer_box, device, dtype):
139 |
140 | latent_image_ids_list = []
141 | for layer_idx in range(len(list_layer_box)):
142 | if list_layer_box[layer_idx] == None:
143 | continue
144 | else:
145 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) # [h/2, w/2, 3]
146 | latent_image_ids[..., 0] = layer_idx # use the first dimension for layer representation
147 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
148 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
149 |
150 | x1, y1, x2, y2 = list_layer_box[layer_idx]
151 | x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
152 | latent_image_ids = latent_image_ids[y1:y2, x1:x2, :]
153 |
154 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
155 | latent_image_ids = latent_image_ids.reshape(
156 | latent_image_id_height * latent_image_id_width, latent_image_id_channels
157 | )
158 |
159 | latent_image_ids_list.append(latent_image_ids)
160 |
161 | full_latent_image_ids = torch.cat(latent_image_ids_list, dim=0)
162 |
163 | return full_latent_image_ids.to(device=device, dtype=dtype)
164 |
165 | def prepare_latents(
166 | self,
167 | batch_size,
168 | num_layers,
169 | num_channels_latents,
170 | height,
171 | width,
172 | list_layer_box,
173 | dtype,
174 | device,
175 | generator,
176 | latents=None,
177 | ):
178 | height = 2 * (int(height) // self.vae_scale_factor)
179 | width = 2 * (int(width) // self.vae_scale_factor)
180 |
181 | shape = (batch_size, num_layers, num_channels_latents, height, width)
182 |
183 | if latents is not None:
184 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
185 | return latents.to(device=device, dtype=dtype), latent_image_ids
186 |
187 | if isinstance(generator, list) and len(generator) != batch_size:
188 | raise ValueError(
189 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
190 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
191 | )
192 |
193 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # [bs, f, c_latent, h, w]
194 |
195 | latent_image_ids = self._prepare_latent_image_ids(height, width, list_layer_box, device, dtype)
196 |
197 | return latents, latent_image_ids
198 |
199 | @torch.no_grad()
200 | def __call__(
201 | self,
202 | prompt: Union[str, List[str]] = None,
203 | prompt_2: Optional[Union[str, List[str]]] = None,
204 | validation_box: List[tuple] = None,
205 | height: Optional[int] = None,
206 | width: Optional[int] = None,
207 | num_inference_steps: int = 28,
208 | timesteps: List[int] = None,
209 | guidance_scale: float = 3.5,
210 | num_images_per_prompt: Optional[int] = 1,
211 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
212 | latents: Optional[torch.FloatTensor] = None,
213 | prompt_embeds: Optional[torch.FloatTensor] = None,
214 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
215 | output_type: Optional[str] = "pil",
216 | return_dict: bool = True,
217 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
218 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
219 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
220 | max_sequence_length: int = 512,
221 | num_layers: int = 5,
222 | sdxl_vae: nn.Module = None,
223 | transparent_decoder: nn.Module = None,
224 | ):
225 | r"""
226 | Function invoked when calling the pipeline for generation.
227 |
228 | Args:
229 | prompt (`str` or `List[str]`, *optional*):
230 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
231 | instead.
232 | prompt_2 (`str` or `List[str]`, *optional*):
233 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
234 | will be used instead
235 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
236 | The height in pixels of the generated image. This is set to 1024 by default for the best results.
237 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
238 | The width in pixels of the generated image. This is set to 1024 by default for the best results.
239 | num_inference_steps (`int`, *optional*, defaults to 50):
240 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
241 | expense of slower inference.
242 | timesteps (`List[int]`, *optional*):
243 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
244 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
245 | passed will be used. Must be in descending order.
246 | guidance_scale (`float`, *optional*, defaults to 7.0):
247 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
248 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
249 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
250 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
251 | usually at the expense of lower image quality.
252 | num_images_per_prompt (`int`, *optional*, defaults to 1):
253 | The number of images to generate per prompt.
254 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
255 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
256 | to make generation deterministic.
257 | latents (`torch.FloatTensor`, *optional*):
258 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
259 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
260 | tensor will ge generated by sampling using the supplied random `generator`.
261 | prompt_embeds (`torch.FloatTensor`, *optional*):
262 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
263 | provided, text embeddings will be generated from `prompt` input argument.
264 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
265 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
266 | If not provided, pooled text embeddings will be generated from `prompt` input argument.
267 | output_type (`str`, *optional*, defaults to `"pil"`):
268 | The output format of the generate image. Choose between
269 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
270 | return_dict (`bool`, *optional*, defaults to `True`):
271 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
272 | joint_attention_kwargs (`dict`, *optional*):
273 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
274 | `self.processor` in
275 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
276 | callback_on_step_end (`Callable`, *optional*):
277 | A function that calls at the end of each denoising steps during the inference. The function is called
278 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
279 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
280 | `callback_on_step_end_tensor_inputs`.
281 | callback_on_step_end_tensor_inputs (`List`, *optional*):
282 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
283 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
284 | `._callback_tensor_inputs` attribute of your pipeline class.
285 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
286 |
287 | Examples:
288 |
289 | Returns:
290 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
291 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
292 | images.
293 | """
294 |
295 | height = height or self.default_sample_size * self.vae_scale_factor
296 | width = width or self.default_sample_size * self.vae_scale_factor
297 |
298 | # 1. Check inputs. Raise error if not correct
299 | self.check_inputs(
300 | prompt,
301 | prompt_2,
302 | height,
303 | width,
304 | prompt_embeds=prompt_embeds,
305 | pooled_prompt_embeds=pooled_prompt_embeds,
306 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
307 | max_sequence_length=max_sequence_length,
308 | )
309 |
310 | self._guidance_scale = guidance_scale
311 | self._joint_attention_kwargs = joint_attention_kwargs
312 | self._interrupt = False
313 |
314 | # 2. Define call parameters
315 | if prompt is not None and isinstance(prompt, str):
316 | batch_size = 1
317 | elif prompt is not None and isinstance(prompt, list):
318 | batch_size = len(prompt)
319 | else:
320 | batch_size = prompt_embeds.shape[0]
321 |
322 | device = self._execution_device
323 |
324 | lora_scale = (
325 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
326 | )
327 | (
328 | prompt_embeds,
329 | pooled_prompt_embeds,
330 | text_ids,
331 | ) = self.encode_prompt(
332 | prompt=prompt,
333 | prompt_2=prompt_2,
334 | prompt_embeds=prompt_embeds,
335 | pooled_prompt_embeds=pooled_prompt_embeds,
336 | device=device,
337 | num_images_per_prompt=num_images_per_prompt,
338 | max_sequence_length=max_sequence_length,
339 | lora_scale=lora_scale,
340 | )
341 |
342 | # 4. Prepare latent variables
343 | num_channels_latents = self.transformer.config.in_channels // 4
344 | latents, latent_image_ids = self.prepare_latents(
345 | batch_size * num_images_per_prompt,
346 | num_layers,
347 | num_channels_latents,
348 | height,
349 | width,
350 | validation_box,
351 | prompt_embeds.dtype,
352 | device,
353 | generator,
354 | latents,
355 | )
356 |
357 | # 5. Prepare timesteps
358 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
359 | image_seq_len = latent_image_ids.shape[0] # ???
360 | mu = calculate_shift(
361 | image_seq_len,
362 | self.scheduler.config.base_image_seq_len,
363 | self.scheduler.config.max_image_seq_len,
364 | self.scheduler.config.base_shift,
365 | self.scheduler.config.max_shift,
366 | )
367 | timesteps, num_inference_steps = retrieve_timesteps(
368 | self.scheduler,
369 | num_inference_steps,
370 | device,
371 | timesteps,
372 | sigmas,
373 | mu=mu,
374 | )
375 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
376 | self._num_timesteps = len(timesteps)
377 |
378 | # handle guidance
379 | if self.transformer.config.guidance_embeds:
380 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
381 | guidance = guidance.expand(latents.shape[0])
382 | else:
383 | guidance = None
384 |
385 | # 6. Denoising loop
386 | with self.progress_bar(total=num_inference_steps) as progress_bar:
387 | for i, t in enumerate(timesteps):
388 | if self.interrupt:
389 | continue
390 |
391 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
392 | timestep = t.expand(latents.shape[0]).to(latents.dtype)
393 |
394 | noise_pred = self.transformer(
395 | hidden_states=latents,
396 | list_layer_box=validation_box,
397 | timestep=timestep / 1000,
398 | guidance=guidance,
399 | pooled_projections=pooled_prompt_embeds,
400 | encoder_hidden_states=prompt_embeds,
401 | txt_ids=text_ids,
402 | img_ids=latent_image_ids,
403 | joint_attention_kwargs=self.joint_attention_kwargs,
404 | return_dict=False,
405 | )[0]
406 |
407 | # compute the previous noisy sample x_t -> x_t-1
408 | latents_dtype = latents.dtype
409 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
410 |
411 | if latents.dtype != latents_dtype:
412 | if torch.backends.mps.is_available():
413 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
414 | latents = latents.to(latents_dtype)
415 |
416 | if callback_on_step_end is not None:
417 | callback_kwargs = {}
418 | for k in callback_on_step_end_tensor_inputs:
419 | callback_kwargs[k] = locals()[k]
420 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
421 |
422 | latents = callback_outputs.pop("latents", latents)
423 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
424 |
425 | # call the callback, if provided
426 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
427 | progress_bar.update()
428 |
429 | if XLA_AVAILABLE:
430 | xm.mark_step()
431 |
432 | # create a grey latent
433 | bs, n_frames, channel_latent, height, width = latents.shape
434 |
435 | pixel_grey = torch.zeros(size=(bs*n_frames, 3, height*8, width*8), device=latents.device, dtype=latents.dtype)
436 | latent_grey = self.vae.encode(pixel_grey).latent_dist.sample()
437 | latent_grey = (latent_grey - self.vae.config.shift_factor) * self.vae.config.scaling_factor
438 | latent_grey = latent_grey.view(bs, n_frames, channel_latent, height, width) # [bs, f, c_latent, h, w]
439 |
440 | # fill in the latents
441 | for layer_idx in range(latent_grey.shape[1]):
442 | x1, y1, x2, y2 = validation_box[layer_idx]
443 | x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8
444 | latent_grey[:, layer_idx, :, y1:y2, x1:x2] = latents[:, layer_idx, :, y1:y2, x1:x2]
445 | latents = latent_grey
446 |
447 | if output_type == "latent":
448 | image = latents
449 |
450 | else:
451 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
452 | latents = latents.reshape(bs * n_frames, channel_latent, height, width)
453 | image = self.vae.decode(latents, return_dict=False)[0]
454 | if sdxl_vae is not None:
455 | sdxl_vae = sdxl_vae.to(dtype=image.dtype, device=image.device)
456 | sdxl_latents = sdxl_vae.encode(image).latent_dist.sample()
457 | transparent_decoder = transparent_decoder.to(dtype=image.dtype, device=image.device)
458 | result_list, vis_list = transparent_decoder(sdxl_vae, sdxl_latents)
459 | else:
460 | result_list, vis_list = None, None
461 | image = self.image_processor.postprocess(image, output_type=output_type)
462 |
463 | # Offload all models
464 | self.maybe_free_model_hooks()
465 |
466 | if not return_dict:
467 | return (image, result_list, vis_list)
468 |
469 | return FluxPipelineOutput(images=image), result_list, vis_list
470 |
471 |
472 | class CustomFluxPipelineCfg(FluxPipeline):
473 |
474 | @staticmethod
475 | def _prepare_latent_image_ids(height, width, list_layer_box, device, dtype):
476 |
477 | latent_image_ids_list = []
478 | for layer_idx in range(len(list_layer_box)):
479 | if list_layer_box[layer_idx] == None:
480 | continue
481 | else:
482 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) # [h/2, w/2, 3]
483 | latent_image_ids[..., 0] = layer_idx # use the first dimension for layer representation
484 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
485 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
486 |
487 | x1, y1, x2, y2 = list_layer_box[layer_idx]
488 | x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
489 | latent_image_ids = latent_image_ids[y1:y2, x1:x2, :]
490 |
491 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
492 | latent_image_ids = latent_image_ids.reshape(
493 | latent_image_id_height * latent_image_id_width, latent_image_id_channels
494 | )
495 |
496 | latent_image_ids_list.append(latent_image_ids)
497 |
498 | full_latent_image_ids = torch.cat(latent_image_ids_list, dim=0)
499 |
500 | return full_latent_image_ids.to(device=device, dtype=dtype)
501 |
502 | def prepare_latents(
503 | self,
504 | batch_size,
505 | num_layers,
506 | num_channels_latents,
507 | height,
508 | width,
509 | list_layer_box,
510 | dtype,
511 | device,
512 | generator,
513 | latents=None,
514 | ):
515 | height = 2 * (int(height) // self.vae_scale_factor)
516 | width = 2 * (int(width) // self.vae_scale_factor)
517 |
518 | shape = (batch_size, num_layers, num_channels_latents, height, width)
519 |
520 | if latents is not None:
521 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
522 | return latents.to(device=device, dtype=dtype), latent_image_ids
523 |
524 | if isinstance(generator, list) and len(generator) != batch_size:
525 | raise ValueError(
526 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
527 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
528 | )
529 |
530 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # [bs, n_layers, c_latent, h, w]
531 |
532 | latent_image_ids = self._prepare_latent_image_ids(height, width, list_layer_box, device, dtype)
533 |
534 | return latents, latent_image_ids
535 |
536 | @torch.no_grad()
537 | def __call__(
538 | self,
539 | prompt: Union[str, List[str]] = None,
540 | prompt_2: Optional[Union[str, List[str]]] = None,
541 | validation_box: List[tuple] = None,
542 | height: Optional[int] = None,
543 | width: Optional[int] = None,
544 | num_inference_steps: int = 28,
545 | timesteps: List[int] = None,
546 | guidance_scale: float = 3.5,
547 | true_gs: float = 3.5,
548 | num_images_per_prompt: Optional[int] = 1,
549 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
550 | latents: Optional[torch.FloatTensor] = None,
551 | prompt_embeds: Optional[torch.FloatTensor] = None,
552 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
553 | output_type: Optional[str] = "pil",
554 | return_dict: bool = True,
555 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
556 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
557 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
558 | max_sequence_length: int = 512,
559 | num_layers: int = 5,
560 | transparent_decoder: nn.Module = None,
561 | ):
562 | r"""
563 | Function invoked when calling the pipeline for generation.
564 |
565 | Args:
566 | prompt (`str` or `List[str]`, *optional*):
567 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
568 | instead.
569 | prompt_2 (`str` or `List[str]`, *optional*):
570 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
571 | will be used instead
572 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
573 | The height in pixels of the generated image. This is set to 1024 by default for the best results.
574 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
575 | The width in pixels of the generated image. This is set to 1024 by default for the best results.
576 | num_inference_steps (`int`, *optional*, defaults to 50):
577 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
578 | expense of slower inference.
579 | timesteps (`List[int]`, *optional*):
580 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
581 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
582 | passed will be used. Must be in descending order.
583 | guidance_scale (`float`, *optional*, defaults to 7.0):
584 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
585 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
586 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
587 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
588 | usually at the expense of lower image quality.
589 | num_images_per_prompt (`int`, *optional*, defaults to 1):
590 | The number of images to generate per prompt.
591 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
592 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
593 | to make generation deterministic.
594 | latents (`torch.FloatTensor`, *optional*):
595 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
596 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
597 | tensor will ge generated by sampling using the supplied random `generator`.
598 | prompt_embeds (`torch.FloatTensor`, *optional*):
599 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
600 | provided, text embeddings will be generated from `prompt` input argument.
601 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
602 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
603 | If not provided, pooled text embeddings will be generated from `prompt` input argument.
604 | output_type (`str`, *optional*, defaults to `"pil"`):
605 | The output format of the generate image. Choose between
606 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
607 | return_dict (`bool`, *optional*, defaults to `True`):
608 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
609 | joint_attention_kwargs (`dict`, *optional*):
610 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
611 | `self.processor` in
612 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
613 | callback_on_step_end (`Callable`, *optional*):
614 | A function that calls at the end of each denoising steps during the inference. The function is called
615 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
616 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
617 | `callback_on_step_end_tensor_inputs`.
618 | callback_on_step_end_tensor_inputs (`List`, *optional*):
619 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
620 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
621 | `._callback_tensor_inputs` attribute of your pipeline class.
622 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
623 |
624 | Examples:
625 |
626 | Returns:
627 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
628 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
629 | images.
630 | """
631 |
632 | height = height or self.default_sample_size * self.vae_scale_factor
633 | width = width or self.default_sample_size * self.vae_scale_factor
634 |
635 | # 1. Check inputs. Raise error if not correct
636 | self.check_inputs(
637 | prompt,
638 | prompt_2,
639 | height,
640 | width,
641 | prompt_embeds=prompt_embeds,
642 | pooled_prompt_embeds=pooled_prompt_embeds,
643 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
644 | max_sequence_length=max_sequence_length,
645 | )
646 |
647 | self._guidance_scale = guidance_scale
648 | self._joint_attention_kwargs = joint_attention_kwargs
649 | self._interrupt = False
650 |
651 | # 2. Define call parameters
652 | if prompt is not None and isinstance(prompt, str):
653 | batch_size = 1
654 | elif prompt is not None and isinstance(prompt, list):
655 | batch_size = len(prompt)
656 | else:
657 | batch_size = prompt_embeds.shape[0]
658 |
659 | device = self._execution_device
660 |
661 | lora_scale = (
662 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
663 | )
664 | (
665 | prompt_embeds,
666 | pooled_prompt_embeds,
667 | text_ids,
668 | ) = self.encode_prompt(
669 | prompt=prompt,
670 | prompt_2=prompt_2,
671 | prompt_embeds=prompt_embeds,
672 | pooled_prompt_embeds=pooled_prompt_embeds,
673 | device=device,
674 | num_images_per_prompt=num_images_per_prompt,
675 | max_sequence_length=max_sequence_length,
676 | lora_scale=lora_scale,
677 | )
678 | (
679 | neg_prompt_embeds,
680 | neg_pooled_prompt_embeds,
681 | neg_text_ids,
682 | ) = self.encode_prompt(
683 | prompt="",
684 | prompt_2=None,
685 | device=device,
686 | num_images_per_prompt=num_images_per_prompt,
687 | max_sequence_length=max_sequence_length,
688 | lora_scale=lora_scale,
689 | )
690 |
691 | # 4. Prepare latent variables
692 | num_channels_latents = self.transformer.config.in_channels // 4
693 | latents, latent_image_ids = self.prepare_latents(
694 | batch_size * num_images_per_prompt,
695 | num_layers,
696 | num_channels_latents,
697 | height,
698 | width,
699 | validation_box,
700 | prompt_embeds.dtype,
701 | device,
702 | generator,
703 | latents,
704 | )
705 |
706 | # 5. Prepare timesteps
707 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
708 | image_seq_len = latent_image_ids.shape[0]
709 | mu = calculate_shift(
710 | image_seq_len,
711 | self.scheduler.config.base_image_seq_len,
712 | self.scheduler.config.max_image_seq_len,
713 | self.scheduler.config.base_shift,
714 | self.scheduler.config.max_shift,
715 | )
716 | timesteps, num_inference_steps = retrieve_timesteps(
717 | self.scheduler,
718 | num_inference_steps,
719 | device,
720 | timesteps,
721 | sigmas,
722 | mu=mu,
723 | )
724 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
725 | self._num_timesteps = len(timesteps)
726 |
727 | # handle guidance
728 | if self.transformer.config.guidance_embeds:
729 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
730 | guidance = guidance.expand(latents.shape[0])
731 | else:
732 | guidance = None
733 |
734 | # 6. Denoising loop
735 | with self.progress_bar(total=num_inference_steps) as progress_bar:
736 | for i, t in enumerate(timesteps):
737 | if self.interrupt:
738 | continue
739 |
740 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
741 | timestep = t.expand(latents.shape[0]).to(latents.dtype)
742 |
743 | noise_pred = self.transformer(
744 | hidden_states=latents,
745 | list_layer_box=validation_box,
746 | timestep=timestep / 1000,
747 | guidance=guidance,
748 | pooled_projections=pooled_prompt_embeds,
749 | encoder_hidden_states=prompt_embeds,
750 | txt_ids=text_ids,
751 | img_ids=latent_image_ids,
752 | joint_attention_kwargs=self.joint_attention_kwargs,
753 | return_dict=False,
754 | )[0]
755 |
756 | neg_noise_pred = self.transformer(
757 | hidden_states=latents,
758 | list_layer_box=validation_box,
759 | timestep=timestep / 1000,
760 | guidance=guidance,
761 | pooled_projections=neg_pooled_prompt_embeds,
762 | encoder_hidden_states=neg_prompt_embeds,
763 | txt_ids=neg_text_ids,
764 | img_ids=latent_image_ids,
765 | joint_attention_kwargs=self.joint_attention_kwargs,
766 | return_dict=False,
767 | )[0]
768 |
769 | noise_pred = neg_noise_pred + true_gs * (noise_pred - neg_noise_pred)
770 |
771 | # compute the previous noisy sample x_t -> x_t-1
772 | latents_dtype = latents.dtype
773 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
774 |
775 | if latents.dtype != latents_dtype:
776 | if torch.backends.mps.is_available():
777 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
778 | latents = latents.to(latents_dtype)
779 |
780 | if callback_on_step_end is not None:
781 | callback_kwargs = {}
782 | for k in callback_on_step_end_tensor_inputs:
783 | callback_kwargs[k] = locals()[k]
784 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
785 |
786 | latents = callback_outputs.pop("latents", latents)
787 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
788 |
789 | # call the callback, if provided
790 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
791 | progress_bar.update()
792 |
793 | if XLA_AVAILABLE:
794 | xm.mark_step()
795 |
796 | # create a grey latent
797 | bs, n_layers, channel_latent, height, width = latents.shape
798 |
799 | pixel_grey = torch.zeros(size=(bs*n_layers, 3, height*8, width*8), device=latents.device, dtype=latents.dtype)
800 | latent_grey = self.vae.encode(pixel_grey).latent_dist.sample()
801 | latent_grey = (latent_grey - self.vae.config.shift_factor) * self.vae.config.scaling_factor
802 | latent_grey = latent_grey.view(bs, n_layers, channel_latent, height, width) # [bs, n_layers, c_latent, h, w]
803 |
804 | # fill in the latents
805 | for layer_idx in range(latent_grey.shape[1]):
806 | if validation_box[layer_idx] == None:
807 | continue
808 | x1, y1, x2, y2 = validation_box[layer_idx]
809 | x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8
810 | latent_grey[:, layer_idx, :, y1:y2, x1:x2] = latents[:, layer_idx, :, y1:y2, x1:x2]
811 | latents = latent_grey
812 |
813 | if output_type == "latent":
814 | image = latents
815 |
816 | else:
817 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
818 | latents = latents.reshape(bs * n_layers, channel_latent, height, width)
819 | latents_segs = torch.split(latents, 16, dim=0) ### split latents by 16 to avoid odd purple output
820 | image_segs = [self.vae.decode(latents_seg, return_dict=False)[0] for latents_seg in latents_segs]
821 | image = torch.cat(image_segs, dim=0)
822 | if transparent_decoder is not None:
823 | transparent_decoder = transparent_decoder.to(dtype=image.dtype, device=image.device)
824 |
825 | decoded_fg, decoded_alpha = transparent_decoder(latents, [validation_box])
826 | decoded_alpha = (decoded_alpha + 1.0) / 2.0
827 | decoded_alpha = torch.clamp(decoded_alpha, min=0.0, max=1.0).permute(0, 2, 3, 1)
828 |
829 | decoded_fg = (decoded_fg + 1.0) / 2.0
830 | decoded_fg = torch.clamp(decoded_fg, min=0.0, max=1.0).permute(0, 2, 3, 1)
831 |
832 | vis_list = None
833 | png = torch.cat([decoded_fg, decoded_alpha], dim=3)
834 | result_list = (png * 255.0).detach().cpu().float().numpy().clip(0, 255).astype(np.uint8)
835 | else:
836 | result_list, vis_list = None, None
837 | image = self.image_processor.postprocess(image, output_type=output_type)
838 |
839 | # Offload all models
840 | self.maybe_free_model_hooks()
841 |
842 | if not return_dict:
843 | return (image, result_list, vis_list, latents)
844 |
845 | return FluxPipelineOutput(images=image), result_list, vis_list, latents
--------------------------------------------------------------------------------