├── layout_planner ├── training │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── layout_dataset.py │ │ ├── layoutnew_dataset.py │ │ └── quantizer.py │ ├── utils.py │ └── trainer_layout.py ├── requirements_part1.txt ├── requirements_part2.txt ├── scripts │ └── inference_template.sh ├── config │ └── 13b_z3_no_offload.json ├── models │ └── modeling_layout.py └── inference_layout.py ├── assets ├── teaser.png ├── natural_image_data │ ├── case_1_seed_0 │ │ ├── bg.png │ │ ├── whole.png │ │ ├── layer_0.png │ │ ├── layer_1.png │ │ ├── layer_2.png │ │ ├── layer_3.png │ │ └── composite.png │ ├── case_1_seed_1 │ │ ├── bg.png │ │ ├── whole.png │ │ ├── layer_0.png │ │ ├── layer_1.png │ │ ├── layer_2.png │ │ ├── layer_3.png │ │ └── composite.png │ ├── case_1_seed_8 │ │ ├── bg.png │ │ ├── whole.png │ │ ├── layer_0.png │ │ ├── layer_1.png │ │ ├── layer_2.png │ │ ├── layer_3.png │ │ └── composite.png │ ├── case_2_seed_0 │ │ ├── bg.png │ │ ├── whole.png │ │ ├── layer_0.png │ │ ├── layer_1.png │ │ ├── layer_2.png │ │ ├── layer_3.png │ │ └── composite.png │ ├── case_2_seed_3 │ │ ├── bg.png │ │ ├── whole.png │ │ ├── layer_0.png │ │ ├── layer_1.png │ │ ├── layer_2.png │ │ ├── layer_3.png │ │ └── composite.png │ ├── case_2_seed_6 │ │ ├── bg.png │ │ ├── whole.png │ │ ├── layer_0.png │ │ ├── layer_1.png │ │ ├── layer_2.png │ │ ├── layer_3.png │ │ └── composite.png │ ├── case_3_seed_0 │ │ ├── bg.png │ │ ├── whole.png │ │ ├── layer_0.png │ │ ├── layer_1.png │ │ ├── layer_2.png │ │ ├── layer_3.png │ │ └── composite.png │ ├── case_3_seed_1 │ │ ├── bg.png │ │ ├── whole.png │ │ ├── layer_0.png │ │ ├── layer_1.png │ │ ├── layer_2.png │ │ ├── layer_3.png │ │ └── composite.png │ ├── case_3_seed_2 │ │ ├── bg.png │ │ ├── whole.png │ │ ├── layer_0.png │ │ ├── layer_1.png │ │ ├── layer_2.png │ │ ├── layer_3.png │ │ └── composite.png │ ├── case_4_seed_0 │ │ ├── bg.png │ │ ├── whole.png │ │ ├── layer_0.png │ │ ├── layer_1.png │ │ ├── layer_2.png │ │ ├── layer_3.png │ │ └── composite.png │ ├── case_4_seed_3 │ │ ├── bg.png │ │ ├── whole.png │ │ ├── layer_0.png │ │ ├── layer_1.png │ │ ├── layer_2.png │ │ ├── layer_3.png │ │ └── composite.png │ └── case_4_seed_9 │ │ ├── bg.png │ │ ├── whole.png │ │ ├── layer_0.png │ │ ├── layer_1.png │ │ ├── layer_2.png │ │ ├── layer_3.png │ │ └── composite.png └── semi_transparent_image_data │ ├── bg │ ├── case_1_output_1.png │ ├── case_2_output_0.png │ ├── case_2_output_2.png │ ├── case_3_output_0.png │ ├── case_4_output_2.png │ ├── case_5_output_3.png │ ├── case_6_output_3.png │ ├── case_7_output_3.png │ ├── case_8_output_2.png │ ├── case_9_output_1.png │ ├── case_10_output_1.png │ └── case_11_output_1.png │ ├── fg │ ├── case_1_output_1.png │ ├── case_2_output_0.png │ ├── case_2_output_2.png │ ├── case_3_output_0.png │ ├── case_4_output_2.png │ ├── case_5_output_3.png │ ├── case_6_output_3.png │ ├── case_7_output_3.png │ ├── case_8_output_2.png │ ├── case_9_output_1.png │ ├── case_10_output_1.png │ └── case_11_output_1.png │ └── merged │ ├── case_10_output_1.png │ ├── case_11_output_1.png │ ├── case_1_output_1.png │ ├── case_2_output_0.png │ ├── case_2_output_2.png │ ├── case_3_output_0.png │ ├── case_4_output_2.png │ ├── case_5_output_3.png │ ├── case_6_output_3.png │ ├── case_7_output_3.png │ ├── case_8_output_2.png │ └── case_9_output_1.png ├── CODE_OF_CONDUCT.md ├── multi_layer_gen ├── configs │ ├── base.py │ ├── multi_layer_resolution1024_test.py │ └── multi_layer_resolution512_test.py ├── test.py ├── custom_model_transp_vae.py ├── custom_model_mmdit.py └── custom_pipeline.py ├── LICENSE ├── SUPPORT.md ├── README.md ├── SECURITY.md ├── example.py └── .gitignore /layout_planner/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /layout_planner/training/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_0/bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_0/bg.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_1/bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_1/bg.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_8/bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_8/bg.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_0/bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_0/bg.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_3/bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_3/bg.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_6/bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_6/bg.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_0/bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_0/bg.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_1/bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_1/bg.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_2/bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_2/bg.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_0/bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_0/bg.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_3/bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_3/bg.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_9/bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_9/bg.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_0/whole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_0/whole.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_1/whole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_1/whole.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_8/whole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_8/whole.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_0/whole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_0/whole.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_3/whole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_3/whole.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_6/whole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_6/whole.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_0/whole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_0/whole.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_1/whole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_1/whole.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_2/whole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_2/whole.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_0/whole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_0/whole.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_3/whole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_3/whole.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_9/whole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_9/whole.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_0/layer_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_0/layer_0.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_0/layer_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_0/layer_1.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_0/layer_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_0/layer_2.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_0/layer_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_0/layer_3.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_1/layer_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_1/layer_0.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_1/layer_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_1/layer_1.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_1/layer_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_1/layer_2.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_1/layer_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_1/layer_3.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_8/layer_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_8/layer_0.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_8/layer_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_8/layer_1.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_8/layer_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_8/layer_2.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_8/layer_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_8/layer_3.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_0/layer_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_0/layer_0.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_0/layer_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_0/layer_1.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_0/layer_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_0/layer_2.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_0/layer_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_0/layer_3.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_3/layer_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_3/layer_0.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_3/layer_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_3/layer_1.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_3/layer_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_3/layer_2.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_3/layer_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_3/layer_3.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_6/layer_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_6/layer_0.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_6/layer_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_6/layer_1.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_6/layer_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_6/layer_2.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_6/layer_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_6/layer_3.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_0/layer_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_0/layer_0.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_0/layer_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_0/layer_1.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_0/layer_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_0/layer_2.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_0/layer_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_0/layer_3.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_1/layer_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_1/layer_0.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_1/layer_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_1/layer_1.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_1/layer_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_1/layer_2.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_1/layer_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_1/layer_3.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_2/layer_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_2/layer_0.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_2/layer_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_2/layer_1.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_2/layer_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_2/layer_2.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_2/layer_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_2/layer_3.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_0/layer_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_0/layer_0.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_0/layer_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_0/layer_1.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_0/layer_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_0/layer_2.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_0/layer_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_0/layer_3.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_3/layer_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_3/layer_0.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_3/layer_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_3/layer_1.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_3/layer_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_3/layer_2.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_3/layer_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_3/layer_3.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_9/layer_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_9/layer_0.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_9/layer_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_9/layer_1.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_9/layer_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_9/layer_2.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_9/layer_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_9/layer_3.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_0/composite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_0/composite.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_1/composite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_1/composite.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_1_seed_8/composite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_1_seed_8/composite.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_0/composite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_0/composite.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_3/composite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_3/composite.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_2_seed_6/composite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_2_seed_6/composite.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_0/composite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_0/composite.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_1/composite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_1/composite.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_3_seed_2/composite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_3_seed_2/composite.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_0/composite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_0/composite.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_3/composite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_3/composite.png -------------------------------------------------------------------------------- /assets/natural_image_data/case_4_seed_9/composite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/natural_image_data/case_4_seed_9/composite.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/bg/case_1_output_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_1_output_1.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/bg/case_2_output_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_2_output_0.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/bg/case_2_output_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_2_output_2.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/bg/case_3_output_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_3_output_0.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/bg/case_4_output_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_4_output_2.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/bg/case_5_output_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_5_output_3.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/bg/case_6_output_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_6_output_3.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/bg/case_7_output_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_7_output_3.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/bg/case_8_output_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_8_output_2.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/bg/case_9_output_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_9_output_1.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/fg/case_1_output_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_1_output_1.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/fg/case_2_output_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_2_output_0.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/fg/case_2_output_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_2_output_2.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/fg/case_3_output_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_3_output_0.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/fg/case_4_output_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_4_output_2.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/fg/case_5_output_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_5_output_3.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/fg/case_6_output_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_6_output_3.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/fg/case_7_output_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_7_output_3.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/fg/case_8_output_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_8_output_2.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/fg/case_9_output_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_9_output_1.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/bg/case_10_output_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_10_output_1.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/bg/case_11_output_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/bg/case_11_output_1.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/fg/case_10_output_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_10_output_1.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/fg/case_11_output_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/fg/case_11_output_1.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/merged/case_10_output_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_10_output_1.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/merged/case_11_output_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_11_output_1.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/merged/case_1_output_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_1_output_1.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/merged/case_2_output_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_2_output_0.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/merged/case_2_output_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_2_output_2.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/merged/case_3_output_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_3_output_0.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/merged/case_4_output_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_4_output_2.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/merged/case_5_output_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_5_output_3.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/merged/case_6_output_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_6_output_3.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/merged/case_7_output_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_7_output_3.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/merged/case_8_output_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_8_output_2.png -------------------------------------------------------------------------------- /assets/semi_transparent_image_data/merged/case_9_output_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/art-msra/HEAD/assets/semi_transparent_image_data/merged/case_9_output_1.png -------------------------------------------------------------------------------- /layout_planner/requirements_part1.txt: -------------------------------------------------------------------------------- 1 | gradio 2 | matplotlib 3 | diffusers 4 | accelerate==0.27.2 5 | warmup_scheduler 6 | webdataset 7 | tensorboardx 8 | protobuf==3.20.0 9 | tensorboard 10 | torchmetrics 11 | clean-fid 12 | pynvml 13 | open_clip_torch 14 | datasets 15 | skia-python 16 | colorama 17 | opencv-python 18 | deepspeed==0.14.2 -------------------------------------------------------------------------------- /layout_planner/requirements_part2.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.27.2 2 | click>=8.0.4,<9 3 | datasets>=2.10.0,<3 4 | transformers==4.39.1 5 | langchain>=0.0.139 6 | wandb 7 | ninja 8 | tensorboard 9 | tensorboardx 10 | peft @ git+https://github.com/huggingface/peft.git@382b178911edff38c1ff619bbac2ba556bd2276b 11 | webdataset 12 | tiktoken 13 | einops 14 | bitsandbytes 15 | evaluate 16 | bert_score -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /multi_layer_gen/configs/base.py: -------------------------------------------------------------------------------- 1 | ### Model Settings 2 | pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev" 3 | revision = None 4 | variant = None 5 | cache_dir = None 6 | 7 | ### Training Settings 8 | seed = 42 9 | report_to = "wandb" 10 | tracker_project_name = "multilayer" 11 | wandb_job_name = "YOU_FORGET_TO_SET" 12 | logging_dir = "logs" 13 | max_train_steps = None 14 | checkpoints_total_limit = None 15 | 16 | # gpu 17 | allow_tf32 = True 18 | gradient_checkpointing = True 19 | mixed_precision = "bf16" 20 | -------------------------------------------------------------------------------- /layout_planner/scripts/inference_template.sh: -------------------------------------------------------------------------------- 1 | python inference_layout.py \ 2 | --input_model "Your base model path" \ 3 | --resume "Your checkpoint path" \ 4 | --width Your_data_width --height Your_data_height \ 5 | --save_path "Your output save path" \ 6 | --do_sample Whether_do_sample_when_decoding \ 7 | --temperature The_temperature_when_decoding \ 8 | --inference_caption "Design an engaging and vibrant recruitment advertisement for our company. The image should feature three animated characters in a modern cityscape, depicting a dynamic and collaborative work environment. Incorporate a light bulb graphic with a question mark, symbolizing innovation, creativity, and problem-solving. Use bold text to announce \"WE ARE RECRUITING\" and provide the company's social media handle \"@reallygreatsite\" and a contact phone number \"+123-456-7890\" for interested individuals. The overall design should be playful and youthful, attracting potential recruits who are innovative and eager to contribute to a lively team." 9 | 10 | -------------------------------------------------------------------------------- /layout_planner/training/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def accuracy(output, target, padding, topk=(1,)): 5 | """Computes the accuracy over the k top predictions for the specified values of k""" 6 | with torch.no_grad(): 7 | maxk = max(topk) 8 | if output.shape[-1] < maxk: 9 | print(f"[WARNING] Less than {maxk} predictions available. Using {output.shape[-1]} for topk.") 10 | 11 | maxk = min(maxk, output.shape[-1]) 12 | batch_size = target.size(0) 13 | 14 | # Take topk along the last dimension. 15 | _, pred = output.topk(maxk, -1, True, True) # (N, T, topk) 16 | 17 | mask = (target != padding).type(target.dtype) 18 | target_expand = target[..., None].expand_as(pred) 19 | correct = pred.eq(target_expand) 20 | correct = correct * mask[..., None].expand_as(correct) 21 | 22 | res = [] 23 | for k in topk: 24 | correct_k = correct[..., :k].reshape(-1).float().sum(0, keepdim=True) 25 | res.append(correct_k.mul_(100.0 / mask.sum())) 26 | return res 27 | -------------------------------------------------------------------------------- /layout_planner/training/datasets/layout_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | 4 | 5 | class LayoutDataset(torch.utils.data.Dataset): 6 | def __init__(self, layout_data_path, split_name): 7 | self.split_name = split_name 8 | self.layout_data_path = layout_data_path 9 | with open(self.layout_data_path, 'r') as file: 10 | self.json_datas = json.load(file) 11 | 12 | self.ids = [x for x in range(len(self.json_datas))] 13 | if self.split_name == 'train': 14 | self.ids = self.ids[:-200] 15 | self.json_datas = self.json_datas[:-200] 16 | elif self.split_name == 'val': 17 | self.ids = self.ids[-200:-100] 18 | self.json_datas = self.json_datas[-200:-100] 19 | elif self.split_name == 'test': 20 | self.ids = self.ids[-100:] 21 | self.json_datas = self.json_datas[-100:] 22 | else: 23 | raise NotImplementedError 24 | 25 | def __getitem__(self, idx): 26 | return self.json_datas[idx]['train'] 27 | 28 | def __len__(self): 29 | return len(self.json_datas) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /layout_planner/config/13b_z3_no_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupLR", 16 | "params": { 17 | "warmup_min_lr": "auto", 18 | "warmup_max_lr": "auto", 19 | "warmup_num_steps": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 3, 24 | "overlap_comm": true, 25 | "contiguous_gradients": true, 26 | "sub_group_size": 1e9, 27 | "reduce_bucket_size": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "stage3_param_persistence_threshold": "auto", 30 | "stage3_max_live_parameters": 1e9, 31 | "stage3_max_reuse_distance": 1e9, 32 | "stage3_gather_16bit_weights_on_model_save": true 33 | }, 34 | "gradient_accumulation_steps": "auto", 35 | "gradient_clipping": "auto", 36 | "steps_per_print": 2000, 37 | "train_batch_size": "auto", 38 | "train_micro_batch_size_per_gpu": "auto", 39 | "wall_clock_breakdown": false 40 | } -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /multi_layer_gen/configs/multi_layer_resolution1024_test.py: -------------------------------------------------------------------------------- 1 | _base_ = "./base.py" 2 | 3 | ### path & device settings 4 | 5 | output_path_base = "./output/" 6 | cache_dir = None 7 | 8 | 9 | ### wandb settings 10 | wandb_job_name = "flux_" + '{{fileBasenameNoExtension}}' 11 | 12 | resolution = 1024 13 | 14 | ### Model Settings 15 | rank = 64 16 | text_encoder_rank = 64 17 | train_text_encoder = False 18 | max_layer_num = 50 + 2 19 | learnable_proj = True 20 | 21 | ### Training Settings 22 | weighting_scheme = "none" 23 | logit_mean = 0.0 24 | logit_std = 1.0 25 | mode_scale = 1.29 26 | guidance_scale = 1.0 ###IMPORTANT 27 | layer_weighting = 5.0 28 | 29 | # steps 30 | train_batch_size = 1 31 | num_train_epochs = 1 32 | max_train_steps = None 33 | checkpointing_steps = 2000 34 | resume_from_checkpoint = "latest" 35 | gradient_accumulation_steps = 1 36 | 37 | # lr 38 | optimizer = "prodigy" 39 | learning_rate = 1.0 40 | scale_lr = False 41 | lr_scheduler = "constant" 42 | lr_warmup_steps = 0 43 | lr_num_cycles = 1 44 | lr_power = 1.0 45 | 46 | # optim 47 | adam_beta1 = 0.9 48 | adam_beta2 = 0.999 49 | adam_weight_decay = 1e-3 50 | adam_epsilon = 1e-8 51 | prodigy_beta3 = None 52 | prodigy_decouple = True 53 | prodigy_use_bias_correction = True 54 | prodigy_safeguard_warmup = True 55 | max_grad_norm = 1.0 56 | 57 | # logging 58 | tracker_task_name = '{{fileBasenameNoExtension}}' 59 | output_dir = output_path_base + "{{fileBasenameNoExtension}}" 60 | 61 | ### Validation Settings 62 | num_validation_images = 1 63 | validation_steps = 2000 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

ART: Anonymous Region Transformer for
Variable Multi-Layer Transparent Image Generation

2 |

3 | arXiv 4 | 5 | 6 |

7 | 8 | This repository supports [generating multi-layer transparent images](#multi-layer-generation) (constructed with multiple RGBA image layers) based on a global text prompt and an anonymous region layout (bounding boxes without layer captions). The anonymous region layout can be either [predicted by LLM](#llm-for-layout-planning) or manually specified by users. 9 | 10 | 11 | ## 🌟 Features 12 | - **Anonymous Layout**: Requires only a single global caption to generate multiple layers, eliminating the need for individual captions for each layer. 13 | - **High Layer Capacity**: Supports the generation of 50+ layers, enabling complex multi-layer outputs. 14 | - **Efficiency**: Maintains high efficiency compared to full attention and spatial-temporal attention mechanisms. 15 | 16 | ## 🛑 Important Notice (updated 2025/07/23) 17 | 18 | This repository previously contained code and pretrained model weights for generating multi-layer transparent images using a global text prompt and anonymous region layout. However, since the model was trained using data that may have come from illegal sources, we have removed the model weights and inference checkpoints from this repository, along with all associated download links. If you have any questions, please contact the original authors through official channels. 19 | 20 | As a result: 21 | 22 | - The pretrained models and associated checkpoints are no longer available for download. 23 | - No runnable or usable code for inference or training is provided. 24 | - We do not provide any means to use or reproduce the model at this time. 25 | - The training code, which relied on this data, has also been removed. 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /layout_planner/training/trainer_layout.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from queue import Queue 3 | from collections import defaultdict 4 | 5 | import torch 6 | from transformers import Trainer 7 | 8 | from .utils import accuracy 9 | 10 | 11 | def l2_loss(u: torch.Tensor, v: torch.Tensor) -> torch.Tensor: 12 | """ 13 | Args: 14 | u: (N, T, D) tensor. 15 | v: (N, T, D) tensor. 16 | Returns: 17 | l1_loss: (N,) tensor of summed L1 loss. 18 | """ 19 | assert u.shape == v.shape, (u.shape, v.shape) 20 | return ((u - v) ** 2).sum(dim=-1) ** 0.5 21 | 22 | 23 | def batch_purity(input_ids, tokenizer): 24 | strings = tokenizer.batch_decode(input_ids) 25 | strings = [ 26 | string.replace(" ", "").replace('> "', '>"').replace("", "") 27 | for string in strings 28 | ] 29 | return strings 30 | 31 | 32 | class Meter: 33 | def __init__(self,size): 34 | self.size = size 35 | self.reset() 36 | 37 | def reset(self): 38 | self.bins = defaultdict(Queue) 39 | 40 | def update(self,metrics): 41 | for k,v in metrics.items(): 42 | self.bins[k].put(v) 43 | 44 | def get(self): 45 | metrics = {} 46 | for k,v in self.bins.items(): 47 | metrics[k] = np.mean(list(v.queue)) 48 | return metrics 49 | 50 | 51 | class TrainerLayout(Trainer): 52 | def __init__(self, extra_args, **kwargs): 53 | self.quantizer = kwargs.pop("quantizer", None) 54 | super().__init__(**kwargs) 55 | self.extra_args = extra_args 56 | weight = torch.ones(self.extra_args.vocab_size) 57 | weight[self.extra_args.old_vocab_size :] = self.extra_args.new_token_weight 58 | self.weighted_ce_loss = torch.nn.CrossEntropyLoss(weight=weight).cuda() 59 | if 'opt-' in self.extra_args.opt_version: 60 | if self.args.fp16: 61 | self.weighted_ce_loss = self.weighted_ce_loss.half() 62 | elif self.args.bf16: 63 | self.weighted_ce_loss = self.weighted_ce_loss.bfloat16() 64 | 65 | self.meter = Meter(self.args.logging_steps) 66 | 67 | def compute_loss(self, model, inputs, return_outputs=False): 68 | labels = inputs.pop("labels").long() 69 | 70 | ( 71 | model_output, 72 | full_labels, 73 | input_embs_norm, 74 | ) = model(labels=labels) 75 | 76 | output = model_output.logits 77 | 78 | masked_full_labels = full_labels.clone() 79 | masked_full_labels[masked_full_labels < self.extra_args.old_vocab_size] = -100 80 | 81 | weighted_ce_loss = self.weighted_ce_loss( 82 | output[:, :-1, :].reshape(-1, output.shape[-1]), 83 | full_labels[:, 1:].reshape(-1), 84 | ) 85 | ce_loss = weighted_ce_loss * self.extra_args.ce_loss_scale 86 | 87 | loss = ce_loss 88 | acc1, acc5 = accuracy(output[:, :-1, :], full_labels[:, 1:], -100, topk=(1, 5)) 89 | masked_acc1, masked_acc5 = accuracy( 90 | output[:, :-1, :], masked_full_labels[:, 1:], -100, topk=(1, 5) 91 | ) 92 | 93 | metrics = { 94 | "loss": loss.item(), 95 | "ce_loss": ce_loss.item(), 96 | "top1": float(acc1), 97 | "top5": float(acc5), 98 | "masked_top1": float(masked_acc1), 99 | "masked_top5": float(masked_acc5), 100 | "inp_emb_norm": input_embs_norm.item(), 101 | } 102 | self.meter.update(metrics) 103 | if self.state.global_step % self.args.logging_steps == 0: 104 | metrics = self.meter.get() 105 | self.meter.reset() 106 | self.log(metrics) 107 | 108 | outputs = model_output 109 | 110 | return (loss, outputs) if return_outputs else loss 111 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import argparse 5 | from PIL import Image 6 | from multi_layer_gen.custom_model_mmdit import CustomFluxTransformer2DModel 7 | from multi_layer_gen.custom_model_transp_vae import AutoencoderKLTransformerTraining as CustomVAE 8 | from multi_layer_gen.custom_pipeline import CustomFluxPipelineCfg 9 | 10 | def test_sample(pipeline, transp_vae, batch, args): 11 | 12 | def adjust_coordinate(value, floor_or_ceil, k=16, min_val=0, max_val=1024): 13 | # Round the value to the nearest multiple of k 14 | if floor_or_ceil == "floor": 15 | rounded_value = math.floor(value / k) * k 16 | else: 17 | rounded_value = math.ceil(value / k) * k 18 | # Clamp the value between min_val and max_val 19 | return max(min_val, min(rounded_value, max_val)) 20 | 21 | validation_prompt = batch["wholecaption"] 22 | validation_box_raw = batch["layout"] 23 | validation_box = [ 24 | ( 25 | adjust_coordinate(rect[0], floor_or_ceil="floor"), 26 | adjust_coordinate(rect[1], floor_or_ceil="floor"), 27 | adjust_coordinate(rect[2], floor_or_ceil="ceil"), 28 | adjust_coordinate(rect[3], floor_or_ceil="ceil"), 29 | ) 30 | for rect in validation_box_raw 31 | ] 32 | if len(validation_box) > 52: 33 | validation_box = validation_box[:52] 34 | 35 | generator = torch.Generator(device=torch.device("cuda", index=args.gpu_id)).manual_seed(args.seed) if args.seed else None 36 | output, rgba_output, _, _ = pipeline( 37 | prompt=validation_prompt, 38 | validation_box=validation_box, 39 | generator=generator, 40 | height=args.resolution, 41 | width=args.resolution, 42 | num_layers=len(validation_box), 43 | guidance_scale=args.cfg, 44 | num_inference_steps=args.steps, 45 | transparent_decoder=transp_vae, 46 | ) 47 | images = output.images # list of PIL, len=layers 48 | rgba_images = [Image.fromarray(arr, 'RGBA') for arr in rgba_output] 49 | 50 | os.makedirs(os.path.join(args.save_dir, this_index), exist_ok=True) 51 | for frame_idx, frame_pil in enumerate(images): 52 | frame_pil.save(os.path.join(args.save_dir, this_index, f"layer_{frame_idx}.png")) 53 | if frame_idx == 0: 54 | frame_pil.save(os.path.join(args.save_dir, this_index, "merged.png")) 55 | merged_pil = images[1].convert('RGBA') 56 | for frame_idx, frame_pil in enumerate(rgba_images): 57 | if frame_idx < 2: 58 | frame_pil = images[frame_idx].convert('RGBA') # merged and background 59 | else: 60 | merged_pil = Image.alpha_composite(merged_pil, frame_pil) 61 | frame_pil.save(os.path.join(args.save_dir, this_index, f"layer_{frame_idx}_rgba.png")) 62 | 63 | merged_pil = merged_pil.convert('RGB') 64 | merged_pil.save(os.path.join(args.save_dir, this_index, "merged_rgba.png")) 65 | 66 | 67 | args = dict( 68 | save_dir="output/", 69 | resolution=512, 70 | cfg=4.0, 71 | steps=28, 72 | seed=41, 73 | gpu_id=0, 74 | ) 75 | args = argparse.Namespace(**args) 76 | 77 | transformer = CustomFluxTransformer2DModel.from_pretrained("ART-Release/ART_v1.0", subfolder="transformer", torch_dtype=torch.bfloat16) 78 | transp_vae = CustomVAE.from_pretrained("ART-Release/ART_v1.0", subfolder="transp_vae", torch_dtype=torch.float32) 79 | pipeline = CustomFluxPipelineCfg.from_pretrained( 80 | "black-forest-labs/FLUX.1-dev", 81 | transformer=transformer, 82 | torch_dtype=torch.bfloat16, 83 | ).to(torch.device("cuda", index=args.gpu_id)) 84 | pipeline.enable_model_cpu_offload(gpu_id=args.gpu_id) # Save GPU memory 85 | 86 | sample = { 87 | "index": "reso512_3", 88 | "wholecaption": 'Floral wedding invitation: green leaves, white flowers; circular border. Center: "JOIN US CELEBRATING OUR WEDDING" (cursive), "DONNA AND HARPER" (bold), "03 JUNE 2023" (small bold). White, green color scheme, elegant, natural.', 89 | "layout": [(0, 0, 512, 512), (0, 0, 512, 512), (0, 0, 512, 352), (144, 384, 368, 448), (160, 192, 352, 432), (368, 0, 512, 144), (0, 0, 144, 144), (128, 80, 384, 208), (128, 448, 384, 496), (176, 48, 336, 80)], 90 | } 91 | 92 | test_sample(pipeline=pipeline, transp_vae=transp_vae, batch=sample, args=args) 93 | 94 | del pipeline 95 | if torch.cuda.is_available(): 96 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | diffusers 165 | *.png 166 | *__pycache__* 167 | wandb 168 | *.safetensors 169 | work_dirs 170 | evaluation 171 | .gradio 172 | .pdf 173 | *.ttf 174 | tools 175 | BACKUP 176 | exps 177 | 178 | output* 179 | multi_layer_gen/output* 180 | multi_layer_gen/pretrained* 181 | test_semi.ipynb 182 | 183 | # not ignore the following files 184 | !assets/teaser.png 185 | !assets/natural_image_data/*/*.png 186 | !assets/semi_transparent_image_data/*/*.png -------------------------------------------------------------------------------- /layout_planner/training/datasets/layoutnew_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Any, Dict, List, Union 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | from transformers import DataCollatorForLanguageModeling 8 | 9 | from .layout_dataset import LayoutDataset 10 | 11 | 12 | def get_dataset(args, split: str, quantizer, tokenizer, **kwargs) -> Dataset: 13 | dataset = LayoutNew(quantizer, tokenizer, split=split, max_len=args.max_len, **kwargs) 14 | return dataset 15 | 16 | 17 | class CenterWrapperNew(torch.utils.data.Dataset): 18 | def __init__(self, dataset): 19 | self.__inner_dataset = dataset 20 | 21 | def __getitem__(self, idx): 22 | example = self.__inner_dataset[idx] 23 | return example 24 | 25 | def __len__(self): 26 | return len(self.__inner_dataset) 27 | 28 | 29 | class LayoutNew(torch.utils.data.Dataset): 30 | def __init__(self, quantizer, tokenizer, 31 | split='train', max_len: int = 32, 32 | return_index=False, inf_length=False, 33 | with_layout=True, 34 | split_version='default', 35 | layout_path = None, 36 | **kwargs): 37 | self.tokenizer = tokenizer 38 | self.max_len = max_len 39 | self.quantizer = quantizer 40 | poster_datasets = [] 41 | self.split_version = split_version 42 | self.layout_path = layout_path 43 | if with_layout: 44 | if split_version == 'layout': 45 | print(f"with {split_version} data: {self.layout_path}") 46 | inner_dataset = LayoutDataset(self.layout_path, split) 47 | else: 48 | raise NotImplementedError 49 | 50 | poster_datasets.append(inner_dataset) 51 | 52 | self.inner_dataset = torch.utils.data.ConcatDataset(poster_datasets) 53 | self.inner_dataset = CenterWrapperNew(self.inner_dataset) 54 | self.split = split 55 | self.size = len(self.inner_dataset) 56 | self.return_index = return_index 57 | self.inf_length = inf_length 58 | if self.inf_length: 59 | self.max_len = 1000000 60 | 61 | def __getitem__(self, idx): 62 | example = self.inner_dataset[idx] 63 | json_example = self.quantizer.convert2layout(example) 64 | max_len = self.max_len 65 | 66 | content = self.quantizer.dump2json(json_example) 67 | 68 | raw_input_ids = self.tokenizer(content, return_tensors="pt").input_ids[0] 69 | 70 | if len(raw_input_ids) > max_len and self.split == 'train': 71 | start = random.randint(0, len(raw_input_ids) - max_len) 72 | end = start + max_len 73 | input_ids = raw_input_ids[start:end] 74 | 75 | else: 76 | input_ids = raw_input_ids 77 | if not self.inf_length: 78 | if input_ids.shape[0] > max_len: 79 | input_ids = input_ids[:max_len] 80 | 81 | if not self.inf_length: 82 | if input_ids.shape[0] > self.max_len: 83 | input_ids = input_ids[:self.max_len] 84 | elif input_ids.shape[0] < self.max_len: 85 | padding_1 = torch.zeros((1,), dtype=torch.long) + self.tokenizer.bos_token_id 86 | padding_2 = torch.zeros((self.max_len - input_ids.shape[0] - 1,), dtype=torch.long) + self.tokenizer.pad_token_id 87 | input_ids = torch.cat([input_ids, padding_1, padding_2], dim=0) 88 | assert input_ids.shape[0] == self.max_len 89 | 90 | return_dict = { 91 | 'labels': input_ids, 92 | } 93 | if self.return_index: 94 | return_dict['index'] = idx 95 | 96 | return return_dict 97 | 98 | def __len__(self): 99 | return self.size 100 | 101 | 102 | def layout_collate_fn(batch): 103 | input_ids = [item['labels'] for item in batch] 104 | input_ids = torch.stack(input_ids, dim=0) 105 | return_dict = { 106 | 'labels': input_ids, 107 | } 108 | if 'index' in batch[0]: 109 | index = [item['index'] for item in batch] 110 | return_dict['index'] = index 111 | 112 | return return_dict 113 | 114 | 115 | class DataCollatorForLayout(DataCollatorForLanguageModeling): 116 | def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: 117 | return layout_collate_fn(examples) -------------------------------------------------------------------------------- /layout_planner/training/datasets/quantizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import numpy as np 4 | from functools import lru_cache 5 | from collections import OrderedDict 6 | 7 | 8 | class BaseQuantizer: 9 | @property 10 | def ignore_tokens(self): 11 | return [] 12 | 13 | def __init__(self, simplify_json=False, **kwargs): 14 | self.simplify_json=simplify_json 15 | self.io_ignore_replace_tokens = [''] 16 | 17 | def dump2json(self, json_example): 18 | if self.simplify_json: 19 | content = json.dumps(json_example, separators=(',',':')) 20 | for token in self.additional_special_tokens: 21 | content = content.replace(f'"{token}"', token) 22 | else: 23 | content = json.dumps(json_example) 24 | return content 25 | 26 | def load_json(self, content): 27 | replace_tokens = set(self.additional_special_tokens) - set(self.io_ignore_replace_tokens) # sirui change 28 | if self.simplify_json: 29 | for token in replace_tokens: 30 | content = content.replace(token, f'"{token}"') 31 | return json.loads(content) 32 | 33 | 34 | specs={ 35 | "width":"size", 36 | "height":"size", 37 | "x":"pos", # center x 38 | "y":"pos", # center y 39 | "color":"color", 40 | "font":"font" 41 | } 42 | 43 | 44 | min_max_bins = { 45 | 'size': (0,1,256), 46 | 'pos': (0,1,256), 47 | 'color': (0,137,138), 48 | 'font': (0,511,512) 49 | } 50 | 51 | 52 | class QuantizerV4(BaseQuantizer): 53 | def __init__(self, quant=True, **kwargs): 54 | super().__init__(**kwargs) 55 | self.min = min 56 | self.max = max 57 | self.quant = quant 58 | self.text_split_token = '' 59 | self.set_min_max_bins(min_max_bins) 60 | self.width = kwargs.get('width', 1024) 61 | self.height = kwargs.get('height', 1024) 62 | self.width = int(self.width) 63 | self.height = int(self.height) 64 | 65 | def set_min_max_bins(self, min_max_bins): 66 | min_max_bins = copy.deepcopy(min_max_bins) 67 | # adjust the bins to plus one 68 | for type_name, (min_val, max_val, n_bins) in min_max_bins.items(): 69 | assert n_bins % 2 == 0 70 | min_max_bins[type_name] = (min_val, max_val, n_bins+1) 71 | self.min_max_bins = min_max_bins 72 | 73 | def setup_tokenizer(self, tokenizer): 74 | additional_special_tokens = [self.text_split_token] 75 | rest_types = [key for key in self.min_max_bins.keys()] 76 | for type_name in rest_types: 77 | min_val, max_val,n_bins = self.min_max_bins[type_name] 78 | additional_special_tokens += [f'<{type_name}-{i}>' for i in range(n_bins)] 79 | 80 | print('additional_special_tokens', additional_special_tokens) 81 | 82 | tokenizer.add_special_tokens({'additional_special_tokens': additional_special_tokens}) 83 | self.additional_special_tokens = set(additional_special_tokens) 84 | return tokenizer 85 | 86 | 87 | @lru_cache(maxsize=128) 88 | def get_bins(self, real_type): 89 | min_val, max_val, n_bins = self.min_max_bins[real_type] 90 | return min_val, max_val, np.linspace(min_val, max_val, n_bins) 91 | 92 | def quantize(self, x, type): 93 | if not self.quant: 94 | return x 95 | """Quantize a float array x into n_bins discrete values.""" 96 | real_type = specs[type] 97 | min_val, max_val, bins = self.get_bins(real_type) 98 | x = np.clip(float(x), min_val, max_val) 99 | val = np.digitize(x, bins) - 1 100 | n_bins = len(bins) 101 | assert val >= 0 and val < n_bins 102 | return f'<{real_type}-{val}>' 103 | 104 | def dequantize(self, x): 105 | # ->1 106 | val = x.split('-')[1].strip('>') 107 | # ->pos 108 | real_type = x.split('-')[0][1:] 109 | min_val, max_val, bins = self.get_bins(real_type) 110 | return bins[int(val)] 111 | 112 | def construct_map_dict(self): 113 | map_dict = {} 114 | for i in range(self.min_max_bins['size'][2]): 115 | name = "" % i 116 | value = self.dequantize(name) 117 | map_dict[name] = str(value) 118 | for i in range(self.min_max_bins['pos'][2]): 119 | name = "" % i 120 | value = self.dequantize(name) 121 | map_dict[name] = str(value) 122 | return map_dict 123 | 124 | def postprocess_colorandfont(self, json_example): 125 | import re 126 | json_example = re.sub(r'()', r'"\1"', json_example) 127 | json_example = re.sub(r'()', r'"\1"', json_example) 128 | return json_example 129 | 130 | def convert2layout(self, example): 131 | new_example = OrderedDict() 132 | new_example['wholecaption'] = example['wholecaption'] 133 | new_layout = [] 134 | for meta_layer in example['layout']: 135 | new_layout.append({ 136 | "layer": meta_layer["layer"], 137 | "x": self.quantize(meta_layer["x"]/self.width, 'x'), 138 | "y": self.quantize(meta_layer["y"]/self.height, 'y'), 139 | "width": self.quantize(meta_layer["width"]/self.width, 'width'), 140 | "height": self.quantize(meta_layer["height"]/self.height, 'height') 141 | }) 142 | new_example['layout'] = new_layout 143 | return new_example 144 | 145 | 146 | def get_quantizer(version='v1', **kwargs): 147 | if version == 'v4': 148 | quantizer = QuantizerV4(**kwargs) 149 | else: 150 | raise NotImplementedError 151 | 152 | return quantizer 153 | 154 | -------------------------------------------------------------------------------- /layout_planner/models/modeling_layout.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from typing import Optional, List 4 | from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoModelForCausalLM, OPTForCausalLM, BitsAndBytesConfig 5 | 6 | 7 | def kmp_preprocess(pattern): 8 | pattern_len = len(pattern) 9 | prefix_suffix = [0] * pattern_len 10 | j = 0 11 | 12 | for i in range(1, pattern_len): 13 | while j > 0 and pattern[i] != pattern[j]: 14 | j = prefix_suffix[j - 1] 15 | 16 | if pattern[i] == pattern[j]: 17 | j += 1 18 | 19 | prefix_suffix[i] = j 20 | 21 | return prefix_suffix 22 | 23 | 24 | def kmp_search(text, pattern): 25 | text_len = len(text) 26 | pattern_len = len(pattern) 27 | prefix_suffix = kmp_preprocess(pattern) 28 | matches = [] 29 | 30 | j = 0 31 | for i in range(text_len): 32 | while j > 0 and text[i] != pattern[j]: 33 | j = prefix_suffix[j - 1] 34 | 35 | if text[i] == pattern[j]: 36 | j += 1 37 | 38 | if j == pattern_len: 39 | matches.append(i - j + 1) 40 | j = prefix_suffix[j - 1] 41 | 42 | return matches 43 | 44 | 45 | class ModelWrapper: 46 | def __init__(self, model): 47 | self.model = model 48 | 49 | def __getattr__(self, name): 50 | return getattr(self.model, name) 51 | 52 | @torch.no_grad() 53 | def __call__(self, pixel_values): 54 | return self.model(pixel_values) 55 | 56 | def eval(self): 57 | pass 58 | 59 | def train(self): 60 | pass 61 | 62 | def parameters(self): 63 | return self.model.parameters() 64 | 65 | 66 | class LayoutModelConfig(PretrainedConfig): 67 | def __init__( 68 | self, 69 | old_vocab_size: int = 32000, 70 | vocab_size: int = 32000, 71 | pad_token_id: int = 2, 72 | freeze_lm: bool = True, 73 | opt_version: str = 'facebook/opt-6.7b', 74 | hidden_size: int = -1, 75 | load_in_4bit: Optional[bool] = False, 76 | ignore_ids: List[int] = [], 77 | **kwargs, 78 | ): 79 | super().__init__(**kwargs) 80 | assert old_vocab_size > 0, 'old_vocab_size must be positive' 81 | assert vocab_size > 0, 'vocab_size must be positive' 82 | 83 | self.old_vocab_size = old_vocab_size 84 | self.vocab_size = vocab_size 85 | self.pad_token_id = pad_token_id 86 | self.freeze_lm = freeze_lm 87 | self.opt_version = opt_version 88 | self.hidden_size = hidden_size 89 | self.load_in_4bit = load_in_4bit 90 | self.ignore_ids = ignore_ids 91 | 92 | 93 | class LayoutModel(PreTrainedModel): 94 | config_class = LayoutModelConfig 95 | supports_gradient_checkpointing = True 96 | 97 | def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): 98 | self.lm.gradient_checkpointing_enable() 99 | 100 | def __init__(self, config: LayoutModelConfig): 101 | super().__init__(config) 102 | self.pad_token_id = config.pad_token_id 103 | 104 | self.args = config 105 | 106 | opt_version = config.opt_version 107 | 108 | print(f"Using {opt_version} for the language model.") 109 | 110 | if config.load_in_4bit: 111 | print("\n would load_in_4bit") 112 | quantization_config = BitsAndBytesConfig( 113 | load_in_4bit=config.load_in_4bit 114 | ) 115 | # This means: fit the entire model on the GPU:0 116 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 117 | device_map = {"": local_rank} 118 | torch_dtype = torch.bfloat16 119 | else: 120 | print("\n wouldn't load_in_4bit") 121 | device_map = None 122 | quantization_config = None 123 | torch_dtype = None 124 | 125 | self.lm = AutoModelForCausalLM.from_pretrained( 126 | opt_version, 127 | quantization_config=quantization_config, 128 | device_map=device_map, 129 | trust_remote_code=True, 130 | attn_implementation="flash_attention_2", 131 | torch_dtype=torch.bfloat16, 132 | ) 133 | self.config.hidden_size = self.lm.config.hidden_size 134 | self.opt_version = opt_version 135 | 136 | if self.args.freeze_lm: 137 | self.lm.eval() 138 | print("Freezing the LM.") 139 | for param in self.lm.parameters(): 140 | param.requires_grad = False 141 | else: 142 | print("\n no freeze lm, so to train lm") 143 | self.lm.train() 144 | self.lm.config.gradient_checkpointing = True 145 | 146 | print('resize token embeddings to match the tokenizer', config.vocab_size) 147 | self.lm.resize_token_embeddings(config.vocab_size) 148 | self.input_embeddings = self.lm.get_input_embeddings() 149 | 150 | def train(self, mode=True): 151 | super().train(mode=mode) 152 | # Overwrite train() to ensure frozen models remain frozen. 153 | if self.args.freeze_lm: 154 | self.lm.eval() 155 | 156 | def forward( 157 | self, 158 | labels: torch.LongTensor, 159 | ): 160 | batch_size = labels.shape[0] 161 | full_labels = labels.detach().clone() 162 | 163 | input_embs = self.input_embeddings(labels) # (N, T, D) 164 | input_embs_norm = ((input_embs ** 2).sum(dim=-1) ** 0.5).mean() 165 | 166 | for ignore_id in self.config.ignore_ids: 167 | full_labels[full_labels == ignore_id] = -100 168 | 169 | pad_idx = [] 170 | # -100 is the ignore index for cross entropy loss. https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html 171 | for label in full_labels: 172 | for k, token in enumerate(label): 173 | # Mask out pad tokens if they exist. 174 | if token in [self.pad_token_id]: 175 | label[k:] = -100 176 | pad_idx.append(k) 177 | break 178 | if k == len(label) - 1: # No padding found. 179 | pad_idx.append(k + 1) 180 | assert len(pad_idx) == batch_size, (len(pad_idx), batch_size) 181 | 182 | output = self.lm( inputs_embeds=input_embs, 183 | labels=full_labels, 184 | output_hidden_states=True) 185 | 186 | return output, full_labels, input_embs_norm -------------------------------------------------------------------------------- /layout_planner/inference_layout.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import argparse 4 | from typing import List 5 | 6 | import torch 7 | from transformers import AutoTokenizer, set_seed 8 | from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \ 9 | STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings 10 | 11 | from models.modeling_layout import LayoutModel, LayoutModelConfig 12 | from training.datasets.quantizer import get_quantizer 13 | 14 | 15 | class StopAtSpecificTokenCriteria(StoppingCriteria): 16 | def __init__(self, token_id_list: List[int] = None): 17 | self.token_id_list = token_id_list 18 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) 19 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 20 | return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list 21 | 22 | 23 | # build model and tokenizer 24 | def buildmodel(device='cuda:0',**kwargs): 25 | # seed / input model / resume 26 | resume = kwargs.get('resume', None) 27 | seed = kwargs.get('seed', None) 28 | input_model = kwargs.get('input_model', None) 29 | quantizer_version = kwargs.get('quantizer_version', 'v4') 30 | 31 | set_seed(seed) 32 | old_tokenizer = AutoTokenizer.from_pretrained(input_model, trust_remote_code=True) 33 | old_vocab_size = len(old_tokenizer) 34 | print(f"Old vocab size: {old_vocab_size}") 35 | 36 | tokenizer = AutoTokenizer.from_pretrained(resume, trust_remote_code=True) 37 | 38 | new_vocab_size = len(tokenizer) 39 | print(f"New vocab size: {new_vocab_size}") 40 | quantizer = get_quantizer(quantizer_version, 41 | simplify_json = True, 42 | width = kwargs['width'], 43 | height = kwargs['height'] 44 | ) 45 | quantizer.setup_tokenizer(tokenizer) 46 | print(f"latest tokenzier size: {len(tokenizer)}") 47 | 48 | model_args = LayoutModelConfig( 49 | old_vocab_size = old_vocab_size, 50 | vocab_size=len(tokenizer), 51 | pad_token_id=tokenizer.pad_token_id, 52 | ignore_ids=tokenizer.convert_tokens_to_ids(quantizer.ignore_tokens), 53 | ) 54 | 55 | model_args.opt_version = input_model 56 | model_args.freeze_lm = False 57 | model_args.load_in_4bit = kwargs.get('load_in_4bit', False) 58 | 59 | print(f"Resuming from checkpoint {resume}, Waiting to ready") 60 | model = LayoutModel.from_pretrained(resume, config=model_args).to(device) 61 | 62 | return model, quantizer, tokenizer 63 | 64 | 65 | def preprocess_Input(intention: str): 66 | 67 | intention = intention.replace('\n', '').replace('\r', '').replace('\\', '') 68 | intention = re.sub(r'\.\s*', '. ', intention) 69 | 70 | return intention 71 | 72 | 73 | # build data 74 | def FormulateInput(intention: str): 75 | ''' 76 | Formulate user input string to Dict Object 77 | ''' 78 | resdict = {} 79 | resdict["wholecaption"] = intention 80 | resdict["layout"] = [] 81 | 82 | return resdict 83 | 84 | 85 | @torch.no_grad() 86 | def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, device, do_sample=False, temperature=1.0, top_p=1.0, top_k=50): 87 | json_example = inputs 88 | input_intention = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":' 89 | print("input_intention:\n", input_intention) 90 | 91 | inputs = tokenizer( 92 | input_intention, return_tensors="pt" 93 | ).to(device) 94 | 95 | stopping_criteria = StoppingCriteriaList() 96 | stopping_criteria.append(StopAtSpecificTokenCriteria(token_id_list=[128000])) 97 | 98 | outputs = model.lm.generate(**inputs, use_cache=True, max_length=8000, stopping_criteria=stopping_criteria, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k) 99 | inputs_length = inputs['input_ids'].shape[1] 100 | outputs = outputs[:, inputs_length:] 101 | 102 | outputs_word = tokenizer.batch_decode(outputs)[0] 103 | split_word = outputs_word.split('}]}')[0]+"}]}" 104 | split_word = '{"wholecaption":"' + json_example["wholecaption"].replace('\n', '\\n').replace('"', '\\"') + '","layout":[{"layer":' + split_word 105 | 106 | map_dict = quantizer.construct_map_dict() 107 | for key ,value in map_dict.items(): 108 | split_word = split_word.replace(key, value) 109 | 110 | try: 111 | pred_json_example = json.loads(split_word) 112 | for layer in pred_json_example["layout"]: 113 | layer['x'] = round(int(width)*layer['x']) 114 | layer['y'] = round(int(height)*layer['y']) 115 | layer['width'] = round(int(width)*layer['width']) 116 | layer['height'] = round(int(height)*layer['height']) 117 | except Exception as e: 118 | print(e) 119 | pred_json_example = None 120 | return pred_json_example 121 | 122 | 123 | if __name__ == '__main__': 124 | 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument('--inference_caption', type=str, help='User input whole caption') 127 | parser.add_argument('--save_path', type=str, help='Path to save data') 128 | parser.add_argument('--device', type=str, default='cuda:0') 129 | parser.add_argument('--width', type=int, default=1024, help='Width of the layout') 130 | parser.add_argument('--height', type=int, default=1024, help='Height of the layout') 131 | parser.add_argument('--input_model', type=str, help='Path to input base model') 132 | parser.add_argument('--resume', type=str, help='Path to test model checlpoint') 133 | parser.add_argument('--do_sample', type=bool, default=False) 134 | parser.add_argument('--temperature', type=float, default=0.5) 135 | 136 | args = parser.parse_args() 137 | 138 | inference_caption = args.inference_caption 139 | save_path = args.save_path 140 | device = args.device 141 | width = args.width 142 | height = args.height 143 | input_model = args.input_model 144 | resume = args.resume 145 | do_sample = args.do_sample 146 | temperature = args.temperature 147 | 148 | params_dict = { 149 | "input_model": input_model, 150 | "resume": resume, 151 | "seed": 0, 152 | "quantizer_version": 'v4', 153 | "width": width, 154 | "height": height, 155 | } 156 | 157 | # Init model 158 | model, quantizer, tokenizer = buildmodel(device=device, **params_dict) 159 | model = model.to(device) 160 | model = model.bfloat16() 161 | model.eval() 162 | 163 | intention = preprocess_Input(inference_caption) 164 | rawdata = FormulateInput(intention) 165 | preddata = evaluate_v1(rawdata, model, quantizer, tokenizer, width, height, device, do_sample=do_sample, temperature=temperature) 166 | max_try_time = 3 167 | while preddata is None and max_try_time > 0: 168 | preddata = evaluate_v1(rawdata, model, quantizer, tokenizer, width, height, device, do_sample=do_sample, temperature=temperature) 169 | max_try_time -= 1 170 | 171 | print("output : ", preddata) 172 | 173 | with open(save_path,'w') as file: 174 | json.dump(preddata, file, indent=4) 175 | -------------------------------------------------------------------------------- /multi_layer_gen/configs/multi_layer_resolution512_test.py: -------------------------------------------------------------------------------- 1 | _base_ = "./base.py" 2 | 3 | ### path & device settings 4 | 5 | output_path_base = "./output/" 6 | cache_dir = None 7 | 8 | 9 | ### wandb settings 10 | wandb_job_name = "flux_" + '{{fileBasenameNoExtension}}' 11 | 12 | resolution = 512 13 | 14 | ### Model Settings 15 | rank = 64 16 | text_encoder_rank = 64 17 | train_text_encoder = False 18 | max_layer_num = 50 + 2 19 | learnable_proj = True 20 | 21 | ### Training Settings 22 | weighting_scheme = "none" 23 | logit_mean = 0.0 24 | logit_std = 1.0 25 | mode_scale = 1.29 26 | guidance_scale = 1.0 ###IMPORTANT 27 | layer_weighting = 5.0 28 | 29 | # steps 30 | train_batch_size = 1 31 | num_train_epochs = 1 32 | max_train_steps = None 33 | checkpointing_steps = 2000 34 | resume_from_checkpoint = "latest" 35 | gradient_accumulation_steps = 1 36 | 37 | # lr 38 | optimizer = "prodigy" 39 | learning_rate = 1.0 40 | scale_lr = False 41 | lr_scheduler = "constant" 42 | lr_warmup_steps = 0 43 | lr_num_cycles = 1 44 | lr_power = 1.0 45 | 46 | # optim 47 | adam_beta1 = 0.9 48 | adam_beta2 = 0.999 49 | adam_weight_decay = 1e-3 50 | adam_epsilon = 1e-8 51 | prodigy_beta3 = None 52 | prodigy_decouple = True 53 | prodigy_use_bias_correction = True 54 | prodigy_safeguard_warmup = True 55 | max_grad_norm = 1.0 56 | 57 | # logging 58 | tracker_task_name = '{{fileBasenameNoExtension}}' 59 | output_dir = output_path_base + "{{fileBasenameNoExtension}}" 60 | 61 | ### Validation Settings 62 | num_validation_images = 1 63 | validation_steps = 2000 64 | validation_prompts = [ 65 | 'The image features a background with a soft, pastel color gradient that transitions from pink to purple. There are abstract floral elements scattered throughout the background, with some appearing to be in full bloom and others in a more delicate, bud-like state. The flowers have a watercolor effect, with soft edges that blend into the background.\n\nCentered in the image is a quote in a serif font that reads, "You\'re free to be different." The text is black, which stands out against the lighter background. The overall style of the image is artistic and inspirational, with a motivational message that encourages individuality and self-expression. The image could be used for motivational purposes, as a background for a blog or social media post, or as part of a personal development or self-help theme.', 66 | 'The image features a logo for a company named "Bull Head Party Adventure." The logo is stylized with a cartoon-like depiction of a bull\'s head, which is the central element of the design. The bull has prominent horns and a fierce expression, with its mouth slightly open as if it\'s snarling or roaring. The color scheme of the bull is a mix of brown and beige tones, with the horns highlighted in a lighter shade.\n\nBelow the bull\'s head, the company name is written in a bold, sans-serif font. The text is arranged in two lines, with "Bull Head" on the top line and "Party Adventure" on the bottom line. The font color matches the color of the bull, creating a cohesive look. The overall style of the image is playful and energetic, suggesting that the company may offer exciting or adventurous party experiences.', 67 | 'The image features a festive and colorful illustration with a theme related to the Islamic holiday of Eid al-Fitr. At the center of the image is a large, ornate crescent moon with intricate patterns and decorations. Surrounding the moon are several smaller stars and crescents, also adorned with decorative elements. These smaller celestial motifs are suspended from the moon, creating a sense of depth and dimension.\n\nBelow the central moon, there is a banner with the text "Eid Mubarak" in a stylized, elegant font. The text is in a bold, dark color that stands out against the lighter background. The background itself is a gradient of light to dark green, which complements the golden and white hues of the celestial motifs.\n\nThe overall style of the image is celebratory and decorative, with a focus on the traditional symbols associated with Eid al-Fitr. The use of gold and white gives the image a luxurious and festive feel, while the green background is a color often associated with Islam. The image appears to be a digital artwork or graphic design, possibly intended for use as a greeting card or a festive decoration.', 68 | 'The image is a festive graphic with a dark background. At the center, there is a large, bold text that reads "Happy New Year 2023" in a combination of white and gold colors. The text is surrounded by numerous white balloons with gold ribbons, giving the impression of a celebratory atmosphere. The balloons are scattered around the text, creating a sense of depth and movement. Additionally, there are small gold sparkles and confetti-like elements that add to the celebratory theme. The overall design suggests a New Year\'s celebration, with the year 2023 being the focal point.', 69 | 'The image is a stylized illustration with a flat design aesthetic. It depicts a scene related to healthcare or medical care. In the center, there is a hospital bed with a patient lying down, appearing to be resting or possibly receiving treatment. The patient is surrounded by three individuals who seem to be healthcare professionals or caregivers. They are standing around the bed, with one on each side and one at the foot of the bed. The person at the foot of the bed is holding a clipboard, suggesting they might be taking notes or reviewing medical records.\n\nThe room has a window with curtains partially drawn, allowing some light to enter. The color palette is soft, with pastel tones dominating the scene. The text "INTERNATIONAL CANCER DAY" is prominently displayed at the top of the image, indicating that the illustration is related to this event. The overall impression is one of care and support, with a focus on the patient\'s well-being.', 70 | 'The image features a stylized illustration of a man with a beard and a tank top, drinking from a can. The man is depicted in a simplified, cartoon-like style with a limited color palette. Above him, there is a text that reads "Happy Eating, Friends" in a bold, friendly font. Below the illustration, there is another line of text that states "Food is a Necessity That is Not Prioritized," which is also in a bold, sans-serif font. The background of the image is a gradient of light to dark blue, giving the impression of a sky or a calm, serene environment. The overall style of the image is casual and approachable, with a focus on the message conveyed by the text.', 71 | 'The image is a digital illustration with a pastel pink background. At the top, there is a text that reads "Sending you my Easter wishes" in a simple, sans-serif font. Below this, a larger text states "May Your Heart be Happy!" in a more decorative, serif font. Underneath this main message, there is a smaller text that says "Let the miracle of the season fill you with hope and love."\n\nThe illustration features three stylized flowers with smiling faces. On the left, there is a purple flower with a yellow center. In the center, there is a blue flower with a green center. On the right, there is a pink flower with a yellow center. Each flower has a pair of eyes and a mouth, giving them a friendly appearance. The flowers are drawn with a cartoon-like style, using solid colors and simple shapes.\n\nThe overall style of the image is cheerful and whimsical, with a clear Easter theme suggested by the text and the presence of flowers, which are often associated with spring and new beginnings.', 72 | 'The image is a vibrant and colorful graphic with a pink background. In the center, there is a photograph of a man and a woman embracing each other. The man is wearing a white shirt, and the woman is wearing a patterned top. They are both smiling and appear to be in a joyful mood.\n\nSurrounding the photograph are various elements that suggest a festive or celebratory theme. There are three hot air balloons in the background, each with a different design: one with a heart, one with a gift box, and one with a basket. These balloons are floating against a clear sky.\n\nAdditionally, there are two gift boxes with ribbons, one on the left and one on the right side of the image. These gift boxes are stylized with a glossy finish and are placed at different heights, creating a sense of depth.\n\nAt the bottom of the image, there is a large red heart, which is a common symbol associated with love and Valentine\'s Day.\n\nFinally, at the very bottom of the image, there is a text that reads "Happy Valentine\'s Day," which confirms the theme of the image as a Valentine\'s Day greeting. The text is in a playful, cursive font that matches the overall cheerful and romantic tone of the image.', 73 | 'The image depicts a stylized illustration of two women sitting on stools, engaged in conversation. They are wearing traditional attire, with headscarves and patterned dresses. The woman on the left is wearing a brown dress with a purple pattern, while the woman on the right is wearing a purple dress with a brown pattern. Between them is a purple flower. Above the women, the text "INTERNATIONAL WOMEN\'S DAY" is written in bold, uppercase letters. The background is a soft, pastel pink, and there are abstract, swirling lines in a darker shade of pink above the women. The overall style of the image is simplistic and cartoonish, with a warm and friendly tone.', 74 | 'The image is a stylized illustration with a warm, peach-colored background. At the center, there is a vintage-style radio with a prominent dial and antenna. The radio is emitting a blue, star-like burst of light or energy from its top. Surrounding the radio are various objects and elements that seem to be floating or suspended in the air. These include a brown, cone-shaped object, a blue, star-like shape, and a brown, wavy, abstract shape that could be interpreted as a flower or a wave.\n\nAt the top of the image, there is text that reads "World Radio Day" in a bold, serif font. Below this, in a smaller, sans-serif font, is the date "13 February 2022." The overall style of the image is playful and cartoonish, with a clear focus on celebrating World Radio Day.', 75 | 'The image is a graphic design of a baby shower invitation. The central focus is a cute, cartoon-style teddy bear with a friendly expression, sitting upright. The bear is colored in a soft, light brown hue. Above the bear, there is a bold text that reads "YOU\'RE INVITED" in a playful, sans-serif font. Below this, the words "BABY SHOWER" are prominently displayed in a larger, more decorative font, suggesting the theme of the event.\n\nThe background of the invitation is a soft, light pink color, which adds to the gentle and welcoming atmosphere of the design. At the bottom of the image, there is additional text providing specific details about the event. It reads "27 January, 2022 - 8:00 PM" followed by "FAUGET INDUSTRIES CAFE," indicating the date, time, and location of the baby shower.\n\nThe overall style of the image is warm, inviting, and child-friendly, with a clear focus on the theme of a baby shower celebration. The use of a teddy bear as the central image reinforces the baby-related theme. The design is simple yet effective, with a clear hierarchy of information that guides the viewer\'s attention from the top to the bottom of the invitation.', 76 | ] 77 | 78 | validation_boxes = [ 79 | [(0, 0, 512, 512), (0, 0, 512, 512), (368, 0, 512, 272), (0, 272, 112, 512), (160, 208, 352, 304)], 80 | [(0, 0, 512, 512), (0, 0, 512, 512), (128, 128, 384, 304), (96, 288, 416, 336), (128, 336, 384, 368)], 81 | [(0, 0, 512, 512), (0, 0, 512, 512), (112, 48, 400, 368), (0, 48, 96, 176), (128, 336, 384, 384), (240, 384, 384, 432)], 82 | [(0, 0, 512, 512), (0, 0, 512, 512), (32, 32, 480, 480), (80, 176, 432, 368), (64, 176, 448, 224), (144, 96, 368, 224)], 83 | [(0, 0, 512, 512), (0, 0, 512, 512), (0, 64, 176, 272), (0, 400, 512, 512), (16, 160, 496, 512), (224, 48, 464, 112), (208, 96, 464, 160)], 84 | [(0, 0, 512, 512), (0, 0, 512, 512), (112, 224, 512, 512), (0, 0, 240, 160), (144, 144, 512, 512), (48, 64, 432, 208), (48, 400, 256, 448)], 85 | [(0, 0, 512, 512), (0, 0, 512, 512), (160, 48, 352, 80), (64, 80, 448, 192), (128, 208, 384, 240), (320, 240, 512, 512), (80, 272, 368, 512), (0, 224, 192, 512)], 86 | [(0, 0, 512, 512), (0, 0, 512, 512), (48, 0, 464, 304), (128, 144, 384, 400), (288, 288, 384, 368), (336, 304, 400, 368), (176, 432, 336, 480), (224, 400, 288, 432)], 87 | [(0, 0, 512, 512), (0, 0, 512, 512), (32, 288, 448, 512), (144, 176, 336, 400), (224, 208, 272, 256), (160, 128, 336, 192), (192, 368, 304, 400), (368, 80, 448, 224), (48, 160, 128, 256)], 88 | [(0, 0, 512, 512), (0, 0, 512, 512), (0, 352, 512, 512), (112, 176, 368, 432), (48, 176, 128, 256), (48, 368, 128, 448), (384, 192, 480, 272), (384, 336, 432, 384), (80, 80, 432, 128), (176, 128, 336, 160)], 89 | [(0, 0, 512, 512), (0, 0, 512, 512), (0, 0, 512, 352), (144, 384, 368, 448), (160, 192, 352, 432), (368, 0, 512, 144), (0, 0, 144, 144), (128, 80, 384, 208), (128, 448, 384, 496), (176, 48, 336, 80)], 90 | ] -------------------------------------------------------------------------------- /multi_layer_gen/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import os 17 | import sys 18 | import math 19 | import random 20 | import argparse 21 | import numpy as np 22 | from PIL import Image 23 | from tqdm import tqdm 24 | from mmengine.config import Config 25 | 26 | import torch 27 | import torch.utils.checkpoint 28 | from torchvision.utils import save_image 29 | 30 | from diffusers import FluxTransformer2DModel 31 | from diffusers.utils import check_min_version 32 | from diffusers.configuration_utils import FrozenDict 33 | from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING 34 | 35 | from custom_model_mmdit import CustomFluxTransformer2DModel 36 | from custom_model_transp_vae import AutoencoderKLTransformerTraining as CustomVAE 37 | from custom_pipeline import CustomFluxPipelineCfg 38 | 39 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 40 | check_min_version("0.31.0.dev0") 41 | 42 | 43 | def seed_everything(seed): 44 | random.seed(seed) 45 | np.random.seed(seed) 46 | torch.manual_seed(seed) 47 | if torch.cuda.is_available(): 48 | torch.cuda.manual_seed(seed) 49 | torch.cuda.manual_seed_all(seed) 50 | torch.backends.cudnn.deterministic = True 51 | 52 | 53 | def parse_config(path=None): 54 | 55 | if path is None: 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('config_dir', type=str) 58 | args = parser.parse_args() 59 | path = args.config_dir 60 | config = Config.fromfile(path) 61 | 62 | config.config_dir = path 63 | 64 | if "LOCAL_RANK" in os.environ: 65 | config.local_rank = int(os.environ["LOCAL_RANK"]) 66 | elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: 67 | config.local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 68 | else: 69 | config.local_rank = -1 70 | 71 | return config 72 | 73 | 74 | def initialize_pipeline(config, args): 75 | 76 | # Load the original Transformer model from the pretrained model 77 | transformer_orig = FluxTransformer2DModel.from_pretrained( 78 | config.transformer_varient if hasattr(config, "transformer_varient") else config.pretrained_model_name_or_path, 79 | subfolder="" if hasattr(config, "transformer_varient") else "transformer", 80 | revision=config.revision, 81 | variant=config.variant, 82 | torch_dtype=torch.bfloat16, 83 | cache_dir=config.get("cache_dir", None), 84 | ) 85 | 86 | # Configure the custom Transformer model 87 | mmdit_config = dict(transformer_orig.config) 88 | mmdit_config["_class_name"] = "CustomSD3Transformer2DModel" 89 | mmdit_config["max_layer_num"] = config.max_layer_num 90 | mmdit_config = FrozenDict(mmdit_config) 91 | transformer = CustomFluxTransformer2DModel.from_config(mmdit_config).to(dtype=torch.bfloat16) 92 | missing_keys, unexpected_keys = transformer.load_state_dict(transformer_orig.state_dict(), strict=False) 93 | 94 | # Fuse initial LoRA weights 95 | if args.pre_fuse_lora_dir is not None: 96 | lora_state_dict = CustomFluxPipelineCfg.lora_state_dict(args.pre_fuse_lora_dir) 97 | CustomFluxPipelineCfg.load_lora_into_transformer(lora_state_dict, None, transformer) 98 | transformer.fuse_lora(safe_fusing=True) 99 | transformer.unload_lora() # Unload LoRA parameters 100 | 101 | # Load layer_pe weights 102 | layer_pe_path = os.path.join(args.ckpt_dir, "layer_pe.pth") 103 | layer_pe = torch.load(layer_pe_path) 104 | missing_keys, unexpected_keys = transformer.load_state_dict(layer_pe, strict=False) 105 | 106 | # Initialize the custom pipeline 107 | pipeline_type = CustomFluxPipelineCfg 108 | pipeline = pipeline_type.from_pretrained( 109 | config.pretrained_model_name_or_path, 110 | transformer=transformer, 111 | revision=config.revision, 112 | variant=config.variant, 113 | torch_dtype=torch.bfloat16, 114 | cache_dir=config.get("cache_dir", None), 115 | ).to(torch.device("cuda", index=args.gpu_id)) 116 | pipeline.enable_model_cpu_offload(gpu_id=args.gpu_id) # Save GPU memory 117 | 118 | # Load LoRA weights 119 | pipeline.load_lora_weights(args.ckpt_dir, adapter_name="layer") 120 | 121 | # Load additional LoRA weights 122 | if args.extra_lora_dir is not None: 123 | _SET_ADAPTER_SCALE_FN_MAPPING["CustomFluxTransformer2DModel"] = _SET_ADAPTER_SCALE_FN_MAPPING["FluxTransformer2DModel"] 124 | pipeline.load_lora_weights(args.extra_lora_dir, adapter_name="extra") 125 | pipeline.set_adapters(["layer", "extra"], adapter_weights=[1.0, 0.5]) 126 | 127 | return pipeline 128 | 129 | def get_fg_layer_box(list_layer_pt): 130 | list_layer_box = [] 131 | for layer_pt in list_layer_pt: 132 | alpha_channel = layer_pt[:, 3:4] 133 | 134 | if layer_pt.shape[1] == 3: 135 | list_layer_box.append( 136 | (0, 0, layer_pt.shape[3], layer_pt.shape[2]) 137 | ) 138 | continue 139 | 140 | # Step 1: Find the non-zero indices 141 | _, _, rows, cols = torch.nonzero(alpha_channel + 1, as_tuple=True) 142 | 143 | if (rows.numel() == 0) or (cols.numel() == 0): 144 | # If there are no non-zero indices, we can skip this layer 145 | list_layer_box.append(None) 146 | continue 147 | 148 | # Step 2: Get the minimum and maximum indices for rows and columns 149 | min_row, max_row = rows.min().item(), rows.max().item() 150 | min_col, max_col = cols.min().item(), cols.max().item() 151 | 152 | # Step 3: Quantize the minimum values down to the nearest multiple of 16 153 | quantized_min_row = (min_row // 16) * 16 154 | quantized_min_col = (min_col // 16) * 16 155 | 156 | # Step 4: Quantize the maximum values up to the nearest multiple of 16 outside of the max 157 | quantized_max_row = ((max_row // 16) + 1) * 16 158 | quantized_max_col = ((max_col // 16) + 1) * 16 159 | list_layer_box.append( 160 | (quantized_min_col, quantized_min_row, quantized_max_col, quantized_max_row) 161 | ) 162 | return list_layer_box 163 | 164 | 165 | def adjust_coordinate(value, floor_or_ceil, k=16, min_val=0, max_val=1024): 166 | # Round the value to the nearest multiple of k 167 | if floor_or_ceil == "floor": 168 | rounded_value = math.floor(value / k) * k 169 | else: 170 | rounded_value = math.ceil(value / k) * k 171 | # Clamp the value between min_val and max_val 172 | return max(min_val, min(rounded_value, max_val)) 173 | 174 | 175 | def test(args): 176 | 177 | if args.seed is not None: 178 | seed_everything(args.seed) 179 | 180 | cfg_path = args.cfg_path 181 | config = parse_config(cfg_path) 182 | 183 | if args.variant is not None: args.save_dir += '_' + args.variant 184 | 185 | # Initialize pipeline 186 | pipeline = initialize_pipeline(config, args) 187 | 188 | # load multi-layer-transparent-vae-decoder 189 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 190 | transp_vae = CustomVAE() 191 | transp_vae_path = args.transp_vae_ckpt 192 | missing, unexpected = transp_vae.load_state_dict(torch.load(transp_vae_path)['model'], strict=False) 193 | transp_vae.eval() 194 | 195 | test_samples = [ 196 | { 197 | "index": "reso512_1", 198 | "wholecaption": "The image shows a collection of luggage items on a carpeted floor. There are three main pieces of luggage: a large suitcase, a smaller suitcase, and a duffel bag. The large suitcase is positioned in the center, with the smaller suitcase to its left and the duffel bag to its right. The luggage appears to be packed and ready for travel. In the foreground, there is a plastic bag containing what looks like a pair of shoes. The background features a white curtain, suggesting that the setting might be indoors, possibly a hotel room or a similar temporary accommodation. The image is in black and white, which gives it a timeless or classic feel.", 199 | "layout": [(0, 0, 512, 512), (0, 0, 512, 512), (281, 203, 474, 397), (94, 22, 294, 406), (190, 327, 379, 471)], 200 | }, 201 | { 202 | "index": "reso512_2", 203 | "wholecaption": "The image features a logo for a flower shop named ”Estelle Darcy Flower Shop.” The logo is designed with a stylized flower, which appears to be a rose, in shades of pink and green. The flower is positioned to the left of the text, which is written in a cursive font. The text is in a brown color, and the overall style of the image is simple and elegant, with a clean, light background that does not distract from the logo itself. The logo conveys a sense of freshness and natural beauty, which is fitting for a flower shop.", 204 | "layout": [(0, 0, 512, 512), (0, 0, 512, 512), (320, 160, 432, 352), (128, 240, 368, 320), (128, 304, 352, 336)], 205 | }, 206 | ] 207 | 208 | for idx, batch in tqdm(enumerate(test_samples)): 209 | 210 | generator = torch.Generator(device=torch.device("cuda", index=args.gpu_id)).manual_seed(args.seed) if args.seed else None 211 | 212 | this_index = batch["index"] 213 | 214 | validation_prompt = batch["wholecaption"] 215 | validation_box_raw = batch["layout"] 216 | validation_box = [ 217 | ( 218 | adjust_coordinate(rect[0], floor_or_ceil="floor"), 219 | adjust_coordinate(rect[1], floor_or_ceil="floor"), 220 | adjust_coordinate(rect[2], floor_or_ceil="ceil"), 221 | adjust_coordinate(rect[3], floor_or_ceil="ceil"), 222 | ) 223 | for rect in validation_box_raw 224 | ] 225 | if len(validation_box) > 52: 226 | validation_box = validation_box[:52] 227 | 228 | output, rgba_output, _, _ = pipeline( 229 | prompt=validation_prompt, 230 | validation_box=validation_box, 231 | generator=generator, 232 | height=config.resolution, 233 | width=config.resolution, 234 | num_layers=len(validation_box), 235 | guidance_scale=args.cfg, 236 | num_inference_steps=args.steps, 237 | transparent_decoder=transp_vae, 238 | ) 239 | images = output.images # list of PIL, len=layers 240 | rgba_images = [Image.fromarray(arr, 'RGBA') for arr in rgba_output] 241 | 242 | os.makedirs(os.path.join(args.save_dir, this_index), exist_ok=True) 243 | os.system(f"rm -rf {os.path.join(args.save_dir, this_index)}/*") 244 | for frame_idx, frame_pil in enumerate(images): 245 | frame_pil.save(os.path.join(args.save_dir, this_index, f"layer_{frame_idx}.png")) 246 | if frame_idx == 0: 247 | frame_pil.save(os.path.join(args.save_dir, this_index, "merged.png")) 248 | merged_pil = images[1].convert('RGBA') 249 | for frame_idx, frame_pil in enumerate(rgba_images): 250 | if frame_idx < 2: 251 | frame_pil = images[frame_idx].convert('RGBA') # merged and background 252 | else: 253 | merged_pil = Image.alpha_composite(merged_pil, frame_pil) 254 | frame_pil.save(os.path.join(args.save_dir, this_index, f"layer_{frame_idx}_rgba.png")) 255 | 256 | merged_pil = merged_pil.convert('RGB') 257 | merged_pil.save(os.path.join(args.save_dir, this_index, "merged_rgba.png")) 258 | 259 | del pipeline 260 | if torch.cuda.is_available(): 261 | torch.cuda.empty_cache() 262 | 263 | 264 | if __name__ == "__main__": 265 | parser = argparse.ArgumentParser() 266 | parser.add_argument("--cfg_path", type=str) 267 | parser.add_argument("--ckpt_dir", type=str) 268 | parser.add_argument("--transp_vae_ckpt", type=str) 269 | parser.add_argument("--pre_fuse_lora_dir", type=str) 270 | parser.add_argument("--extra_lora_dir", type=str, default=None) 271 | parser.add_argument("--save_dir", type=str) 272 | parser.add_argument("--variant", type=str, default="None") 273 | parser.add_argument("--cfg", type=float, default=4.0) 274 | parser.add_argument("--steps", type=int, default=28) 275 | parser.add_argument("--seed", type=int, default=45) 276 | parser.add_argument("--gpu_id", type=int, default=0) 277 | 278 | args = parser.parse_args() 279 | 280 | test(args) -------------------------------------------------------------------------------- /multi_layer_gen/custom_model_transp_vae.py: -------------------------------------------------------------------------------- 1 | import einops 2 | from collections import OrderedDict 3 | from functools import partial 4 | from typing import Callable 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | from torch.utils.checkpoint import checkpoint 10 | 11 | from accelerate.utils import set_module_tensor_to_device 12 | from diffusers.models.embeddings import apply_rotary_emb, FluxPosEmbed 13 | from diffusers.models.modeling_utils import ModelMixin 14 | from diffusers.configuration_utils import ConfigMixin 15 | from diffusers.loaders import FromOriginalModelMixin 16 | 17 | 18 | class MLPBlock(torchvision.ops.misc.MLP): 19 | """Transformer MLP block.""" 20 | 21 | _version = 2 22 | 23 | def __init__(self, in_dim: int, mlp_dim: int, dropout: float): 24 | super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) 25 | 26 | for m in self.modules(): 27 | if isinstance(m, nn.Linear): 28 | nn.init.xavier_uniform_(m.weight) 29 | if m.bias is not None: 30 | nn.init.normal_(m.bias, std=1e-6) 31 | 32 | def _load_from_state_dict( 33 | self, 34 | state_dict, 35 | prefix, 36 | local_metadata, 37 | strict, 38 | missing_keys, 39 | unexpected_keys, 40 | error_msgs, 41 | ): 42 | version = local_metadata.get("version", None) 43 | 44 | if version is None or version < 2: 45 | # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053 46 | for i in range(2): 47 | for type in ["weight", "bias"]: 48 | old_key = f"{prefix}linear_{i+1}.{type}" 49 | new_key = f"{prefix}{3*i}.{type}" 50 | if old_key in state_dict: 51 | state_dict[new_key] = state_dict.pop(old_key) 52 | 53 | super()._load_from_state_dict( 54 | state_dict, 55 | prefix, 56 | local_metadata, 57 | strict, 58 | missing_keys, 59 | unexpected_keys, 60 | error_msgs, 61 | ) 62 | 63 | 64 | class EncoderBlock(nn.Module): 65 | """Transformer encoder block.""" 66 | 67 | def __init__( 68 | self, 69 | num_heads: int, 70 | hidden_dim: int, 71 | mlp_dim: int, 72 | dropout: float, 73 | attention_dropout: float, 74 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 75 | ): 76 | super().__init__() 77 | self.num_heads = num_heads 78 | self.hidden_dim = hidden_dim 79 | self.num_heads = num_heads 80 | 81 | # Attention block 82 | self.ln_1 = norm_layer(hidden_dim) 83 | self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) 84 | self.dropout = nn.Dropout(dropout) 85 | 86 | # MLP block 87 | self.ln_2 = norm_layer(hidden_dim) 88 | self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) 89 | 90 | def forward(self, input: torch.Tensor, freqs_cis): 91 | torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") 92 | B, L, C = input.shape 93 | x = self.ln_1(input) 94 | if freqs_cis is not None: 95 | query = x.view(B, L, self.num_heads, self.hidden_dim // self.num_heads).transpose(1, 2) 96 | query = apply_rotary_emb(query, freqs_cis) 97 | query = query.transpose(1, 2).reshape(B, L, self.hidden_dim) 98 | x, _ = self.self_attention(query, query, x, need_weights=False) 99 | x = self.dropout(x) 100 | x = x + input 101 | 102 | y = self.ln_2(x) 103 | y = self.mlp(y) 104 | return x + y 105 | 106 | 107 | class Encoder(nn.Module): 108 | """Transformer Model Encoder for sequence to sequence translation.""" 109 | 110 | def __init__( 111 | self, 112 | seq_length: int, 113 | num_layers: int, 114 | num_heads: int, 115 | hidden_dim: int, 116 | mlp_dim: int, 117 | dropout: float, 118 | attention_dropout: float, 119 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 120 | ): 121 | super().__init__() 122 | # Note that batch_size is on the first dim because 123 | # we have batch_first=True in nn.MultiAttention() by default 124 | # self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT 125 | self.dropout = nn.Dropout(dropout) 126 | layers: OrderedDict[str, nn.Module] = OrderedDict() 127 | for i in range(num_layers): 128 | layers[f"encoder_layer_{i}"] = EncoderBlock( 129 | num_heads, 130 | hidden_dim, 131 | mlp_dim, 132 | dropout, 133 | attention_dropout, 134 | norm_layer, 135 | ) 136 | self.layers = nn.Sequential(layers) 137 | self.ln = norm_layer(hidden_dim) 138 | 139 | def forward(self, input: torch.Tensor, freqs_cis): 140 | torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") 141 | input = input # + self.pos_embedding 142 | x = self.dropout(input) 143 | for l in self.layers: 144 | x = checkpoint(l, x, freqs_cis) 145 | x = self.ln(x) 146 | return x 147 | 148 | 149 | class ViTEncoder(nn.Module): 150 | def __init__(self, arch='vit-b/32'): 151 | super().__init__() 152 | self.arch = arch 153 | 154 | if self.arch == 'vit-b/32': 155 | ch = 768 156 | layers = 12 157 | heads = 12 158 | elif self.arch == 'vit-h/14': 159 | ch = 1280 160 | layers = 32 161 | heads = 16 162 | 163 | self.encoder = Encoder( 164 | seq_length=-1, 165 | num_layers=layers, 166 | num_heads=heads, 167 | hidden_dim=ch, 168 | mlp_dim=ch*4, 169 | dropout=0.0, 170 | attention_dropout=0.0, 171 | ) 172 | self.fc_in = nn.Linear(16, ch) 173 | self.fc_out = nn.Linear(ch, 256) 174 | 175 | if self.arch == 'vit-b/32': 176 | from torchvision.models.vision_transformer import vit_b_32, ViT_B_32_Weights 177 | vit = vit_b_32(weights=ViT_B_32_Weights.DEFAULT) 178 | elif self.arch == 'vit-h/14': 179 | from torchvision.models.vision_transformer import vit_h_14, ViT_H_14_Weights 180 | vit = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1) 181 | 182 | missing_keys, unexpected_keys = self.encoder.load_state_dict(vit.encoder.state_dict(), strict=False) 183 | if len(missing_keys) > 0 or len(unexpected_keys) > 0: 184 | print(f"ViT Encoder Missing keys: {missing_keys}") 185 | print(f"ViT Encoder Unexpected keys: {unexpected_keys}") 186 | del vit 187 | 188 | def forward(self, x, freqs_cis): 189 | out = self.fc_in(x) 190 | out = self.encoder(out, freqs_cis) 191 | out = checkpoint(self.fc_out, out) 192 | return out 193 | 194 | 195 | def patchify(x, patch_size=8): 196 | if len(x.shape) == 4: 197 | bs, c, h, w = x.shape 198 | x = einops.rearrange(x, "b c (h p1) (w p2) -> b (c p1 p2) h w", p1=patch_size, p2=patch_size) 199 | elif len(x.shape) == 3: 200 | c, h, w = x.shape 201 | x = einops.rearrange(x, "c (h p1) (w p2) -> (c p1 p2) h w", p1=patch_size, p2=patch_size) 202 | return x 203 | 204 | 205 | def unpatchify(x, patch_size=8): 206 | if len(x.shape) == 4: 207 | bs, c, h, w = x.shape 208 | x = einops.rearrange(x, "b (c p1 p2) h w -> b c (h p1) (w p2)", p1=patch_size, p2=patch_size) 209 | elif len(x.shape) == 3: 210 | c, h, w = x.shape 211 | x = einops.rearrange(x, "(c p1 p2) h w -> c (h p1) (w p2)", p1=patch_size, p2=patch_size) 212 | return x 213 | 214 | 215 | def crop_each_layer(hidden_states, use_layers, list_layer_box, H, W, pos_embedding): 216 | token_list = [] 217 | cos_list, sin_list = [], [] 218 | for layer_idx in range(hidden_states.shape[1]): 219 | if list_layer_box[layer_idx] is None: 220 | continue 221 | else: 222 | x1, y1, x2, y2 = list_layer_box[layer_idx] 223 | x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8 224 | layer_token = hidden_states[:, layer_idx, y1:y2, x1:x2] 225 | c, h, w = layer_token.shape 226 | layer_token = layer_token.reshape(c, -1) 227 | token_list.append(layer_token) 228 | ids = prepare_latent_image_ids(-1, H * 2, W * 2, hidden_states.device, hidden_states.dtype) 229 | ids[:, 0] = use_layers[layer_idx] 230 | image_rotary_emb = pos_embedding(ids) 231 | pos_cos, pos_sin = image_rotary_emb[0].reshape(H, W, -1), image_rotary_emb[1].reshape(H, W, -1) 232 | cos_list.append(pos_cos[y1:y2, x1:x2].reshape(-1, 64)) 233 | sin_list.append(pos_sin[y1:y2, x1:x2].reshape(-1, 64)) 234 | token_list = torch.cat(token_list, dim=1).permute(1, 0) 235 | cos_list = torch.cat(cos_list, dim=0) 236 | sin_list = torch.cat(sin_list, dim=0) 237 | return token_list, (cos_list, sin_list) 238 | 239 | 240 | def prepare_latent_image_ids(batch_size, height, width, device, dtype): 241 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) 242 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] 243 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] 244 | 245 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape 246 | 247 | latent_image_ids = latent_image_ids.reshape( 248 | latent_image_id_height * latent_image_id_width, latent_image_id_channels 249 | ) 250 | 251 | return latent_image_ids.to(device=device, dtype=dtype) 252 | 253 | 254 | class AutoencoderKLTransformerTraining(ModelMixin, ConfigMixin, FromOriginalModelMixin): 255 | def __init__(self): 256 | super().__init__() 257 | 258 | self.decoder_arch = 'vit' 259 | self.layer_embedding = 'rope' 260 | 261 | self.decoder = ViTEncoder() 262 | self.pos_embedding = FluxPosEmbed(theta=10000, axes_dim=(8, 28, 28)) 263 | if 'rel' in self.layer_embedding or 'abs' in self.layer_embedding: 264 | self.layer_embedding = nn.Parameter(torch.empty(16, 2 + self.max_layers, 1, 1).normal_(std=0.02), requires_grad=True) 265 | 266 | def zero_module(module): 267 | """ 268 | Zero out the parameters of a module and return it. 269 | """ 270 | for p in module.parameters(): 271 | p.detach().zero_() 272 | return module 273 | 274 | def encode(self, z_2d, box, use_layers): 275 | B, C, T, H, W = z_2d.shape 276 | 277 | z, freqs_cis = [], [] 278 | for b in range(B): 279 | _z = z_2d[b] 280 | if 'vit' in self.decoder_arch: 281 | _use_layers = torch.tensor(use_layers[b], device=z_2d.device) 282 | if 'rel' in self.layer_embedding: 283 | _use_layers[_use_layers > 2] = 2 284 | if 'rel' in self.layer_embedding or 'abs' in self.layer_embedding: 285 | _z = _z + self.layer_embedding[:, _use_layers] # + self.pos_embedding 286 | if 'rope' not in self.layer_embedding: 287 | use_layers[b] = [0] * len(use_layers[b]) 288 | _z, cis = crop_each_layer(_z, use_layers[b], box[b], H, W, self.pos_embedding) ### modified 289 | z.append(_z) 290 | freqs_cis.append(cis) 291 | 292 | return z, freqs_cis 293 | 294 | def decode(self, z, freqs_cis, box, H, W): 295 | B = len(z) 296 | pad = torch.zeros(4, H, W, device=z[0].device, dtype=z[0].dtype) 297 | pad[3, :, :] = -1 298 | x = [] 299 | for b in range(B): 300 | _x = [] 301 | _z = self.decoder(z[b].unsqueeze(0), freqs_cis[b]).squeeze(0) 302 | current_index = 0 303 | for layer_idx in range(len(box[b])): 304 | if box[b][layer_idx] == None: 305 | _x.append(pad.clone()) 306 | else: 307 | x1, y1, x2, y2 = box[b][layer_idx] 308 | x1_tok, y1_tok, x2_tok, y2_tok = x1 // 8, y1 // 8, x2 // 8, y2 // 8 309 | token_length = (x2_tok - x1_tok) * (y2_tok - y1_tok) 310 | tokens = _z[current_index:current_index + token_length] 311 | pixels = einops.rearrange(tokens, "(h w) c -> c h w", h=y2_tok - y1_tok, w=x2_tok - x1_tok) 312 | unpatched = unpatchify(pixels) 313 | pixels = pad.clone() 314 | pixels[:, y1:y2, x1:x2] = unpatched 315 | _x.append(pixels) 316 | current_index += token_length 317 | _x = torch.stack(_x, dim=1) 318 | x.append(_x) 319 | x = torch.stack(x, dim=0) 320 | return x 321 | 322 | def forward(self, z_2d, box, use_layers=None): 323 | z_2d = z_2d.transpose(0, 1).unsqueeze(0) 324 | use_layers = use_layers or [list(range(z_2d.shape[2]))] 325 | z, freqs_cis = self.encode(z_2d, box, use_layers) 326 | H, W = z_2d.shape[-2:] 327 | x_hat = self.decode(z, freqs_cis, box, H * 8, W * 8) 328 | assert x_hat.shape[0] == 1, x_hat.shape 329 | x_hat = einops.rearrange(x_hat[0], "c t h w -> t c h w") 330 | x_hat_rgb, x_hat_alpha = x_hat[:, :3], x_hat[:, 3:] 331 | return x_hat_rgb, x_hat_alpha 332 | -------------------------------------------------------------------------------- /multi_layer_gen/custom_model_mmdit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Any, Dict, List, Optional, Union, Tuple 4 | 5 | from accelerate.utils import set_module_tensor_to_device 6 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 7 | from diffusers.models.normalization import AdaLayerNormContinuous 8 | from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed 9 | from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel, FluxTransformerBlock, FluxSingleTransformerBlock 10 | 11 | from diffusers.configuration_utils import register_to_config 12 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 13 | 14 | 15 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 16 | 17 | 18 | class CustomFluxTransformer2DModel(FluxTransformer2DModel): 19 | """ 20 | The Transformer model introduced in Flux. 21 | 22 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 23 | 24 | Parameters: 25 | patch_size (`int`): Patch size to turn the input data into small patches. 26 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. 27 | num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. 28 | num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. 29 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. 30 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. 31 | joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 32 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. 33 | guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. 34 | """ 35 | 36 | @register_to_config 37 | def __init__( 38 | self, 39 | patch_size: int = 1, 40 | in_channels: int = 64, 41 | num_layers: int = 19, 42 | num_single_layers: int = 38, 43 | attention_head_dim: int = 128, 44 | num_attention_heads: int = 24, 45 | joint_attention_dim: int = 4096, 46 | pooled_projection_dim: int = 768, 47 | guidance_embeds: bool = False, 48 | axes_dims_rope: Tuple[int] = (16, 56, 56), 49 | max_layer_num: int = 52, 50 | ): 51 | super(FluxTransformer2DModel, self).__init__() 52 | self.out_channels = in_channels 53 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim 54 | 55 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) 56 | 57 | text_time_guidance_cls = ( 58 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings 59 | ) 60 | self.time_text_embed = text_time_guidance_cls( 61 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim 62 | ) 63 | 64 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) 65 | self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) 66 | 67 | self.transformer_blocks = nn.ModuleList( 68 | [ 69 | FluxTransformerBlock( 70 | dim=self.inner_dim, 71 | num_attention_heads=self.config.num_attention_heads, 72 | attention_head_dim=self.config.attention_head_dim, 73 | ) 74 | for i in range(self.config.num_layers) 75 | ] 76 | ) 77 | 78 | self.single_transformer_blocks = nn.ModuleList( 79 | [ 80 | FluxSingleTransformerBlock( 81 | dim=self.inner_dim, 82 | num_attention_heads=self.config.num_attention_heads, 83 | attention_head_dim=self.config.attention_head_dim, 84 | ) 85 | for i in range(self.config.num_single_layers) 86 | ] 87 | ) 88 | 89 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) 90 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) 91 | 92 | self.gradient_checkpointing = False 93 | 94 | self.max_layer_num = max_layer_num 95 | 96 | # the following process ensures self.layer_pe is not created as a meta tensor 97 | layer_pe_value = nn.init.trunc_normal_( 98 | nn.Parameter(torch.zeros( 99 | 1, self.max_layer_num, 1, 1, self.inner_dim, 100 | )), 101 | mean=0.0, std=0.02, a=-2.0, b=2.0, 102 | ).data.detach() 103 | self.layer_pe = nn.Parameter(layer_pe_value) 104 | set_module_tensor_to_device( 105 | self, 106 | 'layer_pe', 107 | device='cpu', 108 | value=layer_pe_value, 109 | dtype=layer_pe_value.dtype, 110 | ) 111 | 112 | @classmethod 113 | def from_pretrained(cls, *args, **kwarg): 114 | model = super().from_pretrained(*args, **kwarg) 115 | for name, para in model.named_parameters(): 116 | if name != 'layer_pe': 117 | device = para.device 118 | break 119 | model.layer_pe.to(device) 120 | return model 121 | 122 | def crop_each_layer(self, hidden_states, list_layer_box): 123 | """ 124 | hidden_states: [1, n_layers, h, w, inner_dim] 125 | list_layer_box: List, length=n_layers, each element is a Tuple of 4 elements (x1, y1, x2, y2) 126 | """ 127 | token_list = [] 128 | for layer_idx in range(hidden_states.shape[1]): 129 | if list_layer_box[layer_idx] == None: 130 | continue 131 | else: 132 | x1, y1, x2, y2 = list_layer_box[layer_idx] 133 | x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16 134 | layer_token = hidden_states[:, layer_idx, y1:y2, x1:x2, :] 135 | bs, h, w, c = layer_token.shape 136 | layer_token = layer_token.reshape(bs, -1, c) 137 | token_list.append(layer_token) 138 | result = torch.cat(token_list, dim=1) 139 | return result 140 | 141 | def fill_in_processed_tokens(self, hidden_states, full_hidden_states, list_layer_box): 142 | """ 143 | hidden_states: [1, h1xw1 + h2xw2 + ... + hlxwl , inner_dim] 144 | full_hidden_states: [1, n_layers, h, w, inner_dim] 145 | list_layer_box: List, length=n_layers, each element is a Tuple of 4 elements (x1, y1, x2, y2) 146 | """ 147 | used_token_len = 0 148 | bs = hidden_states.shape[0] 149 | for layer_idx in range(full_hidden_states.shape[1]): 150 | if list_layer_box[layer_idx] == None: 151 | continue 152 | else: 153 | x1, y1, x2, y2 = list_layer_box[layer_idx] 154 | x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16 155 | full_hidden_states[:, layer_idx, y1:y2, x1:x2, :] = hidden_states[:, used_token_len: used_token_len + (y2-y1) * (x2-x1), :].reshape(bs, y2-y1, x2-x1, -1) 156 | used_token_len = used_token_len + (y2-y1) * (x2-x1) 157 | return full_hidden_states 158 | 159 | def forward( 160 | self, 161 | hidden_states: torch.Tensor, 162 | list_layer_box: List[Tuple] = None, 163 | encoder_hidden_states: torch.Tensor = None, 164 | pooled_projections: torch.Tensor = None, 165 | timestep: torch.LongTensor = None, 166 | img_ids: torch.Tensor = None, 167 | txt_ids: torch.Tensor = None, 168 | guidance: torch.Tensor = None, 169 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 170 | return_dict: bool = True, 171 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: 172 | """ 173 | The [`FluxTransformer2DModel`] forward method. 174 | 175 | Args: 176 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): 177 | Input `hidden_states`. 178 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): 179 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. 180 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected 181 | from the embeddings of input conditions. 182 | timestep ( `torch.LongTensor`): 183 | Used to indicate denoising step. 184 | block_controlnet_hidden_states: (`list` of `torch.Tensor`): 185 | A list of tensors that if specified are added to the residuals of transformer blocks. 186 | joint_attention_kwargs (`dict`, *optional*): 187 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 188 | `self.processor` in 189 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 190 | return_dict (`bool`, *optional*, defaults to `True`): 191 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 192 | tuple. 193 | 194 | Returns: 195 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 196 | `tuple` where the first element is the sample tensor. 197 | """ 198 | if joint_attention_kwargs is not None: 199 | joint_attention_kwargs = joint_attention_kwargs.copy() 200 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 201 | else: 202 | lora_scale = 1.0 203 | 204 | if USE_PEFT_BACKEND: 205 | # weight the lora layers by setting `lora_scale` for each PEFT layer 206 | scale_lora_layers(self, lora_scale) 207 | else: 208 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 209 | logger.warning( 210 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 211 | ) 212 | 213 | bs, n_layers, channel_latent, height, width = hidden_states.shape # [bs, n_layers, c_latent, h, w] 214 | 215 | hidden_states = hidden_states.view(bs, n_layers, channel_latent, height // 2, 2, width // 2, 2) # [bs, n_layers, c_latent, h/2, 2, w/2, 2] 216 | hidden_states = hidden_states.permute(0, 1, 3, 5, 2, 4, 6) # [bs, n_layers, h/2, w/2, c_latent, 2, 2] 217 | hidden_states = hidden_states.reshape(bs, n_layers, height // 2, width // 2, channel_latent * 4) # [bs, n_layers, h/2, w/2, c_latent*4] 218 | hidden_states = self.x_embedder(hidden_states) # [bs, n_layers, h/2, w/2, inner_dim] 219 | 220 | full_hidden_states = torch.zeros_like(hidden_states) # [bs, n_layers, h/2, w/2, inner_dim] 221 | layer_pe = self.layer_pe.view(1, self.max_layer_num, 1, 1, self.inner_dim) # [1, max_n_layers, 1, 1, inner_dim] 222 | hidden_states = hidden_states + layer_pe[:, :n_layers] # [bs, n_layers, h/2, w/2, inner_dim] + [1, n_layers, 1, 1, inner_dim] --> [bs, f, h/2, w/2, inner_dim] 223 | hidden_states = self.crop_each_layer(hidden_states, list_layer_box) # [bs, token_len, inner_dim] 224 | 225 | timestep = timestep.to(hidden_states.dtype) * 1000 226 | if guidance is not None: 227 | guidance = guidance.to(hidden_states.dtype) * 1000 228 | else: 229 | guidance = None 230 | temb = ( 231 | self.time_text_embed(timestep, pooled_projections) 232 | if guidance is None 233 | else self.time_text_embed(timestep, guidance, pooled_projections) 234 | ) 235 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 236 | 237 | if txt_ids.ndim == 3: 238 | logger.warning( 239 | "Passing `txt_ids` 3d torch.Tensor is deprecated." 240 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 241 | ) 242 | txt_ids = txt_ids[0] 243 | if img_ids.ndim == 3: 244 | logger.warning( 245 | "Passing `img_ids` 3d torch.Tensor is deprecated." 246 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 247 | ) 248 | img_ids = img_ids[0] 249 | ids = torch.cat((txt_ids, img_ids), dim=0) 250 | image_rotary_emb = self.pos_embed(ids) 251 | 252 | for index_block, block in enumerate(self.transformer_blocks): 253 | if self.training and self.gradient_checkpointing: 254 | 255 | def create_custom_forward(module, return_dict=None): 256 | def custom_forward(*inputs): 257 | if return_dict is not None: 258 | return module(*inputs, return_dict=return_dict) 259 | else: 260 | return module(*inputs) 261 | 262 | return custom_forward 263 | 264 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 265 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 266 | create_custom_forward(block), 267 | hidden_states, 268 | encoder_hidden_states, 269 | temb, 270 | image_rotary_emb, 271 | **ckpt_kwargs, 272 | ) 273 | 274 | else: 275 | encoder_hidden_states, hidden_states = block( 276 | hidden_states=hidden_states, 277 | encoder_hidden_states=encoder_hidden_states, 278 | temb=temb, 279 | image_rotary_emb=image_rotary_emb, 280 | ) 281 | 282 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 283 | 284 | for index_block, block in enumerate(self.single_transformer_blocks): 285 | if self.training and self.gradient_checkpointing: 286 | 287 | def create_custom_forward(module, return_dict=None): 288 | def custom_forward(*inputs): 289 | if return_dict is not None: 290 | return module(*inputs, return_dict=return_dict) 291 | else: 292 | return module(*inputs) 293 | 294 | return custom_forward 295 | 296 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 297 | hidden_states = torch.utils.checkpoint.checkpoint( 298 | create_custom_forward(block), 299 | hidden_states, 300 | temb, 301 | image_rotary_emb, 302 | **ckpt_kwargs, 303 | ) 304 | 305 | else: 306 | hidden_states = block( 307 | hidden_states=hidden_states, 308 | temb=temb, 309 | image_rotary_emb=image_rotary_emb, 310 | ) 311 | 312 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 313 | 314 | hidden_states = self.fill_in_processed_tokens(hidden_states, full_hidden_states, list_layer_box) # [bs, n_layers, h/2, w/2, inner_dim] 315 | hidden_states = hidden_states.view(bs, -1, self.inner_dim) # [bs, n_layers * full_len, inner_dim] 316 | 317 | hidden_states = self.norm_out(hidden_states, temb) # [bs, n_layers * full_len, inner_dim] 318 | hidden_states = self.proj_out(hidden_states) # [bs, n_layers * full_len, c_latent*4] 319 | 320 | # unpatchify 321 | hidden_states = hidden_states.view(bs, n_layers, height//2, width//2, channel_latent, 2, 2) # [bs, n_layers, h/2, w/2, c_latent, 2, 2] 322 | hidden_states = hidden_states.permute(0, 1, 4, 2, 5, 3, 6) 323 | output = hidden_states.reshape(bs, n_layers, channel_latent, height, width) # [bs, n_layers, c_latent, h, w] 324 | 325 | if USE_PEFT_BACKEND: 326 | # remove `lora_scale` from each PEFT layer 327 | unscale_lora_layers(self, lora_scale) 328 | 329 | if not return_dict: 330 | return (output,) 331 | 332 | return Transformer2DModelOutput(sample=output) -------------------------------------------------------------------------------- /multi_layer_gen/custom_pipeline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Any, Callable, Dict, List, Optional, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from diffusers.utils.torch_utils import randn_tensor 8 | from diffusers.utils import is_torch_xla_available, logging 9 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput 10 | from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, FluxPipeline 11 | 12 | if is_torch_xla_available(): 13 | import torch_xla.core.xla_model as xm # type: ignore 14 | XLA_AVAILABLE = True 15 | else: 16 | XLA_AVAILABLE = False 17 | 18 | 19 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 20 | 21 | 22 | def _get_clip_prompt_embeds( 23 | tokenizer, 24 | text_encoder, 25 | prompt: Union[str, List[str]], 26 | num_images_per_prompt: int = 1, 27 | device: Optional[torch.device] = None, 28 | ): 29 | device = device or text_encoder.device 30 | dtype = text_encoder.dtype 31 | 32 | prompt = [prompt] if isinstance(prompt, str) else prompt 33 | batch_size = len(prompt) 34 | 35 | text_inputs = tokenizer( 36 | prompt, 37 | padding="max_length", 38 | max_length=text_encoder.config.max_position_embeddings, 39 | truncation=True, 40 | return_overflowing_tokens=False, 41 | return_length=False, 42 | return_tensors="pt", 43 | ) 44 | 45 | text_input_ids = text_inputs.input_ids 46 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) 47 | 48 | # Use pooled output of CLIPTextModel 49 | prompt_embeds = prompt_embeds.pooler_output 50 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 51 | 52 | # duplicate text embeddings for each generation per prompt, using mps friendly method 53 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 54 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 55 | 56 | return prompt_embeds 57 | 58 | 59 | def _get_t5_prompt_embeds( 60 | tokenizer, 61 | text_encoder, 62 | prompt: Union[str, List[str]] = None, 63 | num_images_per_prompt: int = 1, 64 | max_sequence_length: int = 512, 65 | device: Optional[torch.device] = None, 66 | dtype: Optional[torch.dtype] = None, 67 | ): 68 | device = device or text_encoder.device 69 | dtype = dtype or text_encoder.dtype 70 | 71 | prompt = [prompt] if isinstance(prompt, str) else prompt 72 | batch_size = len(prompt) 73 | 74 | text_inputs = tokenizer( 75 | prompt, 76 | padding="max_length", 77 | max_length=max_sequence_length, 78 | truncation=True, 79 | return_length=False, 80 | return_overflowing_tokens=False, 81 | return_tensors="pt", 82 | ) 83 | text_input_ids = text_inputs.input_ids 84 | 85 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0] 86 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 87 | 88 | _, seq_len, _ = prompt_embeds.shape 89 | 90 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 91 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 92 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 93 | 94 | return prompt_embeds 95 | 96 | 97 | def encode_prompt( 98 | tokenizers, 99 | text_encoders, 100 | prompt: Union[str, List[str]], 101 | prompt_2: Union[str, List[str]] = None, 102 | num_images_per_prompt: int = 1, 103 | max_sequence_length: int = 512, 104 | ): 105 | 106 | tokenizer_1, tokenizer_2 = tokenizers 107 | text_encoder_1, text_encoder_2 = text_encoders 108 | device = text_encoder_1.device 109 | dtype = text_encoder_1.dtype 110 | 111 | prompt = [prompt] if isinstance(prompt, str) else prompt 112 | prompt_2 = prompt_2 or prompt 113 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 114 | 115 | # We only use the pooled prompt output from the CLIPTextModel 116 | pooled_prompt_embeds = _get_clip_prompt_embeds( 117 | tokenizer=tokenizer_1, 118 | text_encoder=text_encoder_1, 119 | prompt=prompt, 120 | num_images_per_prompt=num_images_per_prompt, 121 | ) 122 | prompt_embeds = _get_t5_prompt_embeds( 123 | tokenizer=tokenizer_2, 124 | text_encoder=text_encoder_2, 125 | prompt=prompt_2, 126 | num_images_per_prompt=num_images_per_prompt, 127 | max_sequence_length=max_sequence_length, 128 | ) 129 | 130 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) 131 | 132 | return prompt_embeds, pooled_prompt_embeds, text_ids 133 | 134 | 135 | class CustomFluxPipeline(FluxPipeline): 136 | 137 | @staticmethod 138 | def _prepare_latent_image_ids(height, width, list_layer_box, device, dtype): 139 | 140 | latent_image_ids_list = [] 141 | for layer_idx in range(len(list_layer_box)): 142 | if list_layer_box[layer_idx] == None: 143 | continue 144 | else: 145 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) # [h/2, w/2, 3] 146 | latent_image_ids[..., 0] = layer_idx # use the first dimension for layer representation 147 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] 148 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] 149 | 150 | x1, y1, x2, y2 = list_layer_box[layer_idx] 151 | x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16 152 | latent_image_ids = latent_image_ids[y1:y2, x1:x2, :] 153 | 154 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape 155 | latent_image_ids = latent_image_ids.reshape( 156 | latent_image_id_height * latent_image_id_width, latent_image_id_channels 157 | ) 158 | 159 | latent_image_ids_list.append(latent_image_ids) 160 | 161 | full_latent_image_ids = torch.cat(latent_image_ids_list, dim=0) 162 | 163 | return full_latent_image_ids.to(device=device, dtype=dtype) 164 | 165 | def prepare_latents( 166 | self, 167 | batch_size, 168 | num_layers, 169 | num_channels_latents, 170 | height, 171 | width, 172 | list_layer_box, 173 | dtype, 174 | device, 175 | generator, 176 | latents=None, 177 | ): 178 | height = 2 * (int(height) // self.vae_scale_factor) 179 | width = 2 * (int(width) // self.vae_scale_factor) 180 | 181 | shape = (batch_size, num_layers, num_channels_latents, height, width) 182 | 183 | if latents is not None: 184 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 185 | return latents.to(device=device, dtype=dtype), latent_image_ids 186 | 187 | if isinstance(generator, list) and len(generator) != batch_size: 188 | raise ValueError( 189 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 190 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 191 | ) 192 | 193 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # [bs, f, c_latent, h, w] 194 | 195 | latent_image_ids = self._prepare_latent_image_ids(height, width, list_layer_box, device, dtype) 196 | 197 | return latents, latent_image_ids 198 | 199 | @torch.no_grad() 200 | def __call__( 201 | self, 202 | prompt: Union[str, List[str]] = None, 203 | prompt_2: Optional[Union[str, List[str]]] = None, 204 | validation_box: List[tuple] = None, 205 | height: Optional[int] = None, 206 | width: Optional[int] = None, 207 | num_inference_steps: int = 28, 208 | timesteps: List[int] = None, 209 | guidance_scale: float = 3.5, 210 | num_images_per_prompt: Optional[int] = 1, 211 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 212 | latents: Optional[torch.FloatTensor] = None, 213 | prompt_embeds: Optional[torch.FloatTensor] = None, 214 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 215 | output_type: Optional[str] = "pil", 216 | return_dict: bool = True, 217 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 218 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 219 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 220 | max_sequence_length: int = 512, 221 | num_layers: int = 5, 222 | sdxl_vae: nn.Module = None, 223 | transparent_decoder: nn.Module = None, 224 | ): 225 | r""" 226 | Function invoked when calling the pipeline for generation. 227 | 228 | Args: 229 | prompt (`str` or `List[str]`, *optional*): 230 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 231 | instead. 232 | prompt_2 (`str` or `List[str]`, *optional*): 233 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 234 | will be used instead 235 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 236 | The height in pixels of the generated image. This is set to 1024 by default for the best results. 237 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 238 | The width in pixels of the generated image. This is set to 1024 by default for the best results. 239 | num_inference_steps (`int`, *optional*, defaults to 50): 240 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 241 | expense of slower inference. 242 | timesteps (`List[int]`, *optional*): 243 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 244 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 245 | passed will be used. Must be in descending order. 246 | guidance_scale (`float`, *optional*, defaults to 7.0): 247 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 248 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 249 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 250 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 251 | usually at the expense of lower image quality. 252 | num_images_per_prompt (`int`, *optional*, defaults to 1): 253 | The number of images to generate per prompt. 254 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 255 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 256 | to make generation deterministic. 257 | latents (`torch.FloatTensor`, *optional*): 258 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 259 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 260 | tensor will ge generated by sampling using the supplied random `generator`. 261 | prompt_embeds (`torch.FloatTensor`, *optional*): 262 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 263 | provided, text embeddings will be generated from `prompt` input argument. 264 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 265 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 266 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 267 | output_type (`str`, *optional*, defaults to `"pil"`): 268 | The output format of the generate image. Choose between 269 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 270 | return_dict (`bool`, *optional*, defaults to `True`): 271 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. 272 | joint_attention_kwargs (`dict`, *optional*): 273 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 274 | `self.processor` in 275 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 276 | callback_on_step_end (`Callable`, *optional*): 277 | A function that calls at the end of each denoising steps during the inference. The function is called 278 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 279 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 280 | `callback_on_step_end_tensor_inputs`. 281 | callback_on_step_end_tensor_inputs (`List`, *optional*): 282 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 283 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 284 | `._callback_tensor_inputs` attribute of your pipeline class. 285 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. 286 | 287 | Examples: 288 | 289 | Returns: 290 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` 291 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated 292 | images. 293 | """ 294 | 295 | height = height or self.default_sample_size * self.vae_scale_factor 296 | width = width or self.default_sample_size * self.vae_scale_factor 297 | 298 | # 1. Check inputs. Raise error if not correct 299 | self.check_inputs( 300 | prompt, 301 | prompt_2, 302 | height, 303 | width, 304 | prompt_embeds=prompt_embeds, 305 | pooled_prompt_embeds=pooled_prompt_embeds, 306 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 307 | max_sequence_length=max_sequence_length, 308 | ) 309 | 310 | self._guidance_scale = guidance_scale 311 | self._joint_attention_kwargs = joint_attention_kwargs 312 | self._interrupt = False 313 | 314 | # 2. Define call parameters 315 | if prompt is not None and isinstance(prompt, str): 316 | batch_size = 1 317 | elif prompt is not None and isinstance(prompt, list): 318 | batch_size = len(prompt) 319 | else: 320 | batch_size = prompt_embeds.shape[0] 321 | 322 | device = self._execution_device 323 | 324 | lora_scale = ( 325 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 326 | ) 327 | ( 328 | prompt_embeds, 329 | pooled_prompt_embeds, 330 | text_ids, 331 | ) = self.encode_prompt( 332 | prompt=prompt, 333 | prompt_2=prompt_2, 334 | prompt_embeds=prompt_embeds, 335 | pooled_prompt_embeds=pooled_prompt_embeds, 336 | device=device, 337 | num_images_per_prompt=num_images_per_prompt, 338 | max_sequence_length=max_sequence_length, 339 | lora_scale=lora_scale, 340 | ) 341 | 342 | # 4. Prepare latent variables 343 | num_channels_latents = self.transformer.config.in_channels // 4 344 | latents, latent_image_ids = self.prepare_latents( 345 | batch_size * num_images_per_prompt, 346 | num_layers, 347 | num_channels_latents, 348 | height, 349 | width, 350 | validation_box, 351 | prompt_embeds.dtype, 352 | device, 353 | generator, 354 | latents, 355 | ) 356 | 357 | # 5. Prepare timesteps 358 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 359 | image_seq_len = latent_image_ids.shape[0] # ??? 360 | mu = calculate_shift( 361 | image_seq_len, 362 | self.scheduler.config.base_image_seq_len, 363 | self.scheduler.config.max_image_seq_len, 364 | self.scheduler.config.base_shift, 365 | self.scheduler.config.max_shift, 366 | ) 367 | timesteps, num_inference_steps = retrieve_timesteps( 368 | self.scheduler, 369 | num_inference_steps, 370 | device, 371 | timesteps, 372 | sigmas, 373 | mu=mu, 374 | ) 375 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 376 | self._num_timesteps = len(timesteps) 377 | 378 | # handle guidance 379 | if self.transformer.config.guidance_embeds: 380 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) 381 | guidance = guidance.expand(latents.shape[0]) 382 | else: 383 | guidance = None 384 | 385 | # 6. Denoising loop 386 | with self.progress_bar(total=num_inference_steps) as progress_bar: 387 | for i, t in enumerate(timesteps): 388 | if self.interrupt: 389 | continue 390 | 391 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 392 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 393 | 394 | noise_pred = self.transformer( 395 | hidden_states=latents, 396 | list_layer_box=validation_box, 397 | timestep=timestep / 1000, 398 | guidance=guidance, 399 | pooled_projections=pooled_prompt_embeds, 400 | encoder_hidden_states=prompt_embeds, 401 | txt_ids=text_ids, 402 | img_ids=latent_image_ids, 403 | joint_attention_kwargs=self.joint_attention_kwargs, 404 | return_dict=False, 405 | )[0] 406 | 407 | # compute the previous noisy sample x_t -> x_t-1 408 | latents_dtype = latents.dtype 409 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 410 | 411 | if latents.dtype != latents_dtype: 412 | if torch.backends.mps.is_available(): 413 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 414 | latents = latents.to(latents_dtype) 415 | 416 | if callback_on_step_end is not None: 417 | callback_kwargs = {} 418 | for k in callback_on_step_end_tensor_inputs: 419 | callback_kwargs[k] = locals()[k] 420 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 421 | 422 | latents = callback_outputs.pop("latents", latents) 423 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 424 | 425 | # call the callback, if provided 426 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 427 | progress_bar.update() 428 | 429 | if XLA_AVAILABLE: 430 | xm.mark_step() 431 | 432 | # create a grey latent 433 | bs, n_frames, channel_latent, height, width = latents.shape 434 | 435 | pixel_grey = torch.zeros(size=(bs*n_frames, 3, height*8, width*8), device=latents.device, dtype=latents.dtype) 436 | latent_grey = self.vae.encode(pixel_grey).latent_dist.sample() 437 | latent_grey = (latent_grey - self.vae.config.shift_factor) * self.vae.config.scaling_factor 438 | latent_grey = latent_grey.view(bs, n_frames, channel_latent, height, width) # [bs, f, c_latent, h, w] 439 | 440 | # fill in the latents 441 | for layer_idx in range(latent_grey.shape[1]): 442 | x1, y1, x2, y2 = validation_box[layer_idx] 443 | x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8 444 | latent_grey[:, layer_idx, :, y1:y2, x1:x2] = latents[:, layer_idx, :, y1:y2, x1:x2] 445 | latents = latent_grey 446 | 447 | if output_type == "latent": 448 | image = latents 449 | 450 | else: 451 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 452 | latents = latents.reshape(bs * n_frames, channel_latent, height, width) 453 | image = self.vae.decode(latents, return_dict=False)[0] 454 | if sdxl_vae is not None: 455 | sdxl_vae = sdxl_vae.to(dtype=image.dtype, device=image.device) 456 | sdxl_latents = sdxl_vae.encode(image).latent_dist.sample() 457 | transparent_decoder = transparent_decoder.to(dtype=image.dtype, device=image.device) 458 | result_list, vis_list = transparent_decoder(sdxl_vae, sdxl_latents) 459 | else: 460 | result_list, vis_list = None, None 461 | image = self.image_processor.postprocess(image, output_type=output_type) 462 | 463 | # Offload all models 464 | self.maybe_free_model_hooks() 465 | 466 | if not return_dict: 467 | return (image, result_list, vis_list) 468 | 469 | return FluxPipelineOutput(images=image), result_list, vis_list 470 | 471 | 472 | class CustomFluxPipelineCfg(FluxPipeline): 473 | 474 | @staticmethod 475 | def _prepare_latent_image_ids(height, width, list_layer_box, device, dtype): 476 | 477 | latent_image_ids_list = [] 478 | for layer_idx in range(len(list_layer_box)): 479 | if list_layer_box[layer_idx] == None: 480 | continue 481 | else: 482 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) # [h/2, w/2, 3] 483 | latent_image_ids[..., 0] = layer_idx # use the first dimension for layer representation 484 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] 485 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] 486 | 487 | x1, y1, x2, y2 = list_layer_box[layer_idx] 488 | x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16 489 | latent_image_ids = latent_image_ids[y1:y2, x1:x2, :] 490 | 491 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape 492 | latent_image_ids = latent_image_ids.reshape( 493 | latent_image_id_height * latent_image_id_width, latent_image_id_channels 494 | ) 495 | 496 | latent_image_ids_list.append(latent_image_ids) 497 | 498 | full_latent_image_ids = torch.cat(latent_image_ids_list, dim=0) 499 | 500 | return full_latent_image_ids.to(device=device, dtype=dtype) 501 | 502 | def prepare_latents( 503 | self, 504 | batch_size, 505 | num_layers, 506 | num_channels_latents, 507 | height, 508 | width, 509 | list_layer_box, 510 | dtype, 511 | device, 512 | generator, 513 | latents=None, 514 | ): 515 | height = 2 * (int(height) // self.vae_scale_factor) 516 | width = 2 * (int(width) // self.vae_scale_factor) 517 | 518 | shape = (batch_size, num_layers, num_channels_latents, height, width) 519 | 520 | if latents is not None: 521 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 522 | return latents.to(device=device, dtype=dtype), latent_image_ids 523 | 524 | if isinstance(generator, list) and len(generator) != batch_size: 525 | raise ValueError( 526 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 527 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 528 | ) 529 | 530 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # [bs, n_layers, c_latent, h, w] 531 | 532 | latent_image_ids = self._prepare_latent_image_ids(height, width, list_layer_box, device, dtype) 533 | 534 | return latents, latent_image_ids 535 | 536 | @torch.no_grad() 537 | def __call__( 538 | self, 539 | prompt: Union[str, List[str]] = None, 540 | prompt_2: Optional[Union[str, List[str]]] = None, 541 | validation_box: List[tuple] = None, 542 | height: Optional[int] = None, 543 | width: Optional[int] = None, 544 | num_inference_steps: int = 28, 545 | timesteps: List[int] = None, 546 | guidance_scale: float = 3.5, 547 | true_gs: float = 3.5, 548 | num_images_per_prompt: Optional[int] = 1, 549 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 550 | latents: Optional[torch.FloatTensor] = None, 551 | prompt_embeds: Optional[torch.FloatTensor] = None, 552 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 553 | output_type: Optional[str] = "pil", 554 | return_dict: bool = True, 555 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 556 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 557 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 558 | max_sequence_length: int = 512, 559 | num_layers: int = 5, 560 | transparent_decoder: nn.Module = None, 561 | ): 562 | r""" 563 | Function invoked when calling the pipeline for generation. 564 | 565 | Args: 566 | prompt (`str` or `List[str]`, *optional*): 567 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 568 | instead. 569 | prompt_2 (`str` or `List[str]`, *optional*): 570 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 571 | will be used instead 572 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 573 | The height in pixels of the generated image. This is set to 1024 by default for the best results. 574 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 575 | The width in pixels of the generated image. This is set to 1024 by default for the best results. 576 | num_inference_steps (`int`, *optional*, defaults to 50): 577 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 578 | expense of slower inference. 579 | timesteps (`List[int]`, *optional*): 580 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 581 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 582 | passed will be used. Must be in descending order. 583 | guidance_scale (`float`, *optional*, defaults to 7.0): 584 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 585 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 586 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 587 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 588 | usually at the expense of lower image quality. 589 | num_images_per_prompt (`int`, *optional*, defaults to 1): 590 | The number of images to generate per prompt. 591 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 592 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 593 | to make generation deterministic. 594 | latents (`torch.FloatTensor`, *optional*): 595 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 596 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 597 | tensor will ge generated by sampling using the supplied random `generator`. 598 | prompt_embeds (`torch.FloatTensor`, *optional*): 599 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 600 | provided, text embeddings will be generated from `prompt` input argument. 601 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 602 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 603 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 604 | output_type (`str`, *optional*, defaults to `"pil"`): 605 | The output format of the generate image. Choose between 606 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 607 | return_dict (`bool`, *optional*, defaults to `True`): 608 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. 609 | joint_attention_kwargs (`dict`, *optional*): 610 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 611 | `self.processor` in 612 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 613 | callback_on_step_end (`Callable`, *optional*): 614 | A function that calls at the end of each denoising steps during the inference. The function is called 615 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 616 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 617 | `callback_on_step_end_tensor_inputs`. 618 | callback_on_step_end_tensor_inputs (`List`, *optional*): 619 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 620 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 621 | `._callback_tensor_inputs` attribute of your pipeline class. 622 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. 623 | 624 | Examples: 625 | 626 | Returns: 627 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` 628 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated 629 | images. 630 | """ 631 | 632 | height = height or self.default_sample_size * self.vae_scale_factor 633 | width = width or self.default_sample_size * self.vae_scale_factor 634 | 635 | # 1. Check inputs. Raise error if not correct 636 | self.check_inputs( 637 | prompt, 638 | prompt_2, 639 | height, 640 | width, 641 | prompt_embeds=prompt_embeds, 642 | pooled_prompt_embeds=pooled_prompt_embeds, 643 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 644 | max_sequence_length=max_sequence_length, 645 | ) 646 | 647 | self._guidance_scale = guidance_scale 648 | self._joint_attention_kwargs = joint_attention_kwargs 649 | self._interrupt = False 650 | 651 | # 2. Define call parameters 652 | if prompt is not None and isinstance(prompt, str): 653 | batch_size = 1 654 | elif prompt is not None and isinstance(prompt, list): 655 | batch_size = len(prompt) 656 | else: 657 | batch_size = prompt_embeds.shape[0] 658 | 659 | device = self._execution_device 660 | 661 | lora_scale = ( 662 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 663 | ) 664 | ( 665 | prompt_embeds, 666 | pooled_prompt_embeds, 667 | text_ids, 668 | ) = self.encode_prompt( 669 | prompt=prompt, 670 | prompt_2=prompt_2, 671 | prompt_embeds=prompt_embeds, 672 | pooled_prompt_embeds=pooled_prompt_embeds, 673 | device=device, 674 | num_images_per_prompt=num_images_per_prompt, 675 | max_sequence_length=max_sequence_length, 676 | lora_scale=lora_scale, 677 | ) 678 | ( 679 | neg_prompt_embeds, 680 | neg_pooled_prompt_embeds, 681 | neg_text_ids, 682 | ) = self.encode_prompt( 683 | prompt="", 684 | prompt_2=None, 685 | device=device, 686 | num_images_per_prompt=num_images_per_prompt, 687 | max_sequence_length=max_sequence_length, 688 | lora_scale=lora_scale, 689 | ) 690 | 691 | # 4. Prepare latent variables 692 | num_channels_latents = self.transformer.config.in_channels // 4 693 | latents, latent_image_ids = self.prepare_latents( 694 | batch_size * num_images_per_prompt, 695 | num_layers, 696 | num_channels_latents, 697 | height, 698 | width, 699 | validation_box, 700 | prompt_embeds.dtype, 701 | device, 702 | generator, 703 | latents, 704 | ) 705 | 706 | # 5. Prepare timesteps 707 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 708 | image_seq_len = latent_image_ids.shape[0] 709 | mu = calculate_shift( 710 | image_seq_len, 711 | self.scheduler.config.base_image_seq_len, 712 | self.scheduler.config.max_image_seq_len, 713 | self.scheduler.config.base_shift, 714 | self.scheduler.config.max_shift, 715 | ) 716 | timesteps, num_inference_steps = retrieve_timesteps( 717 | self.scheduler, 718 | num_inference_steps, 719 | device, 720 | timesteps, 721 | sigmas, 722 | mu=mu, 723 | ) 724 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 725 | self._num_timesteps = len(timesteps) 726 | 727 | # handle guidance 728 | if self.transformer.config.guidance_embeds: 729 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) 730 | guidance = guidance.expand(latents.shape[0]) 731 | else: 732 | guidance = None 733 | 734 | # 6. Denoising loop 735 | with self.progress_bar(total=num_inference_steps) as progress_bar: 736 | for i, t in enumerate(timesteps): 737 | if self.interrupt: 738 | continue 739 | 740 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 741 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 742 | 743 | noise_pred = self.transformer( 744 | hidden_states=latents, 745 | list_layer_box=validation_box, 746 | timestep=timestep / 1000, 747 | guidance=guidance, 748 | pooled_projections=pooled_prompt_embeds, 749 | encoder_hidden_states=prompt_embeds, 750 | txt_ids=text_ids, 751 | img_ids=latent_image_ids, 752 | joint_attention_kwargs=self.joint_attention_kwargs, 753 | return_dict=False, 754 | )[0] 755 | 756 | neg_noise_pred = self.transformer( 757 | hidden_states=latents, 758 | list_layer_box=validation_box, 759 | timestep=timestep / 1000, 760 | guidance=guidance, 761 | pooled_projections=neg_pooled_prompt_embeds, 762 | encoder_hidden_states=neg_prompt_embeds, 763 | txt_ids=neg_text_ids, 764 | img_ids=latent_image_ids, 765 | joint_attention_kwargs=self.joint_attention_kwargs, 766 | return_dict=False, 767 | )[0] 768 | 769 | noise_pred = neg_noise_pred + true_gs * (noise_pred - neg_noise_pred) 770 | 771 | # compute the previous noisy sample x_t -> x_t-1 772 | latents_dtype = latents.dtype 773 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 774 | 775 | if latents.dtype != latents_dtype: 776 | if torch.backends.mps.is_available(): 777 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 778 | latents = latents.to(latents_dtype) 779 | 780 | if callback_on_step_end is not None: 781 | callback_kwargs = {} 782 | for k in callback_on_step_end_tensor_inputs: 783 | callback_kwargs[k] = locals()[k] 784 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 785 | 786 | latents = callback_outputs.pop("latents", latents) 787 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 788 | 789 | # call the callback, if provided 790 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 791 | progress_bar.update() 792 | 793 | if XLA_AVAILABLE: 794 | xm.mark_step() 795 | 796 | # create a grey latent 797 | bs, n_layers, channel_latent, height, width = latents.shape 798 | 799 | pixel_grey = torch.zeros(size=(bs*n_layers, 3, height*8, width*8), device=latents.device, dtype=latents.dtype) 800 | latent_grey = self.vae.encode(pixel_grey).latent_dist.sample() 801 | latent_grey = (latent_grey - self.vae.config.shift_factor) * self.vae.config.scaling_factor 802 | latent_grey = latent_grey.view(bs, n_layers, channel_latent, height, width) # [bs, n_layers, c_latent, h, w] 803 | 804 | # fill in the latents 805 | for layer_idx in range(latent_grey.shape[1]): 806 | if validation_box[layer_idx] == None: 807 | continue 808 | x1, y1, x2, y2 = validation_box[layer_idx] 809 | x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8 810 | latent_grey[:, layer_idx, :, y1:y2, x1:x2] = latents[:, layer_idx, :, y1:y2, x1:x2] 811 | latents = latent_grey 812 | 813 | if output_type == "latent": 814 | image = latents 815 | 816 | else: 817 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 818 | latents = latents.reshape(bs * n_layers, channel_latent, height, width) 819 | latents_segs = torch.split(latents, 16, dim=0) ### split latents by 16 to avoid odd purple output 820 | image_segs = [self.vae.decode(latents_seg, return_dict=False)[0] for latents_seg in latents_segs] 821 | image = torch.cat(image_segs, dim=0) 822 | if transparent_decoder is not None: 823 | transparent_decoder = transparent_decoder.to(dtype=image.dtype, device=image.device) 824 | 825 | decoded_fg, decoded_alpha = transparent_decoder(latents, [validation_box]) 826 | decoded_alpha = (decoded_alpha + 1.0) / 2.0 827 | decoded_alpha = torch.clamp(decoded_alpha, min=0.0, max=1.0).permute(0, 2, 3, 1) 828 | 829 | decoded_fg = (decoded_fg + 1.0) / 2.0 830 | decoded_fg = torch.clamp(decoded_fg, min=0.0, max=1.0).permute(0, 2, 3, 1) 831 | 832 | vis_list = None 833 | png = torch.cat([decoded_fg, decoded_alpha], dim=3) 834 | result_list = (png * 255.0).detach().cpu().float().numpy().clip(0, 255).astype(np.uint8) 835 | else: 836 | result_list, vis_list = None, None 837 | image = self.image_processor.postprocess(image, output_type=output_type) 838 | 839 | # Offload all models 840 | self.maybe_free_model_hooks() 841 | 842 | if not return_dict: 843 | return (image, result_list, vis_list, latents) 844 | 845 | return FluxPipelineOutput(images=image), result_list, vis_list, latents --------------------------------------------------------------------------------