├── requirements.txt
├── LICENSE
├── utils.py
├── .gitignore
├── readme.md
├── demos_opacity.py
├── demos_occlusion.py
└── larender_sdxl.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | diffusers
4 | spacy
5 | xformers
6 | transformers==4.49
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Xiaohang Zhan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import random
4 | import torch
5 |
6 |
7 | class DependencyParsing():
8 | def __init__(self):
9 | import spacy
10 | self.nlp = spacy.load("en_core_web_sm")
11 | # if it reports cannot find "en_core_web_sm", try to load as below:
12 | # self.nlp = spacy.load("/usr/local/lib/python3.10/site-packages/en_core_web_sm/en_core_web_sm-3.8.0") # the path is an example
13 |
14 | def parse(self, text, never_empty):
15 | doc = self.nlp(text)
16 | indices = []
17 | subjects = []
18 | for token in doc:
19 | # noun of a sentence or noun in a phrase
20 | if (token.dep_ in ["nsubj", "nsubjpass"]) or (token.dep_ == "ROOT" and token.pos_ in ["NOUN", "PROPN"]):
21 | indices.append(token.i)
22 | subjects.append(token.text)
23 | if never_empty and len(indices) == 0:
24 | indices.append(0)
25 | subjects.append(doc[0].text)
26 | return indices, subjects
27 |
28 |
29 | def torch_fix_seed(seed=42):
30 | # Python random
31 | random.seed(seed)
32 | # Numpy
33 | np.random.seed(seed)
34 | # Pytorch
35 | torch.manual_seed(seed)
36 | torch.cuda.manual_seed(seed)
37 | torch.backends.cudnn.deterministic = True
38 | torch.use_deterministic_algorithms = True
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | # input data, saved log, checkpoints
104 | input/
105 | saved/
106 | datasets/
107 |
108 | # editor, os cache directory
109 | .vscode/
110 | .idea/
111 | __MACOSX/
112 |
113 | .DS_Store
114 | results/
115 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | ## LaRender: Training-Free Occlusion Control in Image Generation via Latent Rendering
2 |
3 | This repository is the official implementation of our [LaRender](https://xiaohangzhan.github.io/projects/larender/) accepted by ICCV 2025 as oral presentation.
4 |
5 | > [**LaRender: Training-Free Occlusion Control in Image Generation via Latent Rendering**](https://xiaohangzhan.github.io/projects/larender/)
6 |
7 | > [Xiaohang Zhan](https://xiaohangzhan.github.io/),
8 | > [Dingming Liu](https://scholar.google.com/citations?user=C0ubzjQAAAAJ)
9 | >
**Tencent**
10 |
11 | > Project page: https://xiaohangzhan.github.io/projects/larender/
12 |
13 | > Paper: https://arxiv.org/pdf/2508.07647
14 |
15 | ## Introduction
16 |
17 | This paper proposes plug and play training-free method to control object occlusion relationships and strength of visual effects in pre-trained text to image models, via simply replacing the cross-attention layers to Latent Rendering layers.
18 |
19 |
20 | ## Get Started
21 |
22 | ### Environments
23 |
24 | ```bash
25 | git clone git@github.com:XiaohangZhan/larender.git
26 | cd larender
27 | conda create -n larender python==3.10
28 | conda activate larender
29 | pip install -r requirements.txt
30 | python -m spacy download en_core_web_sm
31 | ```
32 |
33 | ### Run
34 |
35 | 1. Object occlusion control demos.
36 | ```shell
37 | python demos_occlusion.py
38 | ```
39 |
40 | 2. Visual effect strength control demos.
41 | ```shell
42 | python demos_opacity.py
43 | ```
44 |
45 | Results saved in ``results/``.
46 |
47 | ### Tips
48 |
49 | 1. If multiple objects have similar semantics or appearance (like girl and boy, cat and dog), or with close positioning, the results might suffer from subject neglect or mixing, which has been mentioned in this paper: https://arxiv.org/pdf/2411.18301 .
50 | Simple workarounds include:
51 | - **Modify the prompt** to describe the missing object as precisely as possible.
52 | For example:
53 | ❌ *"a cat"* → ✅ *"a white cat standing"*
54 |
55 | - **Mention the missing object** in the prompt of the other object.
56 | For example:
57 | ❌ `["a girl in yellow dress", "a boy in blue T-shirt"]`
58 | ✅ `["a girl in yellow dress next to a boy", "a boy in blue T-shirt"]`
59 |
60 | 2. Unreasonable bounding boxes will affect the occlusion accuracy. If the occlusion is wrong, please adjust the bounding boxes. Note that the bounding boxes are rough hints of positioning, we are not pursuing accurate bounding box control in this paper.
--------------------------------------------------------------------------------
/demos_opacity.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from PIL import ImageDraw, ImageFont
4 | from utils import DependencyParsing
5 | from larender_sdxl import LaRender
6 |
7 |
8 | if __name__ == "__main__":
9 | # basic params
10 | negative_prompt = ''
11 | pretrain = 'IterComp' # IterComp or GLIGEN
12 | num_images_per_prompt = 10
13 | height = 1024
14 | width = 1024
15 | save_bbox_image = False
16 | seed = 0
17 |
18 | # dependency parsing params
19 | use_dependency_parsing = True
20 | ignore_dp_failure = True
21 |
22 | # larender params
23 | transmittance_use_bbox = True
24 | transmittance_use_attn_map = True
25 | density_scheduler = 'inverse_proportional' # other choices: unchanged, opaque
26 |
27 | #### Reproducible experiments, uncomment each block to proceed. ####
28 |
29 | dual_elements = False
30 |
31 | objects = ['the entrance of a convenience store', 'a clear glass door']
32 | locations = [[0.1, 0.9, 0.1, 0.9], [0.3, 0.8, 0.25, 0.75]]
33 | opacity = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
34 |
35 | # objects = ['sky', 'forest', 'lake with crystal clear water', 'reflection']
36 | # locations = [[0, 0.5, 0, 1], [0.25, 0.5, 0, 1],
37 | # [0.5, 1, 0, 1], [0.5, 1, 0, 1]]
38 | # opacity = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99]
39 |
40 | # objects = ['mountains with forest', 'fog']
41 | # locations = [[0.2, 1, 0, 1], [0, 0.8, 0, 1]]
42 | # opacity = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99]
43 |
44 | # objects = ['a night', 'a cathedral', 'lush palm trees']
45 | # locations = [[0, 1, 0, 1], [0.2, 0.8, 0.2, 0.8], [0.4, 0.9, 0, 1]]
46 | # opacity = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99]
47 |
48 | # objects = ['seaside reefs and turbulent waves at dusk', 'long exposure effects']
49 | # locations = [[0, 1, 0, 1], [0, 1, 0, 1]]
50 | # opacity = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99]
51 |
52 | # dual_elements = True # dual elements for this rain & sunlight example
53 | # objects = ['city view with street and cars', 'rain', 'sunlight']
54 | # locations = [[0, 1, 0, 1], [0, 1, 0, 1], [0, 0.4, 0.6, 1]]
55 | # opacity = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99]
56 |
57 | base_alpha = [0.8 for _ in objects]
58 | alpha_list = []
59 | if dual_elements:
60 | for opa1 in opacity:
61 | for opa2 in opacity:
62 | alpha_list.append(base_alpha[:-2] + [opa1, opa2])
63 | else:
64 | for opa in opacity:
65 | alpha_list.append(base_alpha[:-1] + [opa])
66 |
67 | ################ End of experimental settings ################
68 |
69 | assert all(
70 | 0 <= a < 1 for alpha in alpha_list for a in alpha), "alpha values must be in [0, 1)"
71 | # parse the indices of subject tokens
72 | if use_dependency_parsing:
73 | DP = DependencyParsing()
74 | indices, subjects = zip(
75 | *[DP.parse(obj, ignore_dp_failure) for obj in objects])
76 | assert all(len(ind) > 0 for ind in indices), \
77 | f"Dependency Parsing error: subject indices contain empty cases, please rephrase your prompts or set ignore_dp_failure=True, indices: {indices}"
78 | indices = [[ind + 1 for ind in ind_list]
79 | for ind_list in indices] # will include a token at 0
80 | subjects = ['+'.join(sub) for sub in subjects]
81 | else:
82 | indices = [[1], [1], [1], [1]] # manually assigned if not using DP
83 | # manually assigned if not using DP
84 | subjects = ['sky', 'forest', 'lake', 'reflection']
85 |
86 | save_name = f"larender_{pretrain}_{'-'.join(subjects)}"
87 |
88 | # init LaRender
89 | LR = LaRender(pretrain=pretrain,
90 | transmittance_use_bbox=transmittance_use_bbox,
91 | transmittance_use_attn_map=transmittance_use_attn_map,
92 | density_scheduler=density_scheduler)
93 |
94 | # run
95 | start_time = time.time()
96 | for alpha in alpha_list:
97 | images = LR.run(
98 | objects,
99 | locations,
100 | alpha,
101 | indices,
102 | num_images_per_prompt,
103 | height,
104 | width,
105 | negative_prompt,
106 | seed)
107 | print("Elapsed time:", time.time() - start_time, "seconds")
108 | os.makedirs(f"results/{save_name}", exist_ok=True)
109 |
110 | # save results
111 | for i, img in enumerate(images):
112 | if dual_elements:
113 | save_fn = f"results/{save_name}/{save_name}_{i:02d}_{alpha[-2]:.2f}_{alpha[-1]:.2f}.jpg"
114 | else:
115 | save_fn = f"results/{save_name}/{save_name}_{i:02d}_{alpha[-1]:.2f}.jpg"
116 | print(f'Save at: {save_fn}')
117 | img.save(save_fn)
118 |
119 | if save_bbox_image:
120 | draw = ImageDraw.Draw(img)
121 | width, height = img.width, img.height
122 | for obj, loc in zip(objects, locations):
123 | draw.rectangle((loc[2] * width, loc[0] * height, loc[3]
124 | * width, loc[1] * height), outline='red', width=2)
125 | font = ImageFont.load_default(size=32)
126 | text_position = (loc[2] * width + 1, loc[0] * height + 1)
127 | draw.text(text_position, obj,
128 | font=font, fill="red") # Text
129 | if dual_elements:
130 | img.save(
131 | f"results/{save_name}/{save_name}_box_{i:02d}_{alpha[-2]:.2f}_{alpha[-1]:.2f}.jpg")
132 | else:
133 | img.save(
134 | f"results/{save_name}/{save_name}_box_{i:02d}_{alpha[-1]:.2f}.jpg")
135 |
--------------------------------------------------------------------------------
/demos_occlusion.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from PIL import ImageDraw, ImageFont
4 | from utils import DependencyParsing
5 | from larender_sdxl import LaRender
6 |
7 |
8 | if __name__ == "__main__":
9 | # basic params
10 | negative_prompt = ''
11 | pretrain = 'IterComp' # IterComp or GLIGEN
12 | num_images_per_prompt = 10 # set to 1 for speed test
13 | height = 1024
14 | width = 1024
15 | save_bbox_image = True
16 | seed = 5
17 |
18 | # dependency parsing params
19 | use_dependency_parsing = True
20 | ignore_dp_failure = True
21 |
22 | # larender params
23 | transmittance_use_bbox = True
24 | transmittance_use_attn_map = True
25 | density_scheduler = 'inverse_proportional' # other choices: unchanged, opaque
26 |
27 | #### Reproducible experiments, uncomment each block to proceed. ####
28 |
29 | # Multiple "next to" are used here, because the cat is easily mixed into the giraffe, and the piano easily get lost. Never use "behind" or "in front of" to avoid introducing occlusion hints.
30 | objects = ["a large empty living room", "a brown piano", "a giraffe is standing next to a piano and a cat",
31 | "a sofa", "a ginger cat sitting", "a blue teddy bear next to a cat"]
32 | locations = [[0, 1, 0, 1], [0.4, 0.7, 0.2, 0.9], [0, 0.8, 0, 0.5], [0.6, 0.9, 0.05, 0.95],
33 | [0.5, 0.75, 0.5, 0.7], [0.5, 0.9, 0.7, 0.95]] # mode: [y1, y2, x1, x2], range: 0~1
34 |
35 | # objects = ["forest", 'a white cat standing on the ground next to a branch',
36 | # 'a brown dog sitting on the ground next to a branch', 'a long branch']
37 | # locations = [[0, 1, 0, 1], [0.3, 0.8, 0.1, 0.4],
38 | # [0.3, 0.8, 0.5, 0.8], [0.5, 0.6, 0, 1]]
39 |
40 | # objects = ["forest", 'a white cat standing on the ground next to a branch',
41 | # 'a long branch', 'a brown dog sitting next to a branch']
42 | # locations = [[0, 1, 0, 1], [0.3, 0.8, 0.1, 0.4],
43 | # [0.5, 0.6, 0, 1], [0.3, 0.8, 0.5, 0.8]]
44 |
45 | # objects = ["forest", 'a brown dog sitting',
46 | # 'a long branch', 'a white cat standing']
47 | # locations = [[0, 1, 0, 1], [0.3, 0.8, 0.5, 0.8],
48 | # [0.5, 0.6, 0, 1], [0.3, 0.8, 0.1, 0.4]]
49 |
50 | # objects = ["forest", 'a long branch', 'a brown dog sitting',
51 | # 'a white cat standing']
52 | # locations = [[0, 1, 0, 1],[0.5, 0.6, 0, 1], [0.3, 0.8, 0.55, 0.8],
53 | # [0.3, 0.8, 0.2, 0.4]]
54 |
55 | # objects = ['a giraffe', 'an airplane on the ground']
56 | # locations = [[0.2, 0.8, 0.4, 0.6], [0.5, 0.8, 0, 1]]
57 |
58 | # objects = ['a house', 'lawn', 'a giant moon']
59 | # locations = [[0.3, 0.7, 0.1, 0.9], [0.7, 1, 0, 1], [0.55, 0.8, 0.2, 0.45]]
60 |
61 | # objects = ['a cyan vase', 'a yellow clock']
62 | # locations = [[0.1, 0.8, 0.2, 0.5], [0.4, 0.8, 0.3, 0.7]] # try pretrain = 'GLIGEN' if bbox not accurate
63 |
64 | # objects = ['a yellow clock', 'a cyan vase']
65 | # locations = [[0.4, 0.8, 0.3, 0.7], [0.1, 0.8, 0.2, 0.5]]
66 |
67 | # objects = ['a girl', 'a refrigerator']
68 | # locations = [[0.2, 0.9, 0.4, 0.7], [0.3, 0.9, 0.4, 0.7]]
69 |
70 | # objects = ['a refrigerator', 'a girl']
71 | # locations = [[0.2, 0.9, 0.4, 0.7], [0.2, 0.9, 0.3, 0.6]]
72 |
73 | # objects = ['fence', 'a man', 'a cow next to a man']
74 | # locations = [[0.4, 1, 0, 1], [0.1, 1, 0.1, 0.5], [0.2, 1, 0.3, 0.8]]
75 |
76 | # objects = ['a cow next to a man', 'a man wearing white shirt', 'fence']
77 | # locations = [[0.2, 1, 0.3, 0.8], [0.1, 1, 0.1, 0.5], [0.4, 1, 0, 1]]
78 |
79 | # objects = ['park', 'trees', 'a fountain', 'a lion statue', 'bush']
80 | # locations = [[0, 1, 0, 1], [0, 0.7, 0, 1], [0, 0.7, 0.3, 0.7], [0.4, 0.7, 0.5, 0.8], [0.6, 0.9, 0.1, 0.9]]
81 |
82 | # objects = ['a brown teddy bear', 'a computer']
83 | # locations = [[0.4, 0.8, 0.2, 0.5], [0.3, 0.8, 0.3, 0.8]]
84 |
85 | # objects = ['a computer', 'a brown teddy bear']
86 | # locations = [[0.3, 0.8, 0.3, 0.8], [0.4, 0.8, 0.2, 0.5]]
87 |
88 | # objects = ['a piano', 'a giant Yamaha guitar']
89 | # locations = [[0.5, 0.9, 0.2, 0.8], [0.1, 0.7, 0.4, 0.6]]
90 |
91 | # objects = ['a giant Yamaha guitar', 'a piano']
92 | # locations = [[0.1, 0.7, 0.4, 0.6], [0.5, 0.9, 0.2, 0.8]]
93 |
94 | # objects = ['a boat', 'a bear']
95 | # locations = [[0.4, 0.8, 0.1, 0.9], [0.2, 0.8, 0.3, 0.6]]
96 |
97 | # objects = ['a bear', 'a boat']
98 | # locations = [[0.2, 0.8, 0.3, 0.6], [0.4, 0.8, 0.1, 0.9]]
99 |
100 | # objects = ['a girl in yellow dress next to a boy', 'a boy in blue T-shirt']
101 | # locations = [[0.2, 1, 0.2, 0.6], [0.1, 1, 0.4, 0.8]]
102 |
103 | # objects = ['a boy in blue T-shirt next to a girl', 'a girl in yellow dress']
104 | # locations = [[0.1, 1, 0.4, 0.8], [0.2, 1, 0.2, 0.6]]
105 |
106 | # for non-transparent objects, we can always set 0.8
107 | alpha = [0.8 for _ in objects]
108 |
109 | ################ End of experimental settings ################
110 |
111 | assert all(0 <= a < 1 for a in alpha), "alpha values must be in [0, 1)"
112 | # parse the indices of subject tokens
113 | if use_dependency_parsing:
114 | DP = DependencyParsing()
115 | indices, subjects = zip(
116 | *[DP.parse(obj, ignore_dp_failure) for obj in objects])
117 | assert all(len(ind) > 0 for ind in indices), \
118 | f"Dependency Parsing error: subject indices contain empty cases, please rephrase your prompts or set ignore_dp_failure=True, indices: {indices}"
119 | indices = [[ind + 1 for ind in ind_list]
120 | for ind_list in indices] # will include a token at 0
121 | subjects = ['+'.join(sub) for sub in subjects]
122 | else:
123 | indices = [[1], [1], [1], [1]] # manually assigned if not using DP
124 | # manually assigned if not using DP
125 | subjects = ['sky', 'forest', 'lake', 'reflection']
126 |
127 | save_name = f"larender_{pretrain}_{'-'.join(subjects)}"
128 |
129 | # init LaRender
130 | LR = LaRender(pretrain=pretrain,
131 | transmittance_use_bbox=transmittance_use_bbox,
132 | transmittance_use_attn_map=transmittance_use_attn_map,
133 | density_scheduler=density_scheduler)
134 |
135 | # run
136 | start_time = time.time()
137 | images = LR.run(
138 | objects,
139 | locations,
140 | alpha,
141 | indices,
142 | num_images_per_prompt,
143 | height,
144 | width,
145 | negative_prompt,
146 | seed)
147 | print("Elapsed time:", time.time() - start_time, "seconds")
148 | os.makedirs(f"results/{save_name}", exist_ok=True)
149 |
150 | # save results
151 | for i, img in enumerate(images):
152 | save_fn = f"results/{save_name}/{save_name}_{i:02d}.jpg"
153 | print(f'Save at: {save_fn}')
154 | img.save(save_fn)
155 |
156 | if save_bbox_image:
157 | draw = ImageDraw.Draw(img)
158 | width, height = img.width, img.height
159 | for obj, loc in zip(objects, locations):
160 | draw.rectangle((loc[2] * width, loc[0] * height, loc[3]
161 | * width, loc[1] * height), outline='red', width=2)
162 | font = ImageFont.load_default(size=32)
163 | text_position = (loc[2] * width + 1, loc[0] * height + 1)
164 | draw.text(text_position, obj, font=font, fill="red") # Text
165 | img.save(f"results/{save_name}/{save_name}_box_{i:02d}.jpg")
166 |
--------------------------------------------------------------------------------
/larender_sdxl.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | from diffusers import DiffusionPipeline
5 | from diffusers.schedulers import DPMSolverMultistepScheduler
6 | from utils import torch_fix_seed
7 |
8 |
9 | class LaRender():
10 | def __init__(self,
11 | pretrain='IterComp',
12 | transmittance_use_bbox=True,
13 | transmittance_use_attn_map=True,
14 | density_scheduler='unchanged'):
15 | '''
16 | Args:
17 | pretrain ('str', *optional*):
18 | Pretrain model type, support 'IterComp' or 'GLIGEN' now. 'IterComp' is T2I without location control, while 'GLIGEN' is a pre-trained location contorl model.
19 | transmittance_use_bbox (bool, *optional*):
20 | The flag whether to use bounding boxes in approximating transmittance maps. If set to False, objects in the back usually disappear.
21 | transmittance_use_attn_map (bool, *optional*):
22 | The flag whether to use attention maps in approximating transmittance maps.
23 | density_scheduler (str, *optional*):
24 | The type of density scheduler, support 'opaque', 'unchanged', 'inverse_proportional'.
25 |
26 | '''
27 | assert pretrain in ['IterComp', 'GLIGEN']
28 | assert density_scheduler in [
29 | 'opaque', 'unchanged', 'inverse_proportional']
30 | self.pretrain = pretrain
31 | self.transmittance_use_bbox = transmittance_use_bbox
32 | self.density_scheduler = density_scheduler
33 |
34 | if pretrain == 'IterComp':
35 | self.pipe = DiffusionPipeline.from_pretrained(
36 | "comin/IterComp",
37 | torch_dtype=torch.float16,
38 | use_safetensors=True)
39 | else:
40 | self.pipe = DiffusionPipeline.from_pretrained(
41 | "jiuntian/gligen-xl-1024", trust_remote_code=True, torch_dtype=torch.float16)
42 | self.pipe.to("cuda")
43 | self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
44 | self.pipe.scheduler.config, use_karras_sigmas=True)
45 | self.pipe.enable_xformers_memory_efficient_attention()
46 |
47 | self.isvanilla = True
48 | self.transmittance_use_attn_map = transmittance_use_attn_map
49 |
50 | for name, module in self.pipe.unet.named_modules():
51 | if "attn2" in name and module.__class__.__name__ == "Attention":
52 | module.forward = self.create_larender_forward(module)
53 | module.name = name
54 |
55 | def _forward_cross_attention(self, module, hidden_states, encoder_hidden_states):
56 | query = module.to_q(hidden_states)
57 | key = module.to_k(encoder_hidden_states)
58 | value = module.to_v(encoder_hidden_states)
59 | query = module.head_to_batch_dim(query)
60 | key = module.head_to_batch_dim(key)
61 | value = module.head_to_batch_dim(value)
62 | attention_probs = module.get_attention_scores(query, key)
63 | hidden_states = torch.bmm(attention_probs, value)
64 | hidden_states = module.batch_to_head_dim(hidden_states)
65 | hidden_states = module.to_out[0](hidden_states)
66 | hidden_states = module.to_out[1](hidden_states)
67 | return hidden_states, attention_probs
68 |
69 | def _bboxes_to_mask(self, bboxes, h, w, dtype, device):
70 | '''
71 | Args:
72 | bboxes (List[List[y1, y2, x1, x2]]):
73 | The bounding boxes in y1, y2, x1, x2 format.
74 | '''
75 | masks = []
76 | for bbox in bboxes:
77 | M = torch.zeros((h, w), dtype=dtype, device=device)
78 | upper = int(h * bbox[0])
79 | lower = int(h * bbox[1])
80 | left = int(w * bbox[2])
81 | right = int(w * bbox[3])
82 | assert lower > upper and right > left, f"Invalid bounding box: {bbox}"
83 | M[upper:lower, left:right] = 1
84 | masks.append(M)
85 | return masks
86 |
87 | def create_larender_forward(self, module):
88 | def forward(hidden_states, encoder_hidden_states=None, **kwargs):
89 | return self.larender_core_process(module, hidden_states, encoder_hidden_states)
90 | return forward
91 |
92 | def _prompt_embedding(self, objects, negative_prompt):
93 | prompt_embeds_list = []
94 | negative_prompt_embeds_list = []
95 | pooled_prompt_embeds_list = []
96 | negative_pooled_prompt_embeds_list = []
97 | for p in objects:
98 | (
99 | prompt_embeds,
100 | negative_prompt_embeds,
101 | pooled_prompt_embeds,
102 | negative_pooled_prompt_embeds,
103 | ) = self.pipe.encode_prompt(
104 | prompt=p,
105 | do_classifier_free_guidance=True,
106 | negative_prompt=negative_prompt,
107 | )
108 | prompt_embeds_list.append(prompt_embeds)
109 | negative_prompt_embeds_list.append(negative_prompt_embeds)
110 | pooled_prompt_embeds_list.append(pooled_prompt_embeds)
111 | negative_pooled_prompt_embeds_list.append(
112 | negative_pooled_prompt_embeds)
113 |
114 | prompt_embeds = torch.cat(prompt_embeds_list, dim=1)
115 | negative_prompt_embeds = torch.cat(negative_prompt_embeds_list, dim=1)
116 | pooled_prompt_embeds = sum(
117 | pooled_prompt_embeds_list) / len(pooled_prompt_embeds)
118 | negative_pooled_prompt_embeds = sum(
119 | negative_pooled_prompt_embeds_list) / len(negative_pooled_prompt_embeds)
120 |
121 | (
122 | _,
123 | _,
124 | pooled_prompt_embeds,
125 | negative_pooled_prompt_embeds,
126 | ) = self.pipe.encode_prompt(
127 | prompt=', '.join(objects),
128 | do_classifier_free_guidance=True,
129 | negative_prompt=negative_prompt,
130 | )
131 | return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
132 |
133 | def larender_core_process(self, module, hidden_states, encoder_hidden_states):
134 | # preparation
135 | locations = self.larender_inputs['locations']
136 | indices = self.larender_inputs['indices']
137 | densities = self.larender_inputs['densities']
138 | height = self.height
139 | width = self.width
140 | latent_h = int(math.sqrt(hidden_states.size()[1] * height / width))
141 | assert hidden_states.size()[1] % latent_h == 0, \
142 | "Cannot infer latent_h and latent_w, please check the code."
143 | latent_w = int(hidden_states.size()[1] / latent_h)
144 | num_inference_steps = len(self.pipe.scheduler.timesteps)
145 | assert hasattr(self.pipe.scheduler, 'step_index')
146 | current_count = self.pipe.scheduler.step_index if self.pipe.scheduler.step_index is not None else 0
147 | self.current_count = current_count # only used in visualization
148 | if self.density_scheduler == 'inverse_proportional':
149 | density_multiplier = num_inference_steps / (current_count + 1)
150 | densities_multiplied = [x * density_multiplier for x in densities]
151 | elif self.density_scheduler == 'opaque':
152 | densities_multiplied = [x * num_inference_steps for x in densities]
153 | else:
154 | densities_multiplied = densities
155 |
156 | token_max_length = self.pipe.tokenizer.model_max_length
157 | assert encoder_hidden_states.shape[1] % token_max_length == 0
158 |
159 | # object-wise cross-attention and transmittance map
160 | R = []
161 | dtype = hidden_states.dtype
162 | device = hidden_states.device
163 | if self.transmittance_use_bbox:
164 | M = self._bboxes_to_mask(
165 | locations, latent_h, latent_w, dtype, device)
166 | else:
167 | M = [torch.ones((latent_h, latent_w), dtype=dtype,
168 | device=device) for _ in locations]
169 | for i in range(self.num_objects):
170 | context = encoder_hidden_states[:, i *
171 | token_max_length: (i+1) * token_max_length, :]
172 | R_i, attn_probs = \
173 | self._forward_cross_attention(module, hidden_states, context)
174 | R.append(R_i.reshape(R_i.shape[0],
175 | latent_h, latent_w, R_i.shape[2]))
176 | if self.transmittance_use_attn_map:
177 | attn_map = attn_probs.reshape(
178 | attn_probs.shape[0], latent_h, latent_w, -1)
179 | attn_map = attn_map[..., indices[i]].mean(dim=-1).mean(dim=0)
180 | min_val = attn_map.min()
181 | max_val = attn_map.max()
182 | normalized_attn_map = (
183 | attn_map - min_val) / (max_val - min_val)
184 | M[i] *= normalized_attn_map
185 |
186 | # accumulated transmittance maps (visibility of planar i from the virtual camera)
187 | T = [torch.ones((latent_h, latent_w), dtype=dtype, device=device)
188 | for _ in range(self.num_objects)]
189 | for i in range(self.num_objects - 1, -1, -1): # top to bottom
190 | for j in range(i + 1, self.num_objects):
191 | # TODO
192 | T[i] *= torch.exp(-torch.tensor(densities_multiplied[j]) * M[j])
193 |
194 | # rendering
195 | S = 0
196 | R_out = 0
197 | for i in range(self.num_objects):
198 | contrib = T[i] * (1 - math.exp(-densities_multiplied[i])) * M[i]
199 | contrib = contrib[None, :, :, None] # unsqueeze 0, -1
200 | R_out += contrib * R[i]
201 | S += contrib
202 | S = torch.clamp(S, min=1e-6)
203 | R_out /= S
204 | return R_out.reshape(*hidden_states.shape)
205 |
206 | def run(self,
207 | objects,
208 | locations,
209 | alpha,
210 | indices,
211 | num_images_per_prompt=1,
212 | height=1024,
213 | width=1024,
214 | negative_prompt='',
215 | seed=0):
216 | '''
217 | Args:
218 | objects (List[str]):
219 | List of object prompts.
220 | locations (List[List[float]]):
221 | List of bounding boxes in [top, bottom, left, right] format.
222 | alpha ([List[float]):
223 | List of alpha of each object, values must be in [0, 1).
224 | indices (List[List[int]]):
225 | List of indices of subjects in each object prompt.
226 | Output:
227 | images (List[PIL.Image]):
228 | List of generated images.
229 | '''
230 | for a in alpha:
231 | assert a < 1 and a >= 0, "alpha must in range of [0, 1)"
232 | densities = [-np.log(1 - a) for a in alpha]
233 | self.num_objects = len(objects)
234 | assert len(locations) == len(objects) and len(
235 | densities) == len(objects)
236 | self.larender_inputs = dict(
237 | locations=locations, indices=indices, densities=densities)
238 |
239 | self.height = height
240 | self.width = width
241 |
242 | (prompt_embeds,
243 | negative_prompt_embeds,
244 | pooled_prompt_embeds,
245 | negative_pooled_prompt_embeds) = self._prompt_embedding(objects, negative_prompt)
246 |
247 | torch_fix_seed(seed)
248 |
249 | common_kwargs = dict(
250 | num_images_per_prompt=num_images_per_prompt,
251 | height=height,
252 | width=width,
253 | guidance_scale=7.0,
254 | prompt_embeds=prompt_embeds,
255 | negative_prompt_embeds=negative_prompt_embeds,
256 | pooled_prompt_embeds=pooled_prompt_embeds,
257 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds)
258 | if self.pretrain == 'IterComp':
259 | images = self.pipe(
260 | num_inference_steps=25, **common_kwargs).images
261 | else:
262 | images = self.pipe(
263 | num_inference_steps=50,
264 | gligen_scheduled_sampling_beta=0.4,
265 | gligen_boxes=[[loc[2], loc[0], loc[3], loc[1]]
266 | for loc in locations], # y1y2x1x2 -> x1y1x2y2
267 | gligen_phrases=objects,
268 | **common_kwargs).images
269 | return images
270 |
--------------------------------------------------------------------------------