├── requirements.txt ├── LICENSE ├── utils.py ├── .gitignore ├── readme.md ├── demos_opacity.py ├── demos_occlusion.py └── larender_sdxl.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | diffusers 4 | spacy 5 | xformers 6 | transformers==4.49 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Xiaohang Zhan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import random 4 | import torch 5 | 6 | 7 | class DependencyParsing(): 8 | def __init__(self): 9 | import spacy 10 | self.nlp = spacy.load("en_core_web_sm") 11 | # if it reports cannot find "en_core_web_sm", try to load as below: 12 | # self.nlp = spacy.load("/usr/local/lib/python3.10/site-packages/en_core_web_sm/en_core_web_sm-3.8.0") # the path is an example 13 | 14 | def parse(self, text, never_empty): 15 | doc = self.nlp(text) 16 | indices = [] 17 | subjects = [] 18 | for token in doc: 19 | # noun of a sentence or noun in a phrase 20 | if (token.dep_ in ["nsubj", "nsubjpass"]) or (token.dep_ == "ROOT" and token.pos_ in ["NOUN", "PROPN"]): 21 | indices.append(token.i) 22 | subjects.append(token.text) 23 | if never_empty and len(indices) == 0: 24 | indices.append(0) 25 | subjects.append(doc[0].text) 26 | return indices, subjects 27 | 28 | 29 | def torch_fix_seed(seed=42): 30 | # Python random 31 | random.seed(seed) 32 | # Numpy 33 | np.random.seed(seed) 34 | # Pytorch 35 | torch.manual_seed(seed) 36 | torch.cuda.manual_seed(seed) 37 | torch.backends.cudnn.deterministic = True 38 | torch.use_deterministic_algorithms = True 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # input data, saved log, checkpoints 104 | input/ 105 | saved/ 106 | datasets/ 107 | 108 | # editor, os cache directory 109 | .vscode/ 110 | .idea/ 111 | __MACOSX/ 112 | 113 | .DS_Store 114 | results/ 115 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## LaRender: Training-Free Occlusion Control in Image Generation via Latent Rendering 2 | 3 | This repository is the official implementation of our [LaRender](https://xiaohangzhan.github.io/projects/larender/) accepted by ICCV 2025 as oral presentation. 4 | 5 | > [**LaRender: Training-Free Occlusion Control in Image Generation via Latent Rendering**](https://xiaohangzhan.github.io/projects/larender/) 6 | 7 | > [Xiaohang Zhan](https://xiaohangzhan.github.io/), 8 | > [Dingming Liu](https://scholar.google.com/citations?user=C0ubzjQAAAAJ) 9 | >
**Tencent**
10 | 11 | > Project page: https://xiaohangzhan.github.io/projects/larender/ 12 | 13 | > Paper: https://arxiv.org/pdf/2508.07647 14 | 15 | ## Introduction 16 | 17 | This paper proposes plug and play training-free method to control object occlusion relationships and strength of visual effects in pre-trained text to image models, via simply replacing the cross-attention layers to Latent Rendering layers. 18 | 19 | 20 | ## Get Started 21 | 22 | ### Environments 23 | 24 | ```bash 25 | git clone git@github.com:XiaohangZhan/larender.git 26 | cd larender 27 | conda create -n larender python==3.10 28 | conda activate larender 29 | pip install -r requirements.txt 30 | python -m spacy download en_core_web_sm 31 | ``` 32 | 33 | ### Run 34 | 35 | 1. Object occlusion control demos. 36 | ```shell 37 | python demos_occlusion.py 38 | ``` 39 | 40 | 2. Visual effect strength control demos. 41 | ```shell 42 | python demos_opacity.py 43 | ``` 44 | 45 | Results saved in ``results/``. 46 | 47 | ### Tips 48 | 49 | 1. If multiple objects have similar semantics or appearance (like girl and boy, cat and dog), or with close positioning, the results might suffer from subject neglect or mixing, which has been mentioned in this paper: https://arxiv.org/pdf/2411.18301 . 50 | Simple workarounds include: 51 | - **Modify the prompt** to describe the missing object as precisely as possible. 52 | For example: 53 | ❌ *"a cat"* → ✅ *"a white cat standing"* 54 | 55 | - **Mention the missing object** in the prompt of the other object. 56 | For example: 57 | ❌ `["a girl in yellow dress", "a boy in blue T-shirt"]` 58 | ✅ `["a girl in yellow dress next to a boy", "a boy in blue T-shirt"]` 59 | 60 | 2. Unreasonable bounding boxes will affect the occlusion accuracy. If the occlusion is wrong, please adjust the bounding boxes. Note that the bounding boxes are rough hints of positioning, we are not pursuing accurate bounding box control in this paper. -------------------------------------------------------------------------------- /demos_opacity.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from PIL import ImageDraw, ImageFont 4 | from utils import DependencyParsing 5 | from larender_sdxl import LaRender 6 | 7 | 8 | if __name__ == "__main__": 9 | # basic params 10 | negative_prompt = '' 11 | pretrain = 'IterComp' # IterComp or GLIGEN 12 | num_images_per_prompt = 10 13 | height = 1024 14 | width = 1024 15 | save_bbox_image = False 16 | seed = 0 17 | 18 | # dependency parsing params 19 | use_dependency_parsing = True 20 | ignore_dp_failure = True 21 | 22 | # larender params 23 | transmittance_use_bbox = True 24 | transmittance_use_attn_map = True 25 | density_scheduler = 'inverse_proportional' # other choices: unchanged, opaque 26 | 27 | #### Reproducible experiments, uncomment each block to proceed. #### 28 | 29 | dual_elements = False 30 | 31 | objects = ['the entrance of a convenience store', 'a clear glass door'] 32 | locations = [[0.1, 0.9, 0.1, 0.9], [0.3, 0.8, 0.25, 0.75]] 33 | opacity = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 34 | 35 | # objects = ['sky', 'forest', 'lake with crystal clear water', 'reflection'] 36 | # locations = [[0, 0.5, 0, 1], [0.25, 0.5, 0, 1], 37 | # [0.5, 1, 0, 1], [0.5, 1, 0, 1]] 38 | # opacity = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99] 39 | 40 | # objects = ['mountains with forest', 'fog'] 41 | # locations = [[0.2, 1, 0, 1], [0, 0.8, 0, 1]] 42 | # opacity = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99] 43 | 44 | # objects = ['a night', 'a cathedral', 'lush palm trees'] 45 | # locations = [[0, 1, 0, 1], [0.2, 0.8, 0.2, 0.8], [0.4, 0.9, 0, 1]] 46 | # opacity = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99] 47 | 48 | # objects = ['seaside reefs and turbulent waves at dusk', 'long exposure effects'] 49 | # locations = [[0, 1, 0, 1], [0, 1, 0, 1]] 50 | # opacity = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99] 51 | 52 | # dual_elements = True # dual elements for this rain & sunlight example 53 | # objects = ['city view with street and cars', 'rain', 'sunlight'] 54 | # locations = [[0, 1, 0, 1], [0, 1, 0, 1], [0, 0.4, 0.6, 1]] 55 | # opacity = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99] 56 | 57 | base_alpha = [0.8 for _ in objects] 58 | alpha_list = [] 59 | if dual_elements: 60 | for opa1 in opacity: 61 | for opa2 in opacity: 62 | alpha_list.append(base_alpha[:-2] + [opa1, opa2]) 63 | else: 64 | for opa in opacity: 65 | alpha_list.append(base_alpha[:-1] + [opa]) 66 | 67 | ################ End of experimental settings ################ 68 | 69 | assert all( 70 | 0 <= a < 1 for alpha in alpha_list for a in alpha), "alpha values must be in [0, 1)" 71 | # parse the indices of subject tokens 72 | if use_dependency_parsing: 73 | DP = DependencyParsing() 74 | indices, subjects = zip( 75 | *[DP.parse(obj, ignore_dp_failure) for obj in objects]) 76 | assert all(len(ind) > 0 for ind in indices), \ 77 | f"Dependency Parsing error: subject indices contain empty cases, please rephrase your prompts or set ignore_dp_failure=True, indices: {indices}" 78 | indices = [[ind + 1 for ind in ind_list] 79 | for ind_list in indices] # will include a token at 0 80 | subjects = ['+'.join(sub) for sub in subjects] 81 | else: 82 | indices = [[1], [1], [1], [1]] # manually assigned if not using DP 83 | # manually assigned if not using DP 84 | subjects = ['sky', 'forest', 'lake', 'reflection'] 85 | 86 | save_name = f"larender_{pretrain}_{'-'.join(subjects)}" 87 | 88 | # init LaRender 89 | LR = LaRender(pretrain=pretrain, 90 | transmittance_use_bbox=transmittance_use_bbox, 91 | transmittance_use_attn_map=transmittance_use_attn_map, 92 | density_scheduler=density_scheduler) 93 | 94 | # run 95 | start_time = time.time() 96 | for alpha in alpha_list: 97 | images = LR.run( 98 | objects, 99 | locations, 100 | alpha, 101 | indices, 102 | num_images_per_prompt, 103 | height, 104 | width, 105 | negative_prompt, 106 | seed) 107 | print("Elapsed time:", time.time() - start_time, "seconds") 108 | os.makedirs(f"results/{save_name}", exist_ok=True) 109 | 110 | # save results 111 | for i, img in enumerate(images): 112 | if dual_elements: 113 | save_fn = f"results/{save_name}/{save_name}_{i:02d}_{alpha[-2]:.2f}_{alpha[-1]:.2f}.jpg" 114 | else: 115 | save_fn = f"results/{save_name}/{save_name}_{i:02d}_{alpha[-1]:.2f}.jpg" 116 | print(f'Save at: {save_fn}') 117 | img.save(save_fn) 118 | 119 | if save_bbox_image: 120 | draw = ImageDraw.Draw(img) 121 | width, height = img.width, img.height 122 | for obj, loc in zip(objects, locations): 123 | draw.rectangle((loc[2] * width, loc[0] * height, loc[3] 124 | * width, loc[1] * height), outline='red', width=2) 125 | font = ImageFont.load_default(size=32) 126 | text_position = (loc[2] * width + 1, loc[0] * height + 1) 127 | draw.text(text_position, obj, 128 | font=font, fill="red") # Text 129 | if dual_elements: 130 | img.save( 131 | f"results/{save_name}/{save_name}_box_{i:02d}_{alpha[-2]:.2f}_{alpha[-1]:.2f}.jpg") 132 | else: 133 | img.save( 134 | f"results/{save_name}/{save_name}_box_{i:02d}_{alpha[-1]:.2f}.jpg") 135 | -------------------------------------------------------------------------------- /demos_occlusion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from PIL import ImageDraw, ImageFont 4 | from utils import DependencyParsing 5 | from larender_sdxl import LaRender 6 | 7 | 8 | if __name__ == "__main__": 9 | # basic params 10 | negative_prompt = '' 11 | pretrain = 'IterComp' # IterComp or GLIGEN 12 | num_images_per_prompt = 10 # set to 1 for speed test 13 | height = 1024 14 | width = 1024 15 | save_bbox_image = True 16 | seed = 5 17 | 18 | # dependency parsing params 19 | use_dependency_parsing = True 20 | ignore_dp_failure = True 21 | 22 | # larender params 23 | transmittance_use_bbox = True 24 | transmittance_use_attn_map = True 25 | density_scheduler = 'inverse_proportional' # other choices: unchanged, opaque 26 | 27 | #### Reproducible experiments, uncomment each block to proceed. #### 28 | 29 | # Multiple "next to" are used here, because the cat is easily mixed into the giraffe, and the piano easily get lost. Never use "behind" or "in front of" to avoid introducing occlusion hints. 30 | objects = ["a large empty living room", "a brown piano", "a giraffe is standing next to a piano and a cat", 31 | "a sofa", "a ginger cat sitting", "a blue teddy bear next to a cat"] 32 | locations = [[0, 1, 0, 1], [0.4, 0.7, 0.2, 0.9], [0, 0.8, 0, 0.5], [0.6, 0.9, 0.05, 0.95], 33 | [0.5, 0.75, 0.5, 0.7], [0.5, 0.9, 0.7, 0.95]] # mode: [y1, y2, x1, x2], range: 0~1 34 | 35 | # objects = ["forest", 'a white cat standing on the ground next to a branch', 36 | # 'a brown dog sitting on the ground next to a branch', 'a long branch'] 37 | # locations = [[0, 1, 0, 1], [0.3, 0.8, 0.1, 0.4], 38 | # [0.3, 0.8, 0.5, 0.8], [0.5, 0.6, 0, 1]] 39 | 40 | # objects = ["forest", 'a white cat standing on the ground next to a branch', 41 | # 'a long branch', 'a brown dog sitting next to a branch'] 42 | # locations = [[0, 1, 0, 1], [0.3, 0.8, 0.1, 0.4], 43 | # [0.5, 0.6, 0, 1], [0.3, 0.8, 0.5, 0.8]] 44 | 45 | # objects = ["forest", 'a brown dog sitting', 46 | # 'a long branch', 'a white cat standing'] 47 | # locations = [[0, 1, 0, 1], [0.3, 0.8, 0.5, 0.8], 48 | # [0.5, 0.6, 0, 1], [0.3, 0.8, 0.1, 0.4]] 49 | 50 | # objects = ["forest", 'a long branch', 'a brown dog sitting', 51 | # 'a white cat standing'] 52 | # locations = [[0, 1, 0, 1],[0.5, 0.6, 0, 1], [0.3, 0.8, 0.55, 0.8], 53 | # [0.3, 0.8, 0.2, 0.4]] 54 | 55 | # objects = ['a giraffe', 'an airplane on the ground'] 56 | # locations = [[0.2, 0.8, 0.4, 0.6], [0.5, 0.8, 0, 1]] 57 | 58 | # objects = ['a house', 'lawn', 'a giant moon'] 59 | # locations = [[0.3, 0.7, 0.1, 0.9], [0.7, 1, 0, 1], [0.55, 0.8, 0.2, 0.45]] 60 | 61 | # objects = ['a cyan vase', 'a yellow clock'] 62 | # locations = [[0.1, 0.8, 0.2, 0.5], [0.4, 0.8, 0.3, 0.7]] # try pretrain = 'GLIGEN' if bbox not accurate 63 | 64 | # objects = ['a yellow clock', 'a cyan vase'] 65 | # locations = [[0.4, 0.8, 0.3, 0.7], [0.1, 0.8, 0.2, 0.5]] 66 | 67 | # objects = ['a girl', 'a refrigerator'] 68 | # locations = [[0.2, 0.9, 0.4, 0.7], [0.3, 0.9, 0.4, 0.7]] 69 | 70 | # objects = ['a refrigerator', 'a girl'] 71 | # locations = [[0.2, 0.9, 0.4, 0.7], [0.2, 0.9, 0.3, 0.6]] 72 | 73 | # objects = ['fence', 'a man', 'a cow next to a man'] 74 | # locations = [[0.4, 1, 0, 1], [0.1, 1, 0.1, 0.5], [0.2, 1, 0.3, 0.8]] 75 | 76 | # objects = ['a cow next to a man', 'a man wearing white shirt', 'fence'] 77 | # locations = [[0.2, 1, 0.3, 0.8], [0.1, 1, 0.1, 0.5], [0.4, 1, 0, 1]] 78 | 79 | # objects = ['park', 'trees', 'a fountain', 'a lion statue', 'bush'] 80 | # locations = [[0, 1, 0, 1], [0, 0.7, 0, 1], [0, 0.7, 0.3, 0.7], [0.4, 0.7, 0.5, 0.8], [0.6, 0.9, 0.1, 0.9]] 81 | 82 | # objects = ['a brown teddy bear', 'a computer'] 83 | # locations = [[0.4, 0.8, 0.2, 0.5], [0.3, 0.8, 0.3, 0.8]] 84 | 85 | # objects = ['a computer', 'a brown teddy bear'] 86 | # locations = [[0.3, 0.8, 0.3, 0.8], [0.4, 0.8, 0.2, 0.5]] 87 | 88 | # objects = ['a piano', 'a giant Yamaha guitar'] 89 | # locations = [[0.5, 0.9, 0.2, 0.8], [0.1, 0.7, 0.4, 0.6]] 90 | 91 | # objects = ['a giant Yamaha guitar', 'a piano'] 92 | # locations = [[0.1, 0.7, 0.4, 0.6], [0.5, 0.9, 0.2, 0.8]] 93 | 94 | # objects = ['a boat', 'a bear'] 95 | # locations = [[0.4, 0.8, 0.1, 0.9], [0.2, 0.8, 0.3, 0.6]] 96 | 97 | # objects = ['a bear', 'a boat'] 98 | # locations = [[0.2, 0.8, 0.3, 0.6], [0.4, 0.8, 0.1, 0.9]] 99 | 100 | # objects = ['a girl in yellow dress next to a boy', 'a boy in blue T-shirt'] 101 | # locations = [[0.2, 1, 0.2, 0.6], [0.1, 1, 0.4, 0.8]] 102 | 103 | # objects = ['a boy in blue T-shirt next to a girl', 'a girl in yellow dress'] 104 | # locations = [[0.1, 1, 0.4, 0.8], [0.2, 1, 0.2, 0.6]] 105 | 106 | # for non-transparent objects, we can always set 0.8 107 | alpha = [0.8 for _ in objects] 108 | 109 | ################ End of experimental settings ################ 110 | 111 | assert all(0 <= a < 1 for a in alpha), "alpha values must be in [0, 1)" 112 | # parse the indices of subject tokens 113 | if use_dependency_parsing: 114 | DP = DependencyParsing() 115 | indices, subjects = zip( 116 | *[DP.parse(obj, ignore_dp_failure) for obj in objects]) 117 | assert all(len(ind) > 0 for ind in indices), \ 118 | f"Dependency Parsing error: subject indices contain empty cases, please rephrase your prompts or set ignore_dp_failure=True, indices: {indices}" 119 | indices = [[ind + 1 for ind in ind_list] 120 | for ind_list in indices] # will include a token at 0 121 | subjects = ['+'.join(sub) for sub in subjects] 122 | else: 123 | indices = [[1], [1], [1], [1]] # manually assigned if not using DP 124 | # manually assigned if not using DP 125 | subjects = ['sky', 'forest', 'lake', 'reflection'] 126 | 127 | save_name = f"larender_{pretrain}_{'-'.join(subjects)}" 128 | 129 | # init LaRender 130 | LR = LaRender(pretrain=pretrain, 131 | transmittance_use_bbox=transmittance_use_bbox, 132 | transmittance_use_attn_map=transmittance_use_attn_map, 133 | density_scheduler=density_scheduler) 134 | 135 | # run 136 | start_time = time.time() 137 | images = LR.run( 138 | objects, 139 | locations, 140 | alpha, 141 | indices, 142 | num_images_per_prompt, 143 | height, 144 | width, 145 | negative_prompt, 146 | seed) 147 | print("Elapsed time:", time.time() - start_time, "seconds") 148 | os.makedirs(f"results/{save_name}", exist_ok=True) 149 | 150 | # save results 151 | for i, img in enumerate(images): 152 | save_fn = f"results/{save_name}/{save_name}_{i:02d}.jpg" 153 | print(f'Save at: {save_fn}') 154 | img.save(save_fn) 155 | 156 | if save_bbox_image: 157 | draw = ImageDraw.Draw(img) 158 | width, height = img.width, img.height 159 | for obj, loc in zip(objects, locations): 160 | draw.rectangle((loc[2] * width, loc[0] * height, loc[3] 161 | * width, loc[1] * height), outline='red', width=2) 162 | font = ImageFont.load_default(size=32) 163 | text_position = (loc[2] * width + 1, loc[0] * height + 1) 164 | draw.text(text_position, obj, font=font, fill="red") # Text 165 | img.save(f"results/{save_name}/{save_name}_box_{i:02d}.jpg") 166 | -------------------------------------------------------------------------------- /larender_sdxl.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from diffusers import DiffusionPipeline 5 | from diffusers.schedulers import DPMSolverMultistepScheduler 6 | from utils import torch_fix_seed 7 | 8 | 9 | class LaRender(): 10 | def __init__(self, 11 | pretrain='IterComp', 12 | transmittance_use_bbox=True, 13 | transmittance_use_attn_map=True, 14 | density_scheduler='unchanged'): 15 | ''' 16 | Args: 17 | pretrain ('str', *optional*): 18 | Pretrain model type, support 'IterComp' or 'GLIGEN' now. 'IterComp' is T2I without location control, while 'GLIGEN' is a pre-trained location contorl model. 19 | transmittance_use_bbox (bool, *optional*): 20 | The flag whether to use bounding boxes in approximating transmittance maps. If set to False, objects in the back usually disappear. 21 | transmittance_use_attn_map (bool, *optional*): 22 | The flag whether to use attention maps in approximating transmittance maps. 23 | density_scheduler (str, *optional*): 24 | The type of density scheduler, support 'opaque', 'unchanged', 'inverse_proportional'. 25 | 26 | ''' 27 | assert pretrain in ['IterComp', 'GLIGEN'] 28 | assert density_scheduler in [ 29 | 'opaque', 'unchanged', 'inverse_proportional'] 30 | self.pretrain = pretrain 31 | self.transmittance_use_bbox = transmittance_use_bbox 32 | self.density_scheduler = density_scheduler 33 | 34 | if pretrain == 'IterComp': 35 | self.pipe = DiffusionPipeline.from_pretrained( 36 | "comin/IterComp", 37 | torch_dtype=torch.float16, 38 | use_safetensors=True) 39 | else: 40 | self.pipe = DiffusionPipeline.from_pretrained( 41 | "jiuntian/gligen-xl-1024", trust_remote_code=True, torch_dtype=torch.float16) 42 | self.pipe.to("cuda") 43 | self.pipe.scheduler = DPMSolverMultistepScheduler.from_config( 44 | self.pipe.scheduler.config, use_karras_sigmas=True) 45 | self.pipe.enable_xformers_memory_efficient_attention() 46 | 47 | self.isvanilla = True 48 | self.transmittance_use_attn_map = transmittance_use_attn_map 49 | 50 | for name, module in self.pipe.unet.named_modules(): 51 | if "attn2" in name and module.__class__.__name__ == "Attention": 52 | module.forward = self.create_larender_forward(module) 53 | module.name = name 54 | 55 | def _forward_cross_attention(self, module, hidden_states, encoder_hidden_states): 56 | query = module.to_q(hidden_states) 57 | key = module.to_k(encoder_hidden_states) 58 | value = module.to_v(encoder_hidden_states) 59 | query = module.head_to_batch_dim(query) 60 | key = module.head_to_batch_dim(key) 61 | value = module.head_to_batch_dim(value) 62 | attention_probs = module.get_attention_scores(query, key) 63 | hidden_states = torch.bmm(attention_probs, value) 64 | hidden_states = module.batch_to_head_dim(hidden_states) 65 | hidden_states = module.to_out[0](hidden_states) 66 | hidden_states = module.to_out[1](hidden_states) 67 | return hidden_states, attention_probs 68 | 69 | def _bboxes_to_mask(self, bboxes, h, w, dtype, device): 70 | ''' 71 | Args: 72 | bboxes (List[List[y1, y2, x1, x2]]): 73 | The bounding boxes in y1, y2, x1, x2 format. 74 | ''' 75 | masks = [] 76 | for bbox in bboxes: 77 | M = torch.zeros((h, w), dtype=dtype, device=device) 78 | upper = int(h * bbox[0]) 79 | lower = int(h * bbox[1]) 80 | left = int(w * bbox[2]) 81 | right = int(w * bbox[3]) 82 | assert lower > upper and right > left, f"Invalid bounding box: {bbox}" 83 | M[upper:lower, left:right] = 1 84 | masks.append(M) 85 | return masks 86 | 87 | def create_larender_forward(self, module): 88 | def forward(hidden_states, encoder_hidden_states=None, **kwargs): 89 | return self.larender_core_process(module, hidden_states, encoder_hidden_states) 90 | return forward 91 | 92 | def _prompt_embedding(self, objects, negative_prompt): 93 | prompt_embeds_list = [] 94 | negative_prompt_embeds_list = [] 95 | pooled_prompt_embeds_list = [] 96 | negative_pooled_prompt_embeds_list = [] 97 | for p in objects: 98 | ( 99 | prompt_embeds, 100 | negative_prompt_embeds, 101 | pooled_prompt_embeds, 102 | negative_pooled_prompt_embeds, 103 | ) = self.pipe.encode_prompt( 104 | prompt=p, 105 | do_classifier_free_guidance=True, 106 | negative_prompt=negative_prompt, 107 | ) 108 | prompt_embeds_list.append(prompt_embeds) 109 | negative_prompt_embeds_list.append(negative_prompt_embeds) 110 | pooled_prompt_embeds_list.append(pooled_prompt_embeds) 111 | negative_pooled_prompt_embeds_list.append( 112 | negative_pooled_prompt_embeds) 113 | 114 | prompt_embeds = torch.cat(prompt_embeds_list, dim=1) 115 | negative_prompt_embeds = torch.cat(negative_prompt_embeds_list, dim=1) 116 | pooled_prompt_embeds = sum( 117 | pooled_prompt_embeds_list) / len(pooled_prompt_embeds) 118 | negative_pooled_prompt_embeds = sum( 119 | negative_pooled_prompt_embeds_list) / len(negative_pooled_prompt_embeds) 120 | 121 | ( 122 | _, 123 | _, 124 | pooled_prompt_embeds, 125 | negative_pooled_prompt_embeds, 126 | ) = self.pipe.encode_prompt( 127 | prompt=', '.join(objects), 128 | do_classifier_free_guidance=True, 129 | negative_prompt=negative_prompt, 130 | ) 131 | return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds 132 | 133 | def larender_core_process(self, module, hidden_states, encoder_hidden_states): 134 | # preparation 135 | locations = self.larender_inputs['locations'] 136 | indices = self.larender_inputs['indices'] 137 | densities = self.larender_inputs['densities'] 138 | height = self.height 139 | width = self.width 140 | latent_h = int(math.sqrt(hidden_states.size()[1] * height / width)) 141 | assert hidden_states.size()[1] % latent_h == 0, \ 142 | "Cannot infer latent_h and latent_w, please check the code." 143 | latent_w = int(hidden_states.size()[1] / latent_h) 144 | num_inference_steps = len(self.pipe.scheduler.timesteps) 145 | assert hasattr(self.pipe.scheduler, 'step_index') 146 | current_count = self.pipe.scheduler.step_index if self.pipe.scheduler.step_index is not None else 0 147 | self.current_count = current_count # only used in visualization 148 | if self.density_scheduler == 'inverse_proportional': 149 | density_multiplier = num_inference_steps / (current_count + 1) 150 | densities_multiplied = [x * density_multiplier for x in densities] 151 | elif self.density_scheduler == 'opaque': 152 | densities_multiplied = [x * num_inference_steps for x in densities] 153 | else: 154 | densities_multiplied = densities 155 | 156 | token_max_length = self.pipe.tokenizer.model_max_length 157 | assert encoder_hidden_states.shape[1] % token_max_length == 0 158 | 159 | # object-wise cross-attention and transmittance map 160 | R = [] 161 | dtype = hidden_states.dtype 162 | device = hidden_states.device 163 | if self.transmittance_use_bbox: 164 | M = self._bboxes_to_mask( 165 | locations, latent_h, latent_w, dtype, device) 166 | else: 167 | M = [torch.ones((latent_h, latent_w), dtype=dtype, 168 | device=device) for _ in locations] 169 | for i in range(self.num_objects): 170 | context = encoder_hidden_states[:, i * 171 | token_max_length: (i+1) * token_max_length, :] 172 | R_i, attn_probs = \ 173 | self._forward_cross_attention(module, hidden_states, context) 174 | R.append(R_i.reshape(R_i.shape[0], 175 | latent_h, latent_w, R_i.shape[2])) 176 | if self.transmittance_use_attn_map: 177 | attn_map = attn_probs.reshape( 178 | attn_probs.shape[0], latent_h, latent_w, -1) 179 | attn_map = attn_map[..., indices[i]].mean(dim=-1).mean(dim=0) 180 | min_val = attn_map.min() 181 | max_val = attn_map.max() 182 | normalized_attn_map = ( 183 | attn_map - min_val) / (max_val - min_val) 184 | M[i] *= normalized_attn_map 185 | 186 | # accumulated transmittance maps (visibility of planar i from the virtual camera) 187 | T = [torch.ones((latent_h, latent_w), dtype=dtype, device=device) 188 | for _ in range(self.num_objects)] 189 | for i in range(self.num_objects - 1, -1, -1): # top to bottom 190 | for j in range(i + 1, self.num_objects): 191 | # TODO 192 | T[i] *= torch.exp(-torch.tensor(densities_multiplied[j]) * M[j]) 193 | 194 | # rendering 195 | S = 0 196 | R_out = 0 197 | for i in range(self.num_objects): 198 | contrib = T[i] * (1 - math.exp(-densities_multiplied[i])) * M[i] 199 | contrib = contrib[None, :, :, None] # unsqueeze 0, -1 200 | R_out += contrib * R[i] 201 | S += contrib 202 | S = torch.clamp(S, min=1e-6) 203 | R_out /= S 204 | return R_out.reshape(*hidden_states.shape) 205 | 206 | def run(self, 207 | objects, 208 | locations, 209 | alpha, 210 | indices, 211 | num_images_per_prompt=1, 212 | height=1024, 213 | width=1024, 214 | negative_prompt='', 215 | seed=0): 216 | ''' 217 | Args: 218 | objects (List[str]): 219 | List of object prompts. 220 | locations (List[List[float]]): 221 | List of bounding boxes in [top, bottom, left, right] format. 222 | alpha ([List[float]): 223 | List of alpha of each object, values must be in [0, 1). 224 | indices (List[List[int]]): 225 | List of indices of subjects in each object prompt. 226 | Output: 227 | images (List[PIL.Image]): 228 | List of generated images. 229 | ''' 230 | for a in alpha: 231 | assert a < 1 and a >= 0, "alpha must in range of [0, 1)" 232 | densities = [-np.log(1 - a) for a in alpha] 233 | self.num_objects = len(objects) 234 | assert len(locations) == len(objects) and len( 235 | densities) == len(objects) 236 | self.larender_inputs = dict( 237 | locations=locations, indices=indices, densities=densities) 238 | 239 | self.height = height 240 | self.width = width 241 | 242 | (prompt_embeds, 243 | negative_prompt_embeds, 244 | pooled_prompt_embeds, 245 | negative_pooled_prompt_embeds) = self._prompt_embedding(objects, negative_prompt) 246 | 247 | torch_fix_seed(seed) 248 | 249 | common_kwargs = dict( 250 | num_images_per_prompt=num_images_per_prompt, 251 | height=height, 252 | width=width, 253 | guidance_scale=7.0, 254 | prompt_embeds=prompt_embeds, 255 | negative_prompt_embeds=negative_prompt_embeds, 256 | pooled_prompt_embeds=pooled_prompt_embeds, 257 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds) 258 | if self.pretrain == 'IterComp': 259 | images = self.pipe( 260 | num_inference_steps=25, **common_kwargs).images 261 | else: 262 | images = self.pipe( 263 | num_inference_steps=50, 264 | gligen_scheduled_sampling_beta=0.4, 265 | gligen_boxes=[[loc[2], loc[0], loc[3], loc[1]] 266 | for loc in locations], # y1y2x1x2 -> x1y1x2y2 267 | gligen_phrases=objects, 268 | **common_kwargs).images 269 | return images 270 | --------------------------------------------------------------------------------