├── .gitignore ├── LICENSE ├── README.md ├── SLD_benchmark.py ├── SLD_demo.py ├── benchmark_config.ini ├── demo ├── image_editing │ ├── data.json │ ├── results │ │ ├── dalle3_banana │ │ │ ├── det_result_obj.png │ │ │ ├── final_dalle3_banana.png │ │ │ ├── initial_image.png │ │ │ └── intermediate_dalle3_banana.png │ │ ├── dalle3_dog │ │ │ ├── det_result_obj.png │ │ │ ├── final_dalle3_dog.png │ │ │ ├── initial_image.png │ │ │ └── intermediate_dalle3_dog.png │ │ ├── indoor_scene_attr_mod │ │ │ ├── det_result_obj.png │ │ │ ├── final_indoor_scene.png │ │ │ ├── initial_image.png │ │ │ └── intermediate_indoor_scene.png │ │ ├── indoor_scene_move │ │ │ ├── corrected_result.png │ │ │ ├── det_result_obj.png │ │ │ ├── final_indoor_scene.png │ │ │ ├── initial_image.png │ │ │ └── intermediate_indoor_scene.png │ │ ├── indoor_scene_replace │ │ │ ├── det_result_obj.png │ │ │ ├── final_indoor_scene.png │ │ │ ├── initial_image.png │ │ │ └── intermediate_indoor_scene.png │ │ ├── indoor_scene_resize │ │ │ ├── det_result_obj.png │ │ │ ├── final_indoor_scene.png │ │ │ ├── initial_image.png │ │ │ └── intermediate_indoor_scene.png │ │ ├── indoor_scene_swap │ │ │ ├── det_result_obj.png │ │ │ ├── final_indoor_scene.png │ │ │ ├── initial_image.png │ │ │ └── intermediate_indoor_scene.png │ │ ├── indoor_table_attr_mod │ │ │ ├── det_result_obj.png │ │ │ ├── final_indoor_table.png │ │ │ ├── initial_image.png │ │ │ └── intermediate_indoor_table.png │ │ ├── indoor_table_move │ │ │ ├── det_result_obj.png │ │ │ ├── final_indoor_table.png │ │ │ ├── initial_image.png │ │ │ └── intermediate_indoor_table.png │ │ ├── indoor_table_replace │ │ │ ├── det_result_obj.png │ │ │ ├── final_indoor_table.png │ │ │ ├── initial_image.png │ │ │ └── intermediate_indoor_table.png │ │ ├── indoor_table_resize │ │ │ ├── det_result_obj.png │ │ │ ├── final_indoor_table.png │ │ │ ├── initial_image.png │ │ │ └── intermediate_indoor_table.png │ │ └── indoor_table_swap │ │ │ ├── det_result_obj.png │ │ │ ├── final_indoor_table.png │ │ │ ├── initial_image.png │ │ │ └── intermediate_indoor_table.png │ └── src_image │ │ ├── dalle3_banana.png │ │ ├── dalle3_dog.png │ │ ├── indoor_scene.png │ │ └── indoor_table.png └── self_correction │ ├── data.json │ ├── results │ ├── dalle3_beach │ │ ├── det_result_obj.png │ │ ├── final_dalle3_beach.png │ │ ├── initial_image.png │ │ └── intermediate_dalle3_beach.png │ ├── dalle3_clown │ │ ├── det_result_obj.png │ │ ├── final_dalle3_clown.png │ │ ├── initial_image.png │ │ └── intermediate_dalle3_clown.png │ ├── dalle3_motor │ │ ├── det_result_obj.png │ │ ├── final_dalle3_motor.png │ │ ├── initial_image.png │ │ └── intermediate_dalle3_motor.png │ ├── dalle3_snowwhite │ │ ├── det_result_obj.png │ │ ├── final_dalle3_snowwhite.png │ │ ├── initial_image.png │ │ └── intermediate_dalle3_snowwhite.png │ ├── lmdplus_beach │ │ ├── det_result_obj.png │ │ ├── final_lmdplus_beach.png │ │ ├── initial_image.png │ │ └── intermediate_lmdplus_beach.png │ ├── lmdplus_motor │ │ ├── det_result_obj.png │ │ ├── final_lmdplus_motor.png │ │ ├── initial_image.png │ │ └── intermediate_lmdplus_motor.png │ ├── sdxl_beach │ │ ├── det_result_obj.png │ │ ├── final_sdxl_beach.png │ │ ├── initial_image.png │ │ └── intermediate_sdxl_beach.png │ └── sdxl_motor │ │ ├── det_result_obj.png │ │ ├── final_sdxl_motor.png │ │ ├── initial_image.png │ │ └── intermediate_sdxl_motor.png │ └── src_image │ ├── dalle3_beach.png │ ├── dalle3_car.png │ ├── dalle3_clown.png │ ├── dalle3_motor.png │ ├── dalle3_snowwhite.png │ ├── lmdplus_beach.png │ ├── lmdplus_motor.png │ ├── sdxl_beach.png │ └── sdxl_motor.png ├── demo_config.ini ├── eval ├── __init__.py ├── eval.py ├── lmd.py └── utils.py ├── lmd_benchmark_eval.py ├── models ├── __init__.py ├── attention.py ├── attention_processor.py ├── models.py ├── pipelines.py ├── sam.py ├── transformer_2d.py ├── unet_2d_blocks.py └── unet_2d_condition.py ├── requirements.txt ├── sld ├── detector.py ├── image_generator.py ├── llm_chat.py ├── llm_template.py ├── sdxl_refine.py └── utils.py └── utils ├── __init__.py ├── attn.py ├── boxdiff.py ├── guidance.py ├── latents.py ├── parse.py ├── schedule.py ├── utils.py └── vis.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .vscode 163 | *tmp_imgs 164 | # Never upload you secret token to github!!! 165 | config.ini 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tsung-Han Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-correcting LLM-controlled Diffusion Models 2 | 3 | This repo provides the PyTorch source code of our paper: [Self-correcting LLM-controlled Diffusion Models (CVPR 2024)](https://arxiv.org/abs/2311.16090). Check out project page [here](https://self-correcting-llm-diffusion.github.io/)! 4 | 5 | [![MIT license](https://img.shields.io/badge/License-MIT-blue.svg)](https://lbesson.mit-license.org/) [![arXiv](https://img.shields.io/badge/arXiv-2311.16090-red)](https://arxiv.org/abs/2311.16090) 6 | 7 | 8 | **Authors**: [Tsung-Han Wu\*](https://tsunghan-wu.github.io/), [Long Lian\*](https://tonylian.com/), [Joseph E. Gonzalez](https://people.eecs.berkeley.edu/~jegonzal/), [Boyi Li†](https://sites.google.com/site/boyilics/home), [Trevor Darrell†](https://people.eecs.berkeley.edu/~trevor/) at UC Berkeley. 9 | 10 | 11 | ## :rocket: The Self-correcting LLM-controlled Diffusion (SLD) Framework Highlights: 12 | 1. **Self-correction**: Enhances generative models with LLM-integrated detectors for precise text-to-image alignment. 13 | 2. **Unified Generation and Editing**: Excels at both image generation and fine-grained editing. 14 | 3. **Universal Compatibility**: Works with ANY image generator, like DALL-E 3, requiring no extra training or data. 15 | 16 | ![](https://self-correcting-llm-diffusion.github.io/main_figure.jpg) 17 | 18 | ## :rotating_light: Update 19 | - 03/10/2024 - Add the all SLD scripts and results on the LMD T2I benchmark (all done!) 20 | - 02/13/2024 - Add self-correction and image editing scripts with a few demo examples 21 | 22 | ## :wrench: Installation Guide 23 | 24 | ### System Requirements 25 | 26 | - System Setup: Linux with a single A100 GPU (GPUs with more than 24 GB RAM are also compatible). For Mac or Windows, minor adjustments may be necessary. 27 | 28 | - Dependency Installation: Create a Python environment named "SLD" and install necessary dependencies: 29 | 30 | 31 | ```bash 32 | conda create -n SLD python=3.9 33 | pip3 install -r requirements.txt 34 | ``` 35 | 36 | Note: Ensure the versions of transformers and diffusers match the requirements. Versions of `transformers` before 4.35 do not include `owlv2`, and our code is incompatible with some newer versions of diffusers with different API. 37 | 38 | ## :gear: Usage 39 | 40 | Execute the following command to process images from an input directory according to the instruction in the JSON file and save the transformed images to an output directory. 41 | 42 | ``` 43 | CUDA_VISIBLE_DEVICES=X python3 SLD_demo.py \ 44 | --json-file demo/self_correction/data.json \ # demo/image_editing/data.json 45 | --input-dir demo/self_correction/src_image \ # demo/image_editing/src_image 46 | --output-dir demo/self_correction/results \ # demo/image_editing/results 47 | --mode self_correction \ # image_editing 48 | --config demo_config.ini 49 | ``` 50 | 51 | 1. This script supports both self-correction and image editing modes. Adjust the paths and --mode flag as needed. 52 | 2. We use `gligen/diffusers-generation-text-box` (SDv1.4) as the base diffusion model for image manipulation. For enhanced image quality, we incorporate SDXL refinement techniques similar to [LMD](https://github.com/TonyLianLong/LLM-groundedDiffusion). 53 | 54 | 55 | ## :briefcase: Applying to Your Own Images 56 | 57 | 1. Prepare a JSON File: Structure the file as follows, providing necessary information for each image you wish to process: 58 | 59 | ``` 60 | [ 61 | { 62 | "input_fname": "", 63 | "output_dir": "", 64 | "prompt": "", 65 | "generator": "", 66 | "llm_parsed_prompt": null, // Leave blank for automatic generation 67 | "llm_layout_suggestions": null // Leave blank for automatic suggestions 68 | } 69 | ] 70 | 71 | ``` 72 | 73 | Ensure you replace placeholder text with actual values for each parameter. The llm_parsed_prompt and llm_layout_suggestions are optional and can be left as null for LLM automatic generation. 74 | 75 | 2. Setting the config 76 | 77 | - Duplicate the config/demo_config.ini file to a preferred location. 78 | - Update this copied config file with your OpenAI API key and organization details, along with any other necessary hyper-parameter adjustments. 79 | - **For security reasons, avoid uploading your secret key to public repositories or online platforms.** 80 | 81 | 3. Execute the Script: Run the script similarly to the provided demo, adjusting the command line arguments as needed for your specific configuration and the JSON file you've prepared. 82 | 83 | ## :chart_with_upwards_trend: Quantitative Evaluation on Text-to-Image (T2I) Generation 84 | 85 | In our research, we've shown the superior performance of SLD across four key tasks: negation, numeracy, attribute binding, and spatial relationship. Utilizing the LMD 400 prompts T2I generation [benchmark](https://github.com/TonyLianLong/LLM-groundedDiffusion?tab=readme-ov-file#run-our-benchmark-on-text-to-layout-generation-evaluation), and employing the state-of-the-art [OWLv2](https://huggingface.co/docs/transformers/main/en/model_doc/owlv2) detector with a fixed detection threshold, we've ensured a fair comparison between different methods. Below, we provide both the code and the necessary data for full reproducibility. 86 | 87 | ### Image Correction Logs and Results 88 | 89 | The image generation process, including both the initial and resulting images, has been documented to ensure transparency and ease of further research: 90 | 91 | | Method | Negation | Numeracy | Attribution | Spatial | Overall | 92 | | --------------------- | -------- | -------- | ----------- | ------- | --------- | 93 | | DALL-E 3 | 27 | 37 | 74 | 71 | 52.3% | 94 | | [DALL-E 3 w/ SLD](https://drive.google.com/file/d/1rHqah-TEPsE2vXDS_CQTBhlSVSQ8fGh5/view?usp=sharing) | 84 | 58 | 80 | 85 | 76.8% (+24.5) | 95 | | LMD+ | 100 | 80 | 49 | 88 | 79.3% | 96 | | [LMD+ w/ SLD](https://drive.google.com/file/d/1-yw9_erL6DsQhVVM3LJAeiNRA2dm5VRl/view?usp=sharing) | 100 | 94 | 65 | 97 | 89.0% (+9.7) | 97 | 98 | For access to the data and to generate these performance metrics or to reproduce the correction process yourself, please refer to the above table. The structure of the dataset is as follows: 99 | 100 | ``` 101 | dalle3_sld 102 | ├── 000 # LMD benchmark prompt ID 103 | │ ├── chatgpt_data.json # raw GPT-4 response 104 | │ ├── det_result1.jpg # visualization of bboxes 105 | │ ├── initial_image.jpg # initial generation results 106 | │ ├── log.txt # loggging 107 | │ └── round1.jpg # round[X] SLD correction rsults 108 | ├── 001 109 | │ ├── chatgpt_data.json 110 | │ ├── det_result1.jpg 111 | ... 112 | ``` 113 | 114 | 115 | To generate these performance metrics on your own, execute the following command: 116 | 117 | ``` 118 | python3 lmd_benchmark_eval.py --data_dir [GENERATION_DIR] [--optional-args] 119 | ``` 120 | 121 | ### Reproducing Results 122 | 123 | To replicate our image correction process, follow these steps: 124 | 125 | 1. Setting the config 126 | 127 | - Duplicate the config/benchmark_config.ini file to a preferred location. 128 | - Update this copied config file with your OpenAI API key and organization details, along with any other necessary hyper-parameter adjustments. 129 | - **For security reasons, avoid uploading your secret key to public repositories or online platforms.** 130 | 131 | 2. Execute the SLD Correction Script 132 | 133 | To apply the SLD correction and perform the evaluation, run the following command: 134 | 135 | ``` 136 | python3 SLD_benchmark.py --data_dir [OUTPUT_DIR] 137 | ``` 138 | 139 | Executing this command will overwrite all existing log files and generated images within the specified directory. Ensure you have backups or are working on copies of data that you can afford to lose. 140 | 141 | Also, if you wanna correct other diffusion models, feel free to put the data into the similar structure and then run our code! 142 | 143 | ## :question: Frequently Asked Questions (FAQ) 144 | 145 | 1. **Why are the results for my own image not optimal?** 146 | 147 | *The SLD framework, while training-free and effective for achieving text-to-image alignment—particularly with numeracy, spatial relationships, and attribute binding—may not consistently deliver optimal visual quality. Tailoring hyper-parameters to your specific image can enhance outcomes.* 148 | 149 | 2. **Why do the images generated differ from those in the paper?** 150 | 151 | *In our demonstrations, we use consistent random seeds and hyper-parameters for simplicity, differing from the iterative optimization process in our paper figure. For optimal results, we recommend fine-tuning critical hyper-parameters, such as the dilation parameter in the SAM refinement process or parameters in DiffEdit, tailored to your specific use case.* 152 | 153 | 3. **Isn't using SDXL for improved visualization results unfair?** 154 | 155 | *For quantitative comparisons with baselines in our paper (Table 1), we explicitly exclude the SDXL refinement step to maintain fairness. Also, we set the same hyper-parameters across all models in [quantitative evaluation](#chart_with_upwards_trend-quantitative-evaluation-on-text-to-image-t2i-generation)* 156 | 157 | 4. **Can other LLMs replace GPT-4 in your process?** 158 | 159 | *Yes, other LLMs may be used as alternatives. Our tests with GPT-3.5-turbo indicate only minor performance drops. We encourage exploration with other robust open-source tools like [FastChat](https://github.com/lm-sys/FastChat).* 160 | 161 | 5. **Have more questions or encountered any bugs?** 162 | 163 | *Please use the GitHub issues section for bug reports. For further inquiries, contact Tsung-Han (Patrick) Wu at tsunghan_wu@berkeley.edu.* 164 | 165 | ## :pray: Acknowledgements 166 | 167 | We are grateful for the foundational code provided by [Diffusers](https://huggingface.co/docs/diffusers/index) and [LMD](https://github.com/TonyLianLong/LLM-groundedDiffusion). Utilizing their resources implies agreement to their respective licenses. Our project benefits greatly from these contributions, and we acknowledge their significant impact on our work. 168 | 169 | ## :dart: Citation 170 | If you use our work or our implementation in this repo, or find them helpful, please consider giving a citation. 171 | ``` 172 | @article{wu2023self, 173 | title={Self-correcting LLM-controlled Diffusion Models}, 174 | author={Wu, Tsung-Han and Lian, Long and Gonzalez, Joseph E and Li, Boyi and Darrell, Trevor}, 175 | journal={arXiv preprint arXiv:2311.16090}, 176 | year={2023} 177 | } 178 | ``` 179 | -------------------------------------------------------------------------------- /SLD_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import shutil 5 | import random 6 | import numpy as np 7 | import argparse 8 | import configparser 9 | from PIL import Image 10 | 11 | 12 | import torch 13 | import diffusers 14 | 15 | # Libraries heavily borrowed from LMD 16 | import models 17 | from models import sam 18 | from utils import parse, utils 19 | 20 | # SLD specific imports 21 | from sld.detector import OWLVITV2Detector 22 | from sld.sdxl_refine import sdxl_refine 23 | from sld.utils import get_all_latents, run_sam, run_sam_postprocess, resize_image 24 | from sld.llm_template import spot_object_template, spot_difference_template, image_edit_template 25 | from sld.llm_chat import get_key_objects, get_updated_layout 26 | 27 | 28 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 29 | 30 | 31 | # Operation #1: Addition (The code is in sld/image_generator.py) 32 | 33 | # Operation #2: Deletion (Preprocessing region mask for removal) 34 | def get_remove_region(entry, remove_objects, move_objects, preserve_objs, models, config): 35 | """Generate a region mask for removal given bounding box info.""" 36 | 37 | image_source = np.array(Image.open(entry["output"][-1])) 38 | H, W, _ = image_source.shape 39 | 40 | # if no remove objects, set zero to the whole mask 41 | if (len(remove_objects) + len(move_objects)) == 0: 42 | remove_region = np.zeros((W // 8, H // 8), dtype=np.int64) 43 | return remove_region 44 | 45 | # Otherwise, run the SAM segmentation to locate target regions 46 | remove_items = remove_objects + [x[0] for x in move_objects] 47 | remove_mask = np.zeros((H, W, 3), dtype=bool) 48 | for obj in remove_items: 49 | masks = run_sam(bbox=obj[1], image_source=image_source, models=models) 50 | remove_mask = remove_mask | masks 51 | 52 | # Preserve the regions that should not be removed 53 | preserve_mask = np.zeros((H, W, 3), dtype=bool) 54 | for obj in preserve_objs: 55 | masks = run_sam(bbox=obj[1], image_source=image_source, models=models) 56 | preserve_mask = preserve_mask | masks 57 | # Process the SAM mask by averaging, thresholding, and dilating. 58 | preserve_region = run_sam_postprocess(preserve_mask, H, W, config) 59 | remove_region = run_sam_postprocess(remove_mask, H, W, config) 60 | remove_region = np.logical_and(remove_region, np.logical_not(preserve_region)) 61 | return remove_region 62 | 63 | 64 | # Operation #3: Repositioning (Preprocessing latent) 65 | def get_repos_info(entry, move_objects, models, config): 66 | """ 67 | Updates a list of objects to be moved / reshaped, including resizing images and generating masks. 68 | * Important: Perform image reshaping at the image-level rather than the latent-level. 69 | * Warning: For simplicity, the object is not positioned to the center of the new region... 70 | """ 71 | 72 | # if no remove objects, set zero to the whole mask 73 | if not move_objects: 74 | return move_objects 75 | image_source = np.array(Image.open(entry["output"][-1])) 76 | H, W, _ = image_source.shape 77 | inv_seed = int(config.get("SLD", "inv_seed")) 78 | 79 | new_move_objects = [] 80 | for item in move_objects: 81 | new_img, obj = resize_image(image_source, item[0][1], item[1][1]) 82 | old_object_region = run_sam_postprocess(run_sam(obj, new_img, models), H, W, config).astype(np.bool_) 83 | all_latents, _ = get_all_latents(new_img, models, inv_seed) 84 | new_move_objects.append( 85 | [item[0][0], obj, item[1][1], old_object_region, all_latents] 86 | ) 87 | 88 | return new_move_objects 89 | 90 | 91 | # Operation #4: Attribute Modification (Preprocessing latent) 92 | def get_attrmod_latent(entry, change_attr_objects, models, config): 93 | """ 94 | Processes objects with changed attributes to generate new latents and the name of the modified objects. 95 | 96 | Parameters: 97 | entry (dict): A dictionary containing output data. 98 | change_attr_objects (list): A list of objects with changed attributes. 99 | models (Model): The models used for processing. 100 | inv_seed (int): Seed for inverse generation. 101 | 102 | Returns: 103 | list: A list containing new latents and names of the modified objects. 104 | """ 105 | if len(change_attr_objects) == 0: 106 | return [] 107 | from diffusers import StableDiffusionDiffEditPipeline 108 | from diffusers import DDIMScheduler, DDIMInverseScheduler 109 | 110 | img = Image.open(entry["output"][-1]) 111 | image_source = np.array(img) 112 | H, W, _ = image_source.shape 113 | inv_seed = int(config.get("SLD", "inv_seed")) 114 | 115 | # Initialize the Stable Diffusion pipeline 116 | pipe = StableDiffusionDiffEditPipeline.from_pretrained( 117 | "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16 118 | ).to("cuda") 119 | 120 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) 121 | pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) 122 | pipe.enable_model_cpu_offload() 123 | new_change_objects = [] 124 | for obj in change_attr_objects: 125 | # Run diffedit 126 | old_object_region = run_sam_postprocess(run_sam(obj[1], image_source, models), H, W, config) 127 | old_object_region = old_object_region.astype(np.bool_)[np.newaxis, ...] 128 | 129 | new_object = obj[0].split(" #")[0] 130 | base_object = new_object.split(" ")[-1] 131 | mask_prompt = f"a {base_object}" 132 | new_prompt = f"a {new_object}" 133 | 134 | image_latents = pipe.invert( 135 | image=img, 136 | prompt=mask_prompt, 137 | inpaint_strength=float(config.get("SLD", "diffedit_inpaint_strength")), 138 | generator=torch.Generator(device="cuda").manual_seed(inv_seed), 139 | ).latents 140 | image = pipe( 141 | prompt=new_prompt, 142 | mask_image=old_object_region, 143 | image_latents=image_latents, 144 | guidance_scale=float(config.get("SLD", "diffedit_guidance_scale")), 145 | inpaint_strength=float(config.get("SLD", "diffedit_inpaint_strength")), 146 | generator=torch.Generator(device="cuda").manual_seed(inv_seed), 147 | negative_prompt="", 148 | ).images[0] 149 | 150 | all_latents, _ = get_all_latents(np.array(image), models, inv_seed) 151 | new_change_objects.append( 152 | [ 153 | old_object_region[0], 154 | all_latents, 155 | ] 156 | ) 157 | return new_change_objects 158 | 159 | 160 | def correction( 161 | entry, add_objects, move_objects, 162 | remove_region, change_attr_objects, 163 | models, config 164 | ): 165 | spec = { 166 | "add_objects": add_objects, 167 | "move_objects": move_objects, 168 | "prompt": entry["instructions"], 169 | "remove_region": remove_region, 170 | "change_objects": change_attr_objects, 171 | "all_objects": entry["llm_suggestion"], 172 | "bg_prompt": entry["bg_prompt"], 173 | "extra_neg_prompt": entry["neg_prompt"], 174 | } 175 | image_source = np.array(Image.open(entry["output"][-1])) 176 | # Background latent preprocessing 177 | all_latents, _ = get_all_latents(image_source, models, int(config.get("SLD", "inv_seed"))) 178 | ret_dict = image_generator.run( 179 | spec, 180 | fg_seed_start=int(config.get("SLD", "fg_seed")), 181 | bg_seed=int(config.get("SLD", "bg_seed")), 182 | bg_all_latents=all_latents, 183 | frozen_step_ratio=float(config.get("SLD", "frozen_step_ratio")), 184 | ) 185 | return ret_dict 186 | 187 | 188 | def spot_objects(prompt, data, config): 189 | # If the object list is not available, run the LLM to spot objects 190 | if data.get("llm_parsed_prompt") is None: 191 | questions = f"User Prompt: {prompt}\nReasoning:\n" 192 | message = spot_object_template + questions 193 | results = get_key_objects(message, config) 194 | return results[0] # Extracting the object list 195 | else: 196 | return data["llm_parsed_prompt"] 197 | 198 | 199 | def spot_differences(prompt, det_results, data, config, mode="self_correction"): 200 | if data.get("llm_layout_suggestions") is None: 201 | questions = ( 202 | f"User Prompt: {prompt}\nCurrent Objects: {det_results}\nReasoning:\n" 203 | ) 204 | if mode == "self_correction": 205 | message = spot_difference_template + questions 206 | else: 207 | message = image_edit_template + questions 208 | llm_suggestions = get_updated_layout(message, config) 209 | return llm_suggestions[0] 210 | else: 211 | return data["llm_layout_suggestions"] 212 | 213 | 214 | if __name__ == "__main__": 215 | # create argument parser 216 | parser = argparse.ArgumentParser(description="Demo for the SLD pipeline") 217 | parser.add_argument("--json-file", type=str, default="demo/self_correction/data.json", help="Path to the json file") 218 | parser.add_argument("--input-dir", type=str, default="demo/self_correction/src_image", help="Path to the input directory") 219 | parser.add_argument("--output-dir", type=str, default="demo/self_correction/results", help="Path to the output directory") 220 | parser.add_argument("--mode", type=str, default="self_correction", help="Mode of the demo", choices=["self_correction", "image_editing"]) 221 | parser.add_argument("--config", type=str, default="demo_config.ini", help="Path to the config file") 222 | args = parser.parse_args() 223 | 224 | # Open the json file configured for self-correction (a list of filenames with prompts and other info...) 225 | # Create the output directory 226 | with open(args.json_file) as f: 227 | data = json.load(f) 228 | save_dir = args.output_dir 229 | parse.img_dir = os.path.join(save_dir, "tmp_imgs") 230 | os.makedirs(save_dir, exist_ok=True) 231 | os.makedirs(parse.img_dir, exist_ok=True) 232 | 233 | # Read config 234 | config = configparser.ConfigParser() 235 | config.read(args.config) 236 | 237 | # Load models 238 | models.sd_key = "gligen/diffusers-generation-text-box" 239 | models.sd_version = "sdv1.4" 240 | diffusion_scheduler = None 241 | 242 | models.model_dict = models.load_sd( 243 | key=models.sd_key, 244 | use_fp16=False, 245 | load_inverse_scheduler=True, 246 | scheduler_cls=diffusers.schedulers.__dict__[diffusion_scheduler] 247 | if diffusion_scheduler is not None 248 | else None, 249 | ) 250 | sam_model_dict = sam.load_sam() 251 | models.model_dict.update(sam_model_dict) 252 | from sld import image_generator 253 | 254 | det = OWLVITV2Detector() 255 | # Iterate through the json file 256 | for idx in range(len(data)): 257 | 258 | # Reset random seeds 259 | default_seed = int(config.get("SLD", "default_seed")) 260 | torch.manual_seed(default_seed) 261 | np.random.seed(default_seed) 262 | random.seed(default_seed) 263 | 264 | # Load the image and prompt 265 | rel_fname = data[idx]["input_fname"] 266 | fname = os.path.join(args.input_dir, f"{rel_fname}.png") 267 | 268 | prompt = data[idx]["prompt"] 269 | dirname = os.path.join(save_dir, data[idx]["output_dir"]) 270 | os.makedirs(dirname, exist_ok=True) 271 | 272 | output_fname = os.path.join(dirname, f"initial_image.png") 273 | shutil.copy(fname, output_fname) 274 | 275 | print("-" * 5 + f" [Self-Correcting {fname}] " + "-" * 5) 276 | print(f"Target Textual Prompt: {prompt}") 277 | 278 | # Step 1: Spot Objects with LLM 279 | llm_parsed_prompt = spot_objects(prompt, data[idx], config) 280 | entry = {"instructions": prompt, "output": [fname], "generator": data[idx]["generator"], 281 | "objects": llm_parsed_prompt["objects"], 282 | "bg_prompt": llm_parsed_prompt["bg_prompt"], 283 | "neg_prompt": llm_parsed_prompt["neg_prompt"] 284 | } 285 | print("-" * 5 + f" Parsing Prompts " + "-" * 5) 286 | print(f"* Objects: {entry['objects']}") 287 | print(f"* Background: {entry['bg_prompt']}") 288 | print(f"* Negation: {entry['neg_prompt']}") 289 | 290 | # Step 2: Run open vocabulary detector 291 | print("-" * 5 + f" Running Detector " + "-" * 5) 292 | default_attr_threshold = float(config.get("SLD", "attr_detection_threshold")) 293 | default_prim_threshold = float(config.get("SLD", "prim_detection_threshold")) 294 | default_nms_threshold = float(config.get("SLD", "nms_threshold")) 295 | 296 | attr_threshold = float(config.get(entry["generator"], "attr_detection_threshold", fallback=default_attr_threshold)) 297 | prim_threshold = float(config.get(entry["generator"], "prim_detection_threshold", fallback=default_prim_threshold)) 298 | nms_threshold = float(config.get(entry["generator"], "nms_threshold", fallback=default_nms_threshold)) 299 | det_results = det.run(prompt, entry["objects"], entry["output"][-1], 300 | attr_detection_threshold=attr_threshold, 301 | prim_detection_threshold=prim_threshold, 302 | nms_threshold=nms_threshold) 303 | 304 | print("-" * 5 + f" Getting Modification Suggestions " + "-" * 5) 305 | 306 | # Step 3: Spot difference between detected results and initial prompts 307 | llm_suggestions = spot_differences(prompt, det_results, data[idx], config, mode=args.mode) 308 | 309 | print(f"* Detection Restuls: {det_results}") 310 | print(f"* LLM Suggestions: {llm_suggestions}") 311 | entry["det_results"] = copy.deepcopy(det_results) 312 | entry["llm_suggestion"] = copy.deepcopy(llm_suggestions) 313 | # Compare the two layouts to know where to update 314 | ( 315 | preserve_objs, 316 | deletion_objs, 317 | addition_objs, 318 | repositioning_objs, 319 | attr_modification_objs, 320 | ) = det.parse_list(det_results, llm_suggestions) 321 | 322 | print("-" * 5 + f" Editing Operations " + "-" * 5) 323 | print(f"* Preservation: {preserve_objs}") 324 | print(f"* Addition: {addition_objs}") 325 | print(f"* Deletion: {deletion_objs}") 326 | print(f"* Repositioning: {repositioning_objs}") 327 | print(f"* Attribute Modification: {attr_modification_objs}") 328 | total_ops = len(deletion_objs) + len(addition_objs) + len(repositioning_objs) + len(attr_modification_objs) 329 | # Visualization 330 | parse.show_boxes( 331 | gen_boxes=entry["det_results"], 332 | additional_boxes=entry["llm_suggestion"], 333 | img=np.array(Image.open(entry["output"][-1])).astype(np.uint8), 334 | fname=os.path.join(dirname, "det_result_obj.png"), 335 | ) 336 | # Check if there are any changes to apply 337 | if (total_ops == 0): 338 | print("-" * 5 + f" Results " + "-" * 5) 339 | output_fname = os.path.join(dirname, f"final_{rel_fname}.png") 340 | shutil.copy(entry["output"][-1], output_fname) 341 | print("* No changes to apply!") 342 | print(f"* Output File: {output_fname}") 343 | # Shortcut to proceed to the next round! 344 | continue 345 | 346 | # Step 4: T2I Ops: Addition / Deletion / Repositioning / Attr. Modification 347 | print("-" * 5 + f" Image Manipulation " + "-" * 5) 348 | 349 | deletion_region = get_remove_region( 350 | entry, deletion_objs, repositioning_objs, [], models, config 351 | ) 352 | repositioning_objs = get_repos_info( 353 | entry, repositioning_objs, models, config 354 | ) 355 | attr_modification_objs = get_attrmod_latent( 356 | entry, attr_modification_objs, models, config 357 | ) 358 | 359 | ret_dict = correction( 360 | entry, addition_objs, repositioning_objs, 361 | deletion_region, attr_modification_objs, 362 | models, config 363 | ) 364 | # Save an intermediate file without the SDXL refinement 365 | curr_output_fname = os.path.join(dirname, f"intermediate_{rel_fname}.png") 366 | Image.fromarray(ret_dict.image).save(curr_output_fname) 367 | print("-" * 5 + f" Results " + "-" * 5) 368 | print("* Output File (Before SDXL): ", curr_output_fname) 369 | utils.free_memory() 370 | 371 | # Can run this if applying SDXL as the refine process 372 | sdxl_output_fname = os.path.join(dirname, f"final_{rel_fname}.png") 373 | if args.mode == "self_correction": 374 | sdxl_refine(prompt, curr_output_fname, sdxl_output_fname) 375 | else: 376 | # For image editing, the prompt should be updated 377 | sdxl_refine(ret_dict.final_prompt, curr_output_fname, sdxl_output_fname) 378 | print("* Output File (After SDXL): ", sdxl_output_fname) -------------------------------------------------------------------------------- /benchmark_config.ini: -------------------------------------------------------------------------------- 1 | [openai] 2 | organization = YOUR_OPENAI_ORGANIZATION_ID 3 | api_key = YOUR_OPENAI_API_KEY 4 | model = gpt-4 5 | 6 | [SLD] 7 | default_seed = 78 8 | inv_seed = 37 9 | bg_seed = 42 10 | fg_seed = 9487 11 | attr_detection_threshold = 0.45 12 | prim_detection_threshold = 0.2 13 | nms_threshold = 0.15 14 | SAM_refine_dilate = 3 15 | diffedit_guidance_scale = 15.5 16 | diffedit_inpaint_strength = 0.87 17 | frozen_step_ratio = 0.5 18 | num_rounds = 1 19 | 20 | [eval] 21 | attr_detection_threshold = 0.45 22 | prim_detection_threshold = 0.2 23 | nms_threshold = 0.15 -------------------------------------------------------------------------------- /demo/image_editing/data.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "input_fname": "dalle3_banana", 4 | "output_dir": "dalle3_banana", 5 | "prompt": "Replace the left apple with a pumpkin and keep the right apple and all bananas in the scene unchanged", 6 | "generator": "dalle", 7 | "llm_parsed_prompt":{ 8 | "objects": [ 9 | ["apple", [null, null]], 10 | ["banana", [null]], 11 | ["pumpkin", [null]] 12 | ], 13 | "bg_prompt": "A realistic image", 14 | "neg_prompt": null 15 | }, 16 | "llm_layout_suggestions": [ 17 | ["pumpkin #1", [0.177, 0.567, 0.243, 0.246]], 18 | ["apple #2", [0.416, 0.579, 0.239, 0.256]], 19 | ["banana #1", [0.473, 0.235, 0.485, 0.504]], 20 | ["banana #2", [0.293, 0.199, 0.218, 0.464]], 21 | ["banana #3", [0.255, 0.202, 0.185, 0.379]], 22 | ["banana #4", [0.64, 0.701, 0.172, 0.099]], 23 | ["banana #5", [0.303, 0.213, 0.527, 0.583]], 24 | ["banana #6", [0.419, 0.283, 0.126, 0.321]] 25 | ] 26 | }, 27 | { 28 | "input_fname": "dalle3_dog", 29 | "output_dir": "dalle3_dog", 30 | "prompt": "Make the dog a sleeping dog and remove all shadows in an image of a grassland", 31 | "generator": "dalle", 32 | "llm_parsed_prompt":{ 33 | "objects": [["dog", ["sleeping"]], ["shadows", [null]]], 34 | "bg_prompt": "A realistic image of a grassland", 35 | "neg_prompt": "shadows" 36 | }, 37 | "llm_layout_suggestions": [["sleeping dog #1", [0.172, 0.15, 0.761, 0.577]]] 38 | }, 39 | { 40 | "input_fname": "indoor_scene", 41 | "output_dir": "indoor_scene_move", 42 | "prompt": "Shift the dog slightly upward and move the cat slightly towards the dog in a window-view image of an urban landscape.", 43 | "generator": "dalle", 44 | "llm_parsed_prompt":{ 45 | "objects": [["dog", [null]], ["cat", [null]]], 46 | "bg_prompt": "A window-view image of an urban landscape", 47 | "neg_prompt": null 48 | }, 49 | "llm_layout_suggestions": [["dog #1", [0.003, 0.209, 0.447, 0.62]], ["cat #1", [0.457, 0.342, 0.441, 0.595]]] 50 | }, 51 | { 52 | "input_fname": "indoor_scene", 53 | "output_dir": "indoor_scene_swap", 54 | "prompt": "Exchange the position of the black dog and yellow cat in a window-view image of an urban landscape. Keep their original shape.", 55 | "generator": "dalle", 56 | "llm_parsed_prompt":{ 57 | "objects": [["dog", ["black"]], ["cat", ["yellow"]]], 58 | "bg_prompt": "A window-view image of an urban landscape", 59 | "neg_prompt": null 60 | }, 61 | "llm_layout_suggestions": [["yellow cat #1", [0.003, 0.309, 0.441, 0.595]], ["black dog #1", [0.557, 0.342, 0.447, 0.62]]] 62 | }, 63 | { 64 | "input_fname": "indoor_scene", 65 | "output_dir": "indoor_scene_attr_mod", 66 | "prompt": "Transform the dog into a robotic dog and the cat into a sculptural cat in a window-view image of an urban landscape", 67 | "generator": "dalle", 68 | "llm_parsed_prompt":{ 69 | "objects": [["dog", [null]], ["cat", [null]]], 70 | "bg_prompt": "A window-view image of an urban landscape", 71 | "neg_prompt": null 72 | }, 73 | "llm_layout_suggestions": [["robotic dog #1", [0.003, 0.309, 0.447, 0.62]], ["sculptural cat #1", [0.557, 0.342, 0.441, 0.595]]] 74 | }, 75 | { 76 | "input_fname": "indoor_scene", 77 | "output_dir": "indoor_scene_resize", 78 | "prompt": "Slightly increase the cat's size and reduce the dog's size in a window-view image of an urban landscape. Keep their bottom unchanged.", 79 | "generator": "dalle", 80 | "llm_parsed_prompt":{ 81 | "objects": [["dog", [null]], ["cat", [null]]], 82 | "bg_prompt": "A window-view image of an urban landscape", 83 | "neg_prompt": null 84 | }, 85 | "llm_layout_suggestions": [["dog #1", [0.103, 0.409, 0.347, 0.52]], ["cat #1", [0.457, 0.242, 0.541, 0.695]]] 86 | }, 87 | { 88 | "input_fname": "indoor_scene", 89 | "output_dir": "indoor_scene_replace", 90 | "prompt": "Replace the dog with a houseplant and keep the cat unchanged in a window-view image of an urban landscape.", 91 | "generator": "dalle", 92 | "llm_parsed_prompt":{ 93 | "objects": [["dog", [null]], ["cat", [null]], ["houseplant", [null]]], 94 | "bg_prompt": "A window-view image of an urban landscape", 95 | "neg_prompt": null 96 | }, 97 | "llm_layout_suggestions": [["houseplant #1", [0.003, 0.309, 0.447, 0.62]], ["cat #1", [0.557, 0.342, 0.441, 0.595]]] 98 | }, 99 | { 100 | "input_fname": "indoor_table", 101 | "output_dir": "indoor_table_swap", 102 | "prompt": "Swap the position of bowl and the cup, maintaining their original shape", 103 | "generator": "dalle", 104 | "llm_parsed_prompt":{ 105 | "objects": [["bowl", [null]], ["cup", [null]]], 106 | "bg_prompt": "A realistic image", 107 | "neg_prompt": null 108 | }, 109 | "llm_layout_suggestions": [["bowl #1", [0.483, 0.477, 0.361, 0.18]], ["cup #1", [0.08, 0.365, 0.329, 0.255]]] 110 | }, 111 | { 112 | "input_fname": "indoor_table", 113 | "output_dir": "indoor_table_resize", 114 | "prompt": "Make the cup 1.25 times larger but keep the bowl's size unchanged", 115 | "generator": "dalle", 116 | "llm_parsed_prompt":{ 117 | "objects": [["bowl", [null]], ["cup", [null]]], 118 | "bg_prompt": "A realistic image", 119 | "neg_prompt": null 120 | }, 121 | "llm_layout_suggestions": [["bowl #1", [0.08, 0.415, 0.361, 0.18]], ["cup #1", [0.483, 0.447, 0.41125, 0.31875]]] 122 | }, 123 | { 124 | "input_fname": "indoor_table", 125 | "output_dir": "indoor_table_move", 126 | "prompt": "Move the cup towards the bowl until they slightly overlap", 127 | "generator": "dalle", 128 | "llm_parsed_prompt":{ 129 | "objects": [["bowl", [null]], ["cup", [null]]], 130 | "bg_prompt": "A realistic image", 131 | "neg_prompt": null 132 | }, 133 | "llm_layout_suggestions": [["bowl #1", [0.08, 0.415, 0.361, 0.18]], ["cup #1", [0.413, 0.447, 0.329, 0.255]]] 134 | }, 135 | { 136 | "input_fname": "indoor_table", 137 | "output_dir": "indoor_table_attr_mod", 138 | "prompt": "Turn the cup into green and the bowl into golden", 139 | "generator": "dalle", 140 | "llm_parsed_prompt":{ 141 | "objects": [["bowl", [null]], ["cup", [null]]], 142 | "bg_prompt": "A realistic image", 143 | "neg_prompt": null 144 | }, 145 | "llm_layout_suggestions": [["golden bowl #1", [0.08, 0.415, 0.361, 0.18]], ["green cup #1", [0.483, 0.447, 0.329, 0.255]]] 146 | }, 147 | { 148 | "input_fname": "indoor_table", 149 | "output_dir": "indoor_table_replace", 150 | "prompt": "Replace the bowl with two round apples and keep the cup unchanged", 151 | "generator": "dalle", 152 | "llm_parsed_prompt":{ 153 | "objects": [["bowl", [null]], ["cup", [null]], ["apple", ["round", "round"]]], 154 | "bg_prompt": "A realistic image", 155 | "neg_prompt": null 156 | }, 157 | "llm_layout_suggestions": [["cup #1", [0.483, 0.447, 0.329, 0.255]], ["apple #1", [0.08, 0.415, 0.18, 0.18]], ["apple #2", [0.26, 0.415, 0.18, 0.18]]] 158 | } 159 | ] -------------------------------------------------------------------------------- /demo/image_editing/results/dalle3_banana/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/dalle3_banana/det_result_obj.png -------------------------------------------------------------------------------- /demo/image_editing/results/dalle3_banana/final_dalle3_banana.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/dalle3_banana/final_dalle3_banana.png -------------------------------------------------------------------------------- /demo/image_editing/results/dalle3_banana/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/dalle3_banana/initial_image.png -------------------------------------------------------------------------------- /demo/image_editing/results/dalle3_banana/intermediate_dalle3_banana.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/dalle3_banana/intermediate_dalle3_banana.png -------------------------------------------------------------------------------- /demo/image_editing/results/dalle3_dog/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/dalle3_dog/det_result_obj.png -------------------------------------------------------------------------------- /demo/image_editing/results/dalle3_dog/final_dalle3_dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/dalle3_dog/final_dalle3_dog.png -------------------------------------------------------------------------------- /demo/image_editing/results/dalle3_dog/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/dalle3_dog/initial_image.png -------------------------------------------------------------------------------- /demo/image_editing/results/dalle3_dog/intermediate_dalle3_dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/dalle3_dog/intermediate_dalle3_dog.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_attr_mod/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_attr_mod/det_result_obj.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_attr_mod/final_indoor_scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_attr_mod/final_indoor_scene.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_attr_mod/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_attr_mod/initial_image.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_attr_mod/intermediate_indoor_scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_attr_mod/intermediate_indoor_scene.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_move/corrected_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_move/corrected_result.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_move/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_move/det_result_obj.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_move/final_indoor_scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_move/final_indoor_scene.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_move/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_move/initial_image.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_move/intermediate_indoor_scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_move/intermediate_indoor_scene.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_replace/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_replace/det_result_obj.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_replace/final_indoor_scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_replace/final_indoor_scene.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_replace/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_replace/initial_image.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_replace/intermediate_indoor_scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_replace/intermediate_indoor_scene.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_resize/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_resize/det_result_obj.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_resize/final_indoor_scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_resize/final_indoor_scene.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_resize/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_resize/initial_image.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_resize/intermediate_indoor_scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_resize/intermediate_indoor_scene.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_swap/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_swap/det_result_obj.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_swap/final_indoor_scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_swap/final_indoor_scene.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_swap/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_swap/initial_image.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_scene_swap/intermediate_indoor_scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_scene_swap/intermediate_indoor_scene.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_attr_mod/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_attr_mod/det_result_obj.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_attr_mod/final_indoor_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_attr_mod/final_indoor_table.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_attr_mod/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_attr_mod/initial_image.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_attr_mod/intermediate_indoor_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_attr_mod/intermediate_indoor_table.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_move/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_move/det_result_obj.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_move/final_indoor_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_move/final_indoor_table.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_move/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_move/initial_image.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_move/intermediate_indoor_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_move/intermediate_indoor_table.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_replace/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_replace/det_result_obj.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_replace/final_indoor_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_replace/final_indoor_table.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_replace/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_replace/initial_image.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_replace/intermediate_indoor_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_replace/intermediate_indoor_table.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_resize/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_resize/det_result_obj.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_resize/final_indoor_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_resize/final_indoor_table.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_resize/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_resize/initial_image.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_resize/intermediate_indoor_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_resize/intermediate_indoor_table.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_swap/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_swap/det_result_obj.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_swap/final_indoor_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_swap/final_indoor_table.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_swap/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_swap/initial_image.png -------------------------------------------------------------------------------- /demo/image_editing/results/indoor_table_swap/intermediate_indoor_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/results/indoor_table_swap/intermediate_indoor_table.png -------------------------------------------------------------------------------- /demo/image_editing/src_image/dalle3_banana.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/src_image/dalle3_banana.png -------------------------------------------------------------------------------- /demo/image_editing/src_image/dalle3_dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/src_image/dalle3_dog.png -------------------------------------------------------------------------------- /demo/image_editing/src_image/indoor_scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/src_image/indoor_scene.png -------------------------------------------------------------------------------- /demo/image_editing/src_image/indoor_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/image_editing/src_image/indoor_table.png -------------------------------------------------------------------------------- /demo/self_correction/data.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "input_fname": "lmdplus_beach", 4 | "output_dir": "lmdplus_beach", 5 | "prompt": "An oil painting at the beach of a blue bicycle to the left of a bench and to the right of a palm tree with three seagulls in the sky", 6 | "generator": "lmdplus", 7 | "llm_parsed_prompt":{ 8 | "objects": [ 9 | ["bicycle", ["blue"]], 10 | ["palm tree", [null]], 11 | ["seagull", [null, null, null]], 12 | ["bench", [null]] 13 | ], 14 | "bg_prompt": "An oil painting at the beach", 15 | "neg_prompt": null 16 | }, 17 | "llm_layout_suggestions": [ 18 | ["blue bicycle #1", [0.185, 0.597, 0.359, 0.327]], 19 | ["palm tree #1", [0.141, 0.209, 0.264, 0.544]], 20 | ["seagull #1", [0.389, 0.028, 0.219, 0.169]], 21 | ["bench #1", [0.479, 0.573, 0.519, 0.351]], 22 | ["seagull #2", [0.577, 0.053, 0.219, 0.169]], 23 | ["seagull #3", [0.777, 0.093, 0.219, 0.169]]] 24 | }, 25 | { 26 | "input_fname": "dalle3_beach", 27 | "output_dir": "dalle3_beach", 28 | "prompt": "An oil painting at the beach of a blue bicycle to the left of a bench and to the right of a palm tree with three seagulls in the sky", 29 | "generator": "dalle", 30 | "llm_parsed_prompt":{ 31 | "objects": [ 32 | ["bicycle", ["blue"]], 33 | ["palm tree", [null]], 34 | ["seagull", [null, null, null]], 35 | ["bench", [null]] 36 | ], 37 | "bg_prompt": "An oil painting at the beach", 38 | "neg_prompt": null 39 | }, 40 | "llm_layout_suggestions": [ 41 | ["blue bicycle #1", [0.2, 0.667, 0.423, 0.255]], 42 | ["palm tree #1", [0.0, 0.008, 0.515, 0.851]], 43 | ["seagull #1", [0.65, 0.361, 0.155, 0.056]], 44 | ["seagull #2", [0.567, 0.058, 0.146, 0.132]], 45 | ["seagull #3", [0.65, 0.238, 0.14, 0.076]], 46 | ["bench #1", [0.6, 0.725, 0.364, 0.173]] 47 | ] 48 | }, 49 | { 50 | "input_fname": "lmdplus_motor", 51 | "output_dir": "lmdplus_motor", 52 | "prompt": "A realistic photo with a monkey sitting above a green motorcycle on the left and another raccoon sittig above a blue motorcycle on the right", 53 | "generator": "lmdplus", 54 | "llm_parsed_prompt": { 55 | "objects": [ 56 | ["monkey", [null]], 57 | ["motorcycle", ["green", "blue"]], 58 | ["raccoon", [null]] 59 | ], 60 | "bg_prompt": "A realistic photo", 61 | "neg_prompt": null 62 | }, 63 | "llm_layout_suggestions": [ 64 | ["monkey #1", [0.022, 0.086, 0.39, 0.486]], 65 | ["green motorcycle #1", [0.008, 0.325, 0.472, 0.671]], 66 | ["blue motorcycle #1", [0.548, 0.394, 0.405, 0.568]], 67 | ["raccoon #1", [0.549, 0.026, 0.402, 0.483]] 68 | ] 69 | }, 70 | { 71 | "input_fname": "dalle3_motor", 72 | "output_dir": "dalle3_motor", 73 | "prompt": "A realistic photo with a monkey sitting above a green motorcycle on the left and another raccoon sitting above a blue motorcycle on the right", 74 | "generator": "dalle", 75 | "llm_parsed_prompt": { 76 | "objects": [ 77 | ["monkey", [null]], 78 | ["motorcycle", ["green", "blue"]], 79 | ["raccoon", [null]] 80 | ], 81 | "bg_prompt": "A realistic photo", 82 | "neg_prompt": null 83 | }, 84 | "llm_layout_suggestions": [ 85 | ["monkey #1", [0.009, 0.006, 0.481, 0.821]], 86 | ["green motorcycle #1", [0.016, 0.329, 0.506, 0.6]], 87 | ["blue motorcycle #1", [0.516, 0.329, 0.484, 0.6]], 88 | ["raccoon #1", [0.46, 0.123, 0.526, 0.62]] 89 | ] 90 | }, 91 | { 92 | "input_fname": "dalle3_snowwhite", 93 | "output_dir": "dalle3_snowwhite", 94 | "prompt": "A realistic cartoon-style painting with a princess and four dwarfs", 95 | "generator": "dalle", 96 | "llm_parsed_prompt":{ 97 | "objects": [ 98 | ["princess", [null]], 99 | ["dwarf", [null, null, null, null]] 100 | ], 101 | "bg_prompt": "A realistic cartoon-style painting", 102 | "neg_prompt": null 103 | }, 104 | "llm_layout_suggestions": [ 105 | ["princess #1", [0.002, 0.27, 0.448, 0.738]], 106 | ["dwarf #1", [0.179, 0.339, 0.286, 0.568]], 107 | ["dwarf #2", [0.424, 0.322, 0.227, 0.318]], 108 | ["dwarf #3", [0.681, 0.361, 0.311, 0.641]], 109 | ["dwarf #4", [0.617, 0.348, 0.25, 0.382]] 110 | ] 111 | }, 112 | { 113 | "input_fname": "dalle3_clown", 114 | "output_dir": "dalle3_clown", 115 | "prompt": "A vivid photo where a woman on the right and a clown on the left are walking in a dirty alley", 116 | "generator": "dalle", 117 | "llm_parsed_prompt":{ 118 | "objects": [ 119 | ["woman", [null]], 120 | ["clown", [null]] 121 | ], 122 | "bg_prompt": "A vivid photo in a dirty alley", 123 | "neg_prompt": null 124 | }, 125 | "llm_layout_suggestions": [["woman #1", [0.542, 0.209, 0.237, 0.601]], ["clown #1", [0.206, 0.153, 0.244, 0.81]]] 126 | }, 127 | { 128 | "input_fname": "sdxl_beach", 129 | "output_dir": "sdxl_beach", 130 | "prompt": "An oil painting at the beach of a blue bicycle to the left of a bench and to the right of a palm tree with three seagulls in the sky", 131 | "generator": "sdxl", 132 | "llm_parsed_prompt":{ 133 | "objects": [ 134 | ["bicycle", ["blue"]], 135 | ["palm tree", [null]], 136 | ["seagull", [null, null, null]], 137 | ["bench", [null]] 138 | ], 139 | "bg_prompt": "An oil painting at the beach", 140 | "neg_prompt": null 141 | }, 142 | "llm_layout_suggestions": [["palm tree #1", [0.001, 0.002, 0.462, 0.908]], ["seagull #8", [0.386, 0.091, 0.253, 0.139]], ["seagull #2", [0.641, 0.532, 0.116, 0.061]], ["seagull #11", [0.316, 0.428, 0.266, 0.182]], ["bench #1", [0.444, 0.771, 0.368, 0.146]], ["blue bicycle #1", [0.2, 0.65, 0.3, 0.3]]] 143 | }, 144 | { 145 | "input_fname": "sdxl_motor", 146 | "output_dir": "sdxl_motor", 147 | "prompt": "A realistic photo with a monkey sitting above a green motorcycle on the left and another raccoon sitting above a blue motorcycle on the right", 148 | "generator": "sdxl", 149 | "llm_parsed_prompt": { 150 | "objects": [ 151 | ["monkey", [null]], 152 | ["motorcycle", ["green", "blue"]], 153 | ["raccoon", [null]] 154 | ], 155 | "bg_prompt": "A realistic photo", 156 | "neg_prompt": null 157 | }, 158 | "llm_layout_suggestions": [["green motorcycle #1", [0.025, 0.357, 0.448, 0.586]], ["blue motorcycle #1", [0.541, 0.365, 0.496, 0.602]], ["monkey #1", [0.06, 0.077, 0.422, 0.73]], ["raccoon #1", [0.571, 0.083, 0.39, 0.833]]] 159 | } 160 | ] -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_beach/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_beach/det_result_obj.png -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_beach/final_dalle3_beach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_beach/final_dalle3_beach.png -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_beach/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_beach/initial_image.png -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_beach/intermediate_dalle3_beach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_beach/intermediate_dalle3_beach.png -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_clown/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_clown/det_result_obj.png -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_clown/final_dalle3_clown.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_clown/final_dalle3_clown.png -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_clown/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_clown/initial_image.png -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_clown/intermediate_dalle3_clown.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_clown/intermediate_dalle3_clown.png -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_motor/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_motor/det_result_obj.png -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_motor/final_dalle3_motor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_motor/final_dalle3_motor.png -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_motor/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_motor/initial_image.png -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_motor/intermediate_dalle3_motor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_motor/intermediate_dalle3_motor.png -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_snowwhite/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_snowwhite/det_result_obj.png -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_snowwhite/final_dalle3_snowwhite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_snowwhite/final_dalle3_snowwhite.png -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_snowwhite/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_snowwhite/initial_image.png -------------------------------------------------------------------------------- /demo/self_correction/results/dalle3_snowwhite/intermediate_dalle3_snowwhite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/dalle3_snowwhite/intermediate_dalle3_snowwhite.png -------------------------------------------------------------------------------- /demo/self_correction/results/lmdplus_beach/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/lmdplus_beach/det_result_obj.png -------------------------------------------------------------------------------- /demo/self_correction/results/lmdplus_beach/final_lmdplus_beach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/lmdplus_beach/final_lmdplus_beach.png -------------------------------------------------------------------------------- /demo/self_correction/results/lmdplus_beach/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/lmdplus_beach/initial_image.png -------------------------------------------------------------------------------- /demo/self_correction/results/lmdplus_beach/intermediate_lmdplus_beach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/lmdplus_beach/intermediate_lmdplus_beach.png -------------------------------------------------------------------------------- /demo/self_correction/results/lmdplus_motor/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/lmdplus_motor/det_result_obj.png -------------------------------------------------------------------------------- /demo/self_correction/results/lmdplus_motor/final_lmdplus_motor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/lmdplus_motor/final_lmdplus_motor.png -------------------------------------------------------------------------------- /demo/self_correction/results/lmdplus_motor/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/lmdplus_motor/initial_image.png -------------------------------------------------------------------------------- /demo/self_correction/results/lmdplus_motor/intermediate_lmdplus_motor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/lmdplus_motor/intermediate_lmdplus_motor.png -------------------------------------------------------------------------------- /demo/self_correction/results/sdxl_beach/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/sdxl_beach/det_result_obj.png -------------------------------------------------------------------------------- /demo/self_correction/results/sdxl_beach/final_sdxl_beach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/sdxl_beach/final_sdxl_beach.png -------------------------------------------------------------------------------- /demo/self_correction/results/sdxl_beach/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/sdxl_beach/initial_image.png -------------------------------------------------------------------------------- /demo/self_correction/results/sdxl_beach/intermediate_sdxl_beach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/sdxl_beach/intermediate_sdxl_beach.png -------------------------------------------------------------------------------- /demo/self_correction/results/sdxl_motor/det_result_obj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/sdxl_motor/det_result_obj.png -------------------------------------------------------------------------------- /demo/self_correction/results/sdxl_motor/final_sdxl_motor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/sdxl_motor/final_sdxl_motor.png -------------------------------------------------------------------------------- /demo/self_correction/results/sdxl_motor/initial_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/sdxl_motor/initial_image.png -------------------------------------------------------------------------------- /demo/self_correction/results/sdxl_motor/intermediate_sdxl_motor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/results/sdxl_motor/intermediate_sdxl_motor.png -------------------------------------------------------------------------------- /demo/self_correction/src_image/dalle3_beach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/src_image/dalle3_beach.png -------------------------------------------------------------------------------- /demo/self_correction/src_image/dalle3_car.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/src_image/dalle3_car.png -------------------------------------------------------------------------------- /demo/self_correction/src_image/dalle3_clown.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/src_image/dalle3_clown.png -------------------------------------------------------------------------------- /demo/self_correction/src_image/dalle3_motor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/src_image/dalle3_motor.png -------------------------------------------------------------------------------- /demo/self_correction/src_image/dalle3_snowwhite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/src_image/dalle3_snowwhite.png -------------------------------------------------------------------------------- /demo/self_correction/src_image/lmdplus_beach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/src_image/lmdplus_beach.png -------------------------------------------------------------------------------- /demo/self_correction/src_image/lmdplus_motor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/src_image/lmdplus_motor.png -------------------------------------------------------------------------------- /demo/self_correction/src_image/sdxl_beach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/src_image/sdxl_beach.png -------------------------------------------------------------------------------- /demo/self_correction/src_image/sdxl_motor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunghan-wu/SLD/8b730ef5b44195127e8bd64e8d188eadb685721f/demo/self_correction/src_image/sdxl_motor.png -------------------------------------------------------------------------------- /demo_config.ini: -------------------------------------------------------------------------------- 1 | [openai] 2 | organization = YOUR_OPENAI_ORGANIZATION_ID 3 | api_key = YOUR_OPENAI_API_KEY 4 | model = "gpt-4" 5 | 6 | [dalle] 7 | attr_detection_threshold = 0.6 8 | prim_detection_threshold = 0.2 9 | nms_threshold = 0.5 10 | 11 | [lmdplus] 12 | attr_detection_threshold = 0.6 13 | prim_detection_threshold = 0.2 14 | nms_threshold = 0.3 15 | 16 | [sdxl] 17 | attr_detection_threshold = 0.5 18 | prim_detection_threshold = 0.2 19 | nms_threshold = 0.5 20 | 21 | [SLD] 22 | default_seed = 78 23 | inv_seed = 37 24 | bg_seed = 42 25 | fg_seed = 9487 26 | attr_detection_threshold = 0.6 27 | prim_detection_threshold = 0.2 28 | nms_threshold = 0.5 29 | SAM_refine_dilate = 3 30 | diffedit_guidance_scale = 10.5 31 | diffedit_inpaint_strength = 0.8 32 | frozen_step_ratio = 0.5 -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval import * 2 | -------------------------------------------------------------------------------- /eval/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | from transformers import Owlv2Processor, Owlv2ForObjectDetection 5 | 6 | def get_eval_info_from_prompt(prompt, prompt_type): 7 | if prompt_type.startswith("lmd"): 8 | from .lmd import get_eval_info_from_prompt_lmd 9 | return get_eval_info_from_prompt_lmd(prompt) 10 | raise ValueError(f"Unknown prompt type: {prompt_type}") 11 | 12 | def nms(bounding_boxes, confidence_score, labels, threshold, input_in_pixels=False, return_array=True): 13 | """ 14 | This NMS processes boxes of all labels. It not only removes the box with the same label. 15 | 16 | Adapted from https://github.com/amusi/Non-Maximum-Suppression/blob/master/nms.py 17 | """ 18 | # If no bounding boxes, return empty list 19 | if len(bounding_boxes) == 0: 20 | return np.array([]), np.array([]), np.array([]) 21 | 22 | # Bounding boxes 23 | boxes = np.array(bounding_boxes) 24 | 25 | # coordinates of bounding boxes 26 | start_x = boxes[:, 0] 27 | start_y = boxes[:, 1] 28 | end_x = boxes[:, 2] 29 | end_y = boxes[:, 3] 30 | 31 | # Confidence scores of bounding boxes 32 | score = np.array(confidence_score) 33 | 34 | # Picked bounding boxes 35 | picked_boxes = [] 36 | picked_score = [] 37 | picked_labels = [] 38 | 39 | # Compute areas of bounding boxes 40 | if input_in_pixels: 41 | areas = (end_x - start_x + 1) * (end_y - start_y + 1) 42 | else: 43 | areas = (end_x - start_x) * (end_y - start_y) 44 | 45 | # Sort by confidence score of bounding boxes 46 | order = np.argsort(score) 47 | 48 | # Iterate bounding boxes 49 | while order.size > 0: 50 | # The index of largest confidence score 51 | index = order[-1] 52 | 53 | # Pick the bounding box with largest confidence score 54 | picked_boxes.append(bounding_boxes[index]) 55 | picked_score.append(confidence_score[index]) 56 | picked_labels.append(labels[index]) 57 | 58 | # Compute ordinates of intersection-over-union(IOU) 59 | x1 = np.maximum(start_x[index], start_x[order[:-1]]) 60 | x2 = np.minimum(end_x[index], end_x[order[:-1]]) 61 | y1 = np.maximum(start_y[index], start_y[order[:-1]]) 62 | y2 = np.minimum(end_y[index], end_y[order[:-1]]) 63 | 64 | # Compute areas of intersection-over-union 65 | if input_in_pixels: 66 | w = np.maximum(0.0, x2 - x1 + 1) 67 | h = np.maximum(0.0, y2 - y1 + 1) 68 | else: 69 | w = np.maximum(0.0, x2 - x1) 70 | h = np.maximum(0.0, y2 - y1) 71 | intersection = w * h 72 | 73 | # Compute the ratio between intersection and union 74 | ratio = intersection / (areas[index] + areas[order[:-1]] - intersection) 75 | 76 | left = np.where(ratio < threshold) 77 | order = order[left] 78 | 79 | if return_array: 80 | picked_boxes, picked_score, picked_labels = np.array(picked_boxes), np.array(picked_score), np.array(picked_labels) 81 | 82 | return picked_boxes, picked_score, picked_labels 83 | 84 | def class_aware_nms(bounding_boxes, confidence_score, labels, threshold, input_in_pixels=False): 85 | """ 86 | This NMS processes boxes of each label individually. 87 | """ 88 | # If no bounding boxes, return empty list 89 | if len(bounding_boxes) == 0: 90 | return np.array([]), np.array([]), np.array([]) 91 | 92 | picked_boxes, picked_score, picked_labels = [], [], [] 93 | 94 | labels_unique = np.unique(labels) 95 | for label in labels_unique: 96 | bounding_boxes_label = [bounding_box for i, bounding_box in enumerate(bounding_boxes) if labels[i] == label] 97 | confidence_score_label = [confidence_score_item for i, confidence_score_item in enumerate(confidence_score) if labels[i] == label] 98 | labels_label = [label] * len(bounding_boxes_label) 99 | picked_boxes_label, picked_score_label, picked_labels_label = nms(bounding_boxes_label, confidence_score_label, labels_label, threshold=threshold, input_in_pixels=input_in_pixels, return_array=False) 100 | picked_boxes += picked_boxes_label 101 | picked_score += picked_score_label 102 | picked_labels += picked_labels_label 103 | 104 | picked_boxes, picked_score, picked_labels = np.array(picked_boxes), np.array(picked_score), np.array(picked_labels) 105 | 106 | return picked_boxes, picked_score, picked_labels 107 | 108 | def evaluate_with_boxes(boxes, eval_info, verbose=False): 109 | predicate = eval_info["predicate"] 110 | 111 | print("boxes:", boxes) 112 | 113 | return predicate(boxes, verbose) 114 | 115 | def to_gen_box_format(box, width, height): 116 | # Input: xyxy, ranging from 0 to 1 117 | # Output: xywh, unnormalized (in pixels) 118 | x_min, y_min, x_max, y_max = box 119 | return [x_min * width, y_min * height, (x_max - x_min) * width, (y_max - y_min) * height] 120 | 121 | @torch.no_grad() 122 | def eval_prompt(p, path, evaluator, prim_score_threshold = 0.2, attr_score_threshold=0.45, 123 | nms_threshold = 0.5, use_class_aware_nms=False, verbose=False, use_cuda=True): 124 | texts, eval_info = get_eval_info_from_prompt(p, "lmd") 125 | eval_type = eval_info["type"] 126 | if eval_type == "attribution": 127 | score_threshold = attr_score_threshold 128 | else: 129 | score_threshold = prim_score_threshold 130 | 131 | image = Image.open(path) 132 | inputs = evaluator.processor(text=texts, images=image, return_tensors="pt") 133 | if use_cuda: 134 | inputs = inputs.to("cuda") 135 | outputs = evaluator.model(**inputs) 136 | 137 | width, height = image.size 138 | 139 | # Target image sizes (height, width) to rescale box predictions [batch_size, 2] 140 | target_sizes = torch.Tensor([[height, width]]) 141 | if use_cuda: 142 | target_sizes = target_sizes.cuda() 143 | # Convert outputs (bounding boxes and class logits) to COCO API 144 | results = evaluator.processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes) 145 | 146 | i = 0 # Retrieve predictions for the first image for the corresponding text queries 147 | text = texts[i] 148 | boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] 149 | boxes = boxes.cpu() 150 | # xyxy ranging from 0 to 1 151 | boxes = np.array([[x_min / width, y_min / height, x_max / width, y_max / height] 152 | for (x_min, y_min, x_max, y_max), score in zip(boxes, scores) if score >= score_threshold]) 153 | labels = np.array([label.cpu().numpy() for label, score in zip( 154 | labels, scores) if score >= score_threshold]) 155 | scores = np.array([score.cpu().numpy() 156 | for score in scores if score >= score_threshold]) 157 | 158 | # print(f"Pre-NMS:") 159 | # for box, score, label in zip(boxes, scores, labels): 160 | # box = [round(i, 2) for i in box.tolist()] 161 | # print( 162 | # f"Detected {text[label]} ({label}) with confidence {round(score.item(), 3)} at location {box}") 163 | 164 | print("Post-NMS:") 165 | 166 | if use_class_aware_nms: 167 | boxes, scores, labels = class_aware_nms(boxes, scores, labels, nms_threshold) 168 | else: 169 | boxes, scores, labels = nms(boxes, scores, labels, nms_threshold) 170 | for box, score, label in zip(boxes, scores, labels): 171 | box = [round(i, 2) for i in box.tolist()] 172 | print(f"Detected {text[label]} ({label}) with confidence {round(score.item(), 3)} at location {box}") 173 | 174 | if verbose: 175 | print(f"prompt: {p}, texts: {texts}, boxes: {boxes}, labels: {labels}, eval_info: {eval_info}") 176 | 177 | det_boxes = [{"name": text[label], "bounding_box": to_gen_box_format(box, width, height), "score": score} for box, score, label in zip(boxes, scores, labels)] 178 | 179 | eval_success = evaluate_with_boxes(det_boxes, eval_info, verbose=verbose) 180 | 181 | return eval_type, eval_success 182 | 183 | 184 | 185 | class Evaluator: 186 | def __init__(self): 187 | self.processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") 188 | self.model = Owlv2ForObjectDetection.from_pretrained( 189 | "google/owlv2-base-patch16-ensemble" 190 | ).cuda() -------------------------------------------------------------------------------- /eval/lmd.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | from functools import partial 4 | from .utils import p, singular, predicate_numeracy, predicate_numeracy_2obj, locations_xywh, predicate_spatial, predicate_attribution, word_to_num_mapping 5 | 6 | prompt_prefix = "A realistic photo of a scene" 7 | 8 | evaluate_classes = ['backpack', 'book', 'bottle', 9 | 'bowl', 'car', 'cat', 'chair', 'cup', 'dog', 'laptop'] 10 | 11 | def get_eval_info_from_prompt_lmd(prompt): 12 | """ 13 | Note: object_name needs to be a substring of each item in texts to make `count` and `get_box` in the predicate work 14 | """ 15 | if 'without' in prompt: 16 | # negation 17 | 18 | pattern = f"without (.+)" 19 | match = re.search(pattern, prompt) 20 | object_name = match.group(1) 21 | object_name = singular(object_name) 22 | texts = [[f"image of {p.a(object_name)}"]] 23 | query_names = (object_name,) 24 | number = 0 25 | predicate = partial(predicate_numeracy, query_names, number) 26 | eval_info = {"type": "negation", "predicate": predicate} 27 | elif 'on the left' in prompt or 'on the right' in prompt or 'on the top' in prompt or 'on the bottom' in prompt: 28 | # spatial 29 | pattern = f"with (.+) on the (.+) and (.+) on the (.+)" 30 | match = re.search(pattern, prompt) 31 | print("prompt:", prompt) 32 | object_name1, location1 = match.group(1), match.group(2) 33 | object_name2, location2 = match.group(3), match.group(4) 34 | texts = [[f"image of {object_name1}", f"image of {object_name2}"]] 35 | query_names1, query_names2 = (object_name1, ), (object_name2, ) 36 | 37 | verify_fn = locations_xywh[(location1, location2)] 38 | 39 | predicate = partial(predicate_spatial, query_names1, query_names2, verify_fn) 40 | eval_info = {"type": "spatial", "location1": location1, "location2": location2, "predicate": predicate} 41 | elif 'and' in prompt: # no spatial keyword 42 | if 'one' in prompt or 'two' in prompt or 'three' in prompt or 'four' in prompt or 'five' in prompt: 43 | # numeracy 2obj 44 | pattern = f"with (.+) (.+) and (.+) (.+)" 45 | match = re.search(pattern, prompt) 46 | number1, object_name1 = match.group(1), match.group(2) 47 | number2, object_name2 = match.group(1), match.group(2) 48 | 49 | number1 = word_to_num_mapping[number1] if number1 in word_to_num_mapping else int(number1) 50 | number2 = word_to_num_mapping[number2] if number2 in word_to_num_mapping else int(number2) 51 | 52 | object_name1, object_name2 = singular(object_name1), singular(object_name2) 53 | texts = [[f"image of {p.a(object_name1)}", f"image of {p.a(object_name2)}"]] 54 | query_names1, query_names2 = (object_name1,), (object_name2,) 55 | 56 | predicate = partial(predicate_numeracy_2obj, query_names1, number1, query_names2, number2) 57 | 58 | eval_info = {"type": "numeracy_2obj", "object_name1": object_name1, "number1": number1, "object_name2": object_name2, "number2": number2, "predicate": predicate} 59 | else: 60 | # attribution 61 | # NOTE: we should match against other modifiers 62 | 63 | assert 'on the' not in prompt, prompt 64 | 65 | pattern = f"with (.+) and (.+)" 66 | match = re.search(pattern, prompt) 67 | # object_name has a modifier 68 | object_name1 = match.group(1) 69 | object_name2 = match.group(2) 70 | texts = [[f"image of {object_name1}", f"image of {object_name2}"]] 71 | query_names1, query_names2 = (object_name1, ), (object_name2, ) 72 | 73 | modifier1, modifier2, intended_count1, intended_count2 = None, None, 1, 1 74 | 75 | predicate = partial(predicate_attribution, query_names1, query_names2, modifier1, modifier2, intended_count1, intended_count2) 76 | eval_info = {"type": "attribution", "object_name1": object_name1, "object_name2": object_name2, "predicate": predicate} 77 | elif 'with' in prompt: # with number words 78 | # numeracy 79 | pattern = f"with (.+) (.+)" 80 | match = re.search(pattern, prompt) 81 | number, object_name = match.group(1), match.group(2) 82 | 83 | if number not in word_to_num_mapping: 84 | number = int(number) 85 | else: 86 | number = word_to_num_mapping[number] 87 | object_name = singular(object_name) 88 | texts = [[f"image of {p.a(object_name)}"]] 89 | query_names = (object_name,) 90 | 91 | predicate = partial(predicate_numeracy, query_names, number) 92 | eval_info = {"type": "numeracy", "object_name": object_name, "number": number, "predicate": predicate} 93 | else: 94 | raise ValueError(f"Unknown LMD prompt type: {prompt}") 95 | 96 | return texts, eval_info 97 | 98 | 99 | def get_prompt_predicates_negation(repeat=10): 100 | modifier = '' 101 | 102 | prompt_predicates = [] 103 | number = 0 104 | 105 | for object_name in evaluate_classes: 106 | if isinstance(object_name, tuple): 107 | query_names = object_name 108 | object_name = object_name[0] 109 | else: 110 | query_names = (object_name,) 111 | 112 | if prompt_prefix: 113 | prompt = f"{prompt_prefix} without{modifier} {p.plural(object_name)}" 114 | else: 115 | prompt = f"without{modifier} {p.plural(object_name)}" 116 | prompt = prompt.strip() 117 | 118 | # Shouldn't use lambda here since query_names (and number) might change. 119 | prompt_predicate = prompt, partial(predicate_numeracy, query_names, number) 120 | 121 | prompt_predicates += [prompt_predicate] * repeat 122 | 123 | return prompt_predicates 124 | 125 | def get_prompt_predicates_numeracy(min_num=1, max_num=5, repeat=2): 126 | modifier = '' 127 | 128 | prompt_predicates = [] 129 | 130 | for number in range(min_num, max_num + 1): 131 | for object_name in evaluate_classes: 132 | if isinstance(object_name, tuple): 133 | query_names = object_name 134 | object_name = object_name[0] 135 | else: 136 | query_names = (object_name,) 137 | 138 | if prompt_prefix: 139 | prompt = f"{prompt_prefix} with {p.number_to_words(number) if number < 21 else number}{modifier} {p.plural(object_name) if number > 1 else object_name}" 140 | else: 141 | prompt = f"{p.number_to_words(number) if number < 21 else number}{modifier} {p.plural(object_name) if number > 1 else object_name}" 142 | prompt = prompt.strip() 143 | 144 | prompt_predicate = prompt, partial(predicate_numeracy, query_names, number) 145 | 146 | prompt_predicates += [prompt_predicate] * repeat 147 | 148 | return prompt_predicates 149 | 150 | def process_object_name(object_name): 151 | if isinstance(object_name, tuple): 152 | query_names = object_name 153 | object_name = object_name[0] 154 | else: 155 | query_names = (object_name,) 156 | 157 | return object_name, query_names 158 | 159 | def get_prompt_predicates_attribution(num_prompts=100, repeat=1): 160 | prompt_predicates = [] 161 | 162 | intended_count1, intended_count2 = 1, 1 163 | 164 | evaluate_classes_np = np.array(evaluate_classes, dtype=object) 165 | 166 | modifiers = ['red', 'orange', 'yellow', 'green', 'blue', 167 | 'purple', 'pink', 'brown', 'black', 'white', 'gray'] 168 | 169 | for ind in range(num_prompts): 170 | np.random.seed(ind) 171 | modifier1, modifier2 = np.random.choice(modifiers, 2, replace=False) 172 | object_name1, object_name2 = np.random.choice( 173 | evaluate_classes_np, 2, replace=False) 174 | 175 | object_name1, query_names1 = process_object_name(object_name1) 176 | object_name2, query_names2 = process_object_name(object_name2) 177 | 178 | if prompt_prefix: 179 | prompt = f"{prompt_prefix} with {p.a(modifier1)} {object_name1} and {p.a(modifier2)} {object_name2}" 180 | else: 181 | prompt = f"{p.a(modifier1)} {object_name1} and {p.a(modifier2)} {object_name2}" 182 | prompt = prompt.strip() 183 | 184 | prompt_predicate = prompt, partial(predicate_attribution, query_names1, query_names2, modifier1, modifier2, intended_count1, intended_count2) 185 | 186 | prompt_predicates += [prompt_predicate] * repeat 187 | 188 | return prompt_predicates 189 | 190 | def get_prompt_predicates_spatial(num_prompts=25, left_right_only=False): 191 | 192 | prompt_predicates = [] 193 | 194 | repeat = 1 195 | 196 | evaluate_classes_np = np.array(evaluate_classes, dtype=object) 197 | 198 | # NOTE: the boxes are in (x, y, w, h) format, not (x_min, y_min, x_max, y_max), different from the predicates in eval.py 199 | locations = [ 200 | ('left', 'right', lambda box1, 201 | box2: box1[0] + box1[2]/2 < box2[0] + box2[2]/2), 202 | ('right', 'left', lambda box1, 203 | box2: box1[0] + box1[2]/2 > box2[0] + box2[2]/2), 204 | ] 205 | if not left_right_only: 206 | # NOTE: the boxes are in (x, y, w, h) format 207 | locations += [ 208 | ('top', 'bottom', lambda box1, 209 | box2: box1[1] + box1[3]/2 < box2[1] + box2[3]/2), 210 | ('bottom', 'top', lambda box1, 211 | box2: box1[1] + box1[3]/2 > box2[1] + box2[3]/2) 212 | ] 213 | 214 | for ind in range(num_prompts): 215 | np.random.seed(ind) 216 | for location1, location2, verify_fn in locations: 217 | object_name1, object_name2 = np.random.choice( 218 | evaluate_classes_np, 2, replace=False) 219 | 220 | object_name1, query_names1 = process_object_name(object_name1) 221 | object_name2, query_names2 = process_object_name(object_name2) 222 | 223 | if prompt_prefix: 224 | prompt = f"{prompt_prefix} with {p.a(object_name1)} on the {location1} and {p.a(object_name2)} on the {location2}" 225 | else: 226 | prompt = f"{p.a(object_name1)} on the {location1} and {p.a(object_name2)} on the {location2}" 227 | prompt = prompt.strip() 228 | 229 | prompt_predicate = prompt, partial(predicate_spatial, query_names1, query_names2, verify_fn) 230 | 231 | prompt_predicates += [prompt_predicate] * repeat 232 | 233 | return prompt_predicates 234 | 235 | def get_lmd_prompts(): 236 | # negation 237 | prompt_predicates_negation = get_prompt_predicates_negation(repeat=10) 238 | # numeracy 239 | prompt_predicates_numeracy = get_prompt_predicates_numeracy(max_num=5, repeat=2) 240 | # attribution 241 | prompt_predicates_attribution = get_prompt_predicates_attribution(num_prompts=100) 242 | # spatial 243 | prompt_predicates_spatial = get_prompt_predicates_spatial(num_prompts=25) 244 | 245 | prompts_negation = [prompt for prompt, _ in prompt_predicates_negation] 246 | prompts_numeracy = [prompt for prompt, _ in prompt_predicates_numeracy] 247 | prompts_attribution = [prompt for prompt, _ in prompt_predicates_attribution] 248 | prompts_spatial = [prompt for prompt, _ in prompt_predicates_spatial] 249 | 250 | prompts_all = prompts_negation + prompts_numeracy + prompts_attribution + prompts_spatial 251 | 252 | prompts = { 253 | 'lmd': prompts_all, 254 | 'lmd_negation': prompts_negation, 255 | 'lmd_numeracy': prompts_numeracy, 256 | 'lmd_attribution': prompts_attribution, 257 | 'lmd_spatial': prompts_spatial, 258 | } 259 | 260 | return prompts 261 | -------------------------------------------------------------------------------- /eval/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import inflect 3 | import re 4 | 5 | p = inflect.engine() 6 | 7 | 8 | # Credit: GPT 9 | def find_word_after(text, word): 10 | pattern = r"\b" + re.escape(word) + r"\s+(.+)" 11 | match = re.search(pattern, text) 12 | if match: 13 | return match.group(1) 14 | else: 15 | return None 16 | 17 | 18 | word_to_num_mapping = {p.number_to_words(i): i for i in range(1, 21)} 19 | 20 | # New predicates that use the center 21 | locations_xyxy = { 22 | ('left', 'right'): (lambda box1, box2: (box1[0] + box1[2]) < (box2[0] + box2[2])), 23 | ('right', 'left'): (lambda box1, box2: (box1[0] + box1[2]) > (box2[0] + box2[2])), 24 | ('top', 'bottom'): (lambda box1, box2: (box1[1] + box1[3]) < (box2[1] + box2[3])), 25 | ('bottom', 'top'): (lambda box1, box2: (box1[1] + box1[3]) > (box2[1] + box2[3])) 26 | } 27 | 28 | locations_xywh = { 29 | ('left', 'right'): (lambda box1, box2: box1[0] + box1[2]/2 < box2[0] + box2[2]/2), 30 | ('right', 'left'): (lambda box1, box2: box1[0] + box1[2]/2 > box2[0] + box2[2]/2), 31 | ('top', 'bottom'): (lambda box1, box2: box1[1] + box1[3]/2 < box2[1] + box2[3]/2), 32 | ('bottom', 'top'): (lambda box1, box2: box1[1] + box1[3]/2 > box2[1] + box2[3]/2) 33 | } 34 | 35 | 36 | def singular(noun): 37 | singular_noun = p.singular_noun(noun) 38 | if singular_noun is False: 39 | return noun 40 | return singular_noun 41 | 42 | 43 | def get_box(gen_boxes, name_include): 44 | # This prevents substring match on non-word boundaries: carrot vs car 45 | box_match = [any([((name_include_item + ' ') in box['name'] or box['name'].endswith(name_include_item)) 46 | for name_include_item in name_include]) for box in gen_boxes] 47 | 48 | if not any(box_match): 49 | return None 50 | 51 | box_ind = np.min(np.where(box_match)[0]) 52 | return gen_boxes[box_ind] 53 | 54 | 55 | def count(gen_boxes, name_include): 56 | return sum([ 57 | any([name_include_item in box['name'] for name_include_item in name_include]) for box in gen_boxes 58 | ]) 59 | 60 | 61 | def predicate_numeracy(query_names, intended_count, gen_boxes, verbose=False): 62 | # gen_boxes: dict with keys 'name' and 'bounding_box' 63 | object_count = count(gen_boxes, name_include=query_names) 64 | if verbose: 65 | print( 66 | f"object_count: {object_count}, intended_count: {intended_count} (gen_boxes: {gen_boxes}, query_names: {query_names})") 67 | 68 | return object_count == intended_count 69 | 70 | def predicate_numeracy_2obj(query_names1, intended_count1, query_names2, intended_count2, gen_boxes, verbose=False): 71 | # gen_boxes: dict with keys 'name' and 'bounding_box' 72 | object_count1 = count(gen_boxes, name_include=query_names1) 73 | object_count2 = count(gen_boxes, name_include=query_names2) 74 | 75 | if verbose: 76 | print( 77 | f"object_count1: {object_count1}, intended_count1: {intended_count1} (gen_boxes: {gen_boxes}, query_names1: {query_names1})") 78 | print( 79 | f"object_count2: {object_count2}, intended_count2: {intended_count2} (gen_boxes: {gen_boxes}, query_names2: {query_names2})") 80 | 81 | return object_count1 == intended_count1 and object_count2 == intended_count2 82 | 83 | 84 | def predicate_attribution(query_names1, query_names2, modifier1, modifier2, intended_count1, intended_count2, gen_boxes, verbose=False): 85 | # gen_boxes: dict with keys 'name' and 'bounding_box' 86 | if modifier1: 87 | query_names1 = [f"{modifier1} {item}" for item in query_names1] 88 | object_count1 = count(gen_boxes, name_include=query_names1) 89 | 90 | if query_names2 is not None: 91 | if modifier2: 92 | query_names2 = [f"{modifier2} {item}" for item in query_names2] 93 | object_count2 = count(gen_boxes, name_include=query_names2) 94 | 95 | if verbose: 96 | print(f"Count 1: {object_count1}, Count 2: {object_count2}") 97 | return object_count1 >= intended_count1 and object_count2 >= intended_count2 98 | else: 99 | if verbose: 100 | print(f"Count 1: {object_count1}") 101 | return object_count1 >= intended_count1 102 | 103 | 104 | def predicate_spatial(query_names1, query_names2, verify_fn, gen_boxes, verbose=False): 105 | # gen_boxes: dict with keys 'name' and 'bounding_box' 106 | 107 | object_box1 = get_box(gen_boxes, query_names1) 108 | object_box2 = get_box(gen_boxes, query_names2) 109 | 110 | if verbose: 111 | print( 112 | f"object_box1: {object_box1}, object_box2: {object_box2}") 113 | 114 | if object_box1 is None or object_box2 is None: 115 | return False 116 | 117 | return verify_fn(object_box1['bounding_box'], object_box2['bounding_box']) 118 | -------------------------------------------------------------------------------- /lmd_benchmark_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from eval.lmd import get_lmd_prompts 5 | from glob import glob 6 | from eval.eval import eval_prompt, Evaluator 7 | from tqdm import tqdm 8 | 9 | torch.set_grad_enabled(False) 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--data_dir", type=str, required=True) 14 | parser.add_argument("--num_round", type=int, default=1) 15 | parser.add_argument("--prim_detection_score_threshold", default=0.20, type=float) 16 | parser.add_argument("--attr_detection_score_threshold", default=0.45, type=float) 17 | parser.add_argument("--nms_threshold", default=0.15, type=float) 18 | parser.add_argument("--class-aware-nms", action='store_true', default=True) 19 | parser.add_argument("--verbose", action='store_true') 20 | parser.add_argument("--no-cuda", action='store_true') 21 | args = parser.parse_args() 22 | 23 | prompts = get_lmd_prompts()["lmd"] 24 | print(f"Number of prompts: {len(prompts)}") 25 | 26 | evaluator = Evaluator() 27 | eval_success_counts = {} 28 | eval_all_counts = {} 29 | failure = [] 30 | 31 | for ind, prompt in enumerate(tqdm(prompts)): 32 | 33 | get_path = False 34 | for idx in range(args.num_round, 0, -1): 35 | path = os.path.join(args.data_dir, f"{ind:03d}", f"round{idx}.jpg") 36 | if os.path.exists(path): 37 | get_path = True 38 | break 39 | if not get_path: 40 | path = os.path.join(args.data_dir, f"{ind:03d}", "initial_image.jpg") 41 | print(f"Image path: {path}") 42 | 43 | eval_type, eval_success = eval_prompt(prompt, path, evaluator, 44 | prim_score_threshold=args.prim_detection_score_threshold, attr_score_threshold=args.attr_detection_score_threshold, 45 | nms_threshold=args.nms_threshold, use_class_aware_nms=args.class_aware_nms, use_cuda=True, verbose=args.verbose) 46 | 47 | print(f"Eval success (eval_type):", eval_success) 48 | if int(eval_success) < 1: 49 | failure.append(ind) 50 | if eval_type not in eval_all_counts: 51 | eval_success_counts[eval_type] = 0 52 | eval_all_counts[eval_type] = 0 53 | eval_success_counts[eval_type] += int(eval_success) 54 | eval_all_counts[eval_type] += 1 55 | summary = [] 56 | eval_success_conut, eval_all_count = 0, 0 57 | for k, v in eval_all_counts.items(): 58 | rate = eval_success_counts[k]/eval_all_counts[k] 59 | print( 60 | f"Eval type: {k}, success: {eval_success_counts[k]}/{eval_all_counts[k]}, rate: {round(rate, 2):.2f}") 61 | eval_success_conut += eval_success_counts[k] 62 | eval_all_count += eval_all_counts[k] 63 | summary.append(rate) 64 | print(failure) 65 | rate = eval_success_conut/eval_all_count 66 | print( 67 | f"Overall: success: {eval_success_conut}/{eval_all_count}, rate: {rate:.2f}") 68 | summary.append(rate) 69 | 70 | summary_str = '/'.join([f"{round(rate, 2):.2f}" for rate in summary]) 71 | print(f"Summary: {summary_str}") -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Dict, Optional 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch import nn 19 | 20 | from diffusers.utils import maybe_allow_in_graph 21 | from .attention_processor import Attention 22 | from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings 23 | 24 | # https://github.com/gligen/diffusers/blob/23a9a0fab1b48752c7b9bcc98f6fe3b1d8fa7990/src/diffusers/models/attention.py 25 | class GatedSelfAttentionDense(nn.Module): 26 | def __init__(self, query_dim, context_dim, n_heads, d_head): 27 | super().__init__() 28 | 29 | # we need a linear projection since we need cat visual feature and obj feature 30 | self.linear = nn.Linear(context_dim, query_dim) 31 | 32 | self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) 33 | self.ff = FeedForward(query_dim, activation_fn="geglu") 34 | 35 | self.norm1 = nn.LayerNorm(query_dim) 36 | self.norm2 = nn.LayerNorm(query_dim) 37 | 38 | self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) 39 | self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) 40 | 41 | self.enabled = True 42 | 43 | def forward(self, x, objs, fuser_attn_kwargs={}): 44 | if not self.enabled: 45 | return x 46 | 47 | n_visual = x.shape[1] 48 | objs = self.linear(objs) 49 | 50 | x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)), **fuser_attn_kwargs)[:, :n_visual, :] 51 | x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) 52 | 53 | return x 54 | 55 | @maybe_allow_in_graph 56 | class BasicTransformerBlock(nn.Module): 57 | r""" 58 | A basic Transformer block. 59 | 60 | Parameters: 61 | dim (`int`): The number of channels in the input and output. 62 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 63 | attention_head_dim (`int`): The number of channels in each head. 64 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 65 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 66 | only_cross_attention (`bool`, *optional*): 67 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 68 | double_self_attention (`bool`, *optional*): 69 | Whether to use two self-attention layers. In this case no cross attention layers are used. 70 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 71 | num_embeds_ada_norm (: 72 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 73 | attention_bias (: 74 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 75 | """ 76 | 77 | def __init__( 78 | self, 79 | dim: int, 80 | num_attention_heads: int, 81 | attention_head_dim: int, 82 | dropout=0.0, 83 | cross_attention_dim: Optional[int] = None, 84 | activation_fn: str = "geglu", 85 | num_embeds_ada_norm: Optional[int] = None, 86 | attention_bias: bool = False, 87 | only_cross_attention: bool = False, 88 | double_self_attention: bool = False, 89 | upcast_attention: bool = False, 90 | norm_elementwise_affine: bool = True, 91 | norm_type: str = "layer_norm", 92 | final_dropout: bool = False, 93 | use_gated_attention: bool = False, 94 | ): 95 | super().__init__() 96 | self.only_cross_attention = only_cross_attention 97 | 98 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" 99 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" 100 | 101 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 102 | raise ValueError( 103 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 104 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 105 | ) 106 | 107 | # Define 3 blocks. Each block has its own normalization layer. 108 | # 1. Self-Attn 109 | if self.use_ada_layer_norm: 110 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 111 | elif self.use_ada_layer_norm_zero: 112 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 113 | else: 114 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 115 | self.attn1 = Attention( 116 | query_dim=dim, 117 | heads=num_attention_heads, 118 | dim_head=attention_head_dim, 119 | dropout=dropout, 120 | bias=attention_bias, 121 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 122 | upcast_attention=upcast_attention, 123 | ) 124 | 125 | # 2. Cross-Attn 126 | if cross_attention_dim is not None or double_self_attention: 127 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 128 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 129 | # the second cross attention block. 130 | self.norm2 = ( 131 | AdaLayerNorm(dim, num_embeds_ada_norm) 132 | if self.use_ada_layer_norm 133 | else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 134 | ) 135 | self.attn2 = Attention( 136 | query_dim=dim, 137 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 138 | heads=num_attention_heads, 139 | dim_head=attention_head_dim, 140 | dropout=dropout, 141 | bias=attention_bias, 142 | upcast_attention=upcast_attention, 143 | ) # is self-attn if encoder_hidden_states is none 144 | else: 145 | self.norm2 = None 146 | self.attn2 = None 147 | 148 | # 3. Feed-forward 149 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 150 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) 151 | 152 | # 4. Fuser 153 | if use_gated_attention: 154 | self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) 155 | 156 | def forward( 157 | self, 158 | hidden_states: torch.FloatTensor, 159 | attention_mask: Optional[torch.FloatTensor] = None, 160 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 161 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 162 | timestep: Optional[torch.LongTensor] = None, 163 | cross_attention_kwargs: Dict[str, Any] = None, 164 | class_labels: Optional[torch.LongTensor] = None, 165 | return_cross_attention_probs: bool = None, 166 | ): 167 | # Notice that normalization is always applied before the real computation in the following blocks. 168 | 169 | # 0. Prepare GLIGEN inputs 170 | if 'gligen' in cross_attention_kwargs: 171 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 172 | gligen_kwargs = cross_attention_kwargs.pop('gligen', None) 173 | else: 174 | cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} 175 | gligen_kwargs = None 176 | 177 | # 1. Self-Attention 178 | if self.use_ada_layer_norm: 179 | norm_hidden_states = self.norm1(hidden_states, timestep) 180 | elif self.use_ada_layer_norm_zero: 181 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 182 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 183 | ) 184 | else: 185 | norm_hidden_states = self.norm1(hidden_states) 186 | 187 | attn_output = self.attn1( 188 | norm_hidden_states, 189 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 190 | attention_mask=attention_mask, 191 | **cross_attention_kwargs, 192 | ) 193 | if self.use_ada_layer_norm_zero: 194 | attn_output = gate_msa.unsqueeze(1) * attn_output 195 | hidden_states = attn_output + hidden_states 196 | 197 | # 1.5 GLIGEN Control 198 | if gligen_kwargs is not None: 199 | # print(gligen_kwargs) 200 | hidden_states = self.fuser(hidden_states, gligen_kwargs['objs'], fuser_attn_kwargs=gligen_kwargs.get("fuser_attn_kwargs", {})) 201 | # 1.5 ends 202 | 203 | # 2. Cross-Attention 204 | if self.attn2 is not None: 205 | norm_hidden_states = ( 206 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 207 | ) 208 | 209 | attn_output = self.attn2( 210 | norm_hidden_states, 211 | encoder_hidden_states=encoder_hidden_states, 212 | attention_mask=encoder_attention_mask, 213 | return_attntion_probs=return_cross_attention_probs, 214 | **cross_attention_kwargs, 215 | ) 216 | 217 | if return_cross_attention_probs: 218 | attn_output, cross_attention_probs = attn_output 219 | 220 | hidden_states = attn_output + hidden_states 221 | 222 | # 3. Feed-forward 223 | norm_hidden_states = self.norm3(hidden_states) 224 | 225 | if self.use_ada_layer_norm_zero: 226 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 227 | 228 | ff_output = self.ff(norm_hidden_states) 229 | 230 | if self.use_ada_layer_norm_zero: 231 | ff_output = gate_mlp.unsqueeze(1) * ff_output 232 | 233 | hidden_states = ff_output + hidden_states 234 | 235 | if return_cross_attention_probs and self.attn2 is not None: 236 | return hidden_states, cross_attention_probs 237 | return hidden_states 238 | 239 | 240 | class FeedForward(nn.Module): 241 | r""" 242 | A feed-forward layer. 243 | 244 | Parameters: 245 | dim (`int`): The number of channels in the input. 246 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. 247 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. 248 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 249 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 250 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. 251 | """ 252 | 253 | def __init__( 254 | self, 255 | dim: int, 256 | dim_out: Optional[int] = None, 257 | mult: int = 4, 258 | dropout: float = 0.0, 259 | activation_fn: str = "geglu", 260 | final_dropout: bool = False, 261 | ): 262 | super().__init__() 263 | inner_dim = int(dim * mult) 264 | dim_out = dim_out if dim_out is not None else dim 265 | 266 | if activation_fn == "gelu": 267 | act_fn = GELU(dim, inner_dim) 268 | if activation_fn == "gelu-approximate": 269 | act_fn = GELU(dim, inner_dim, approximate="tanh") 270 | elif activation_fn == "geglu": 271 | act_fn = GEGLU(dim, inner_dim) 272 | elif activation_fn == "geglu-approximate": 273 | act_fn = ApproximateGELU(dim, inner_dim) 274 | 275 | self.net = nn.ModuleList([]) 276 | # project in 277 | self.net.append(act_fn) 278 | # project dropout 279 | self.net.append(nn.Dropout(dropout)) 280 | # project out 281 | self.net.append(nn.Linear(inner_dim, dim_out)) 282 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 283 | if final_dropout: 284 | self.net.append(nn.Dropout(dropout)) 285 | 286 | def forward(self, hidden_states): 287 | for module in self.net: 288 | hidden_states = module(hidden_states) 289 | return hidden_states 290 | 291 | 292 | class GELU(nn.Module): 293 | r""" 294 | GELU activation function with tanh approximation support with `approximate="tanh"`. 295 | """ 296 | 297 | def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): 298 | super().__init__() 299 | self.proj = nn.Linear(dim_in, dim_out) 300 | self.approximate = approximate 301 | 302 | def gelu(self, gate): 303 | if gate.device.type != "mps": 304 | return F.gelu(gate, approximate=self.approximate) 305 | # mps: gelu is not implemented for float16 306 | return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) 307 | 308 | def forward(self, hidden_states): 309 | hidden_states = self.proj(hidden_states) 310 | hidden_states = self.gelu(hidden_states) 311 | return hidden_states 312 | 313 | 314 | class GEGLU(nn.Module): 315 | r""" 316 | A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. 317 | 318 | Parameters: 319 | dim_in (`int`): The number of channels in the input. 320 | dim_out (`int`): The number of channels in the output. 321 | """ 322 | 323 | def __init__(self, dim_in: int, dim_out: int): 324 | super().__init__() 325 | self.proj = nn.Linear(dim_in, dim_out * 2) 326 | 327 | def gelu(self, gate): 328 | if gate.device.type != "mps": 329 | return F.gelu(gate) 330 | # mps: gelu is not implemented for float16 331 | return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) 332 | 333 | def forward(self, hidden_states): 334 | hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) 335 | return hidden_states * self.gelu(gate) 336 | 337 | 338 | class ApproximateGELU(nn.Module): 339 | """ 340 | The approximate form of Gaussian Error Linear Unit (GELU) 341 | 342 | For more details, see section 2: https://arxiv.org/abs/1606.08415 343 | """ 344 | 345 | def __init__(self, dim_in: int, dim_out: int): 346 | super().__init__() 347 | self.proj = nn.Linear(dim_in, dim_out) 348 | 349 | def forward(self, x): 350 | x = self.proj(x) 351 | return x * torch.sigmoid(1.702 * x) 352 | 353 | 354 | class AdaLayerNorm(nn.Module): 355 | """ 356 | Norm layer modified to incorporate timestep embeddings. 357 | """ 358 | 359 | def __init__(self, embedding_dim, num_embeddings): 360 | super().__init__() 361 | self.emb = nn.Embedding(num_embeddings, embedding_dim) 362 | self.silu = nn.SiLU() 363 | self.linear = nn.Linear(embedding_dim, embedding_dim * 2) 364 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) 365 | 366 | def forward(self, x, timestep): 367 | emb = self.linear(self.silu(self.emb(timestep))) 368 | scale, shift = torch.chunk(emb, 2) 369 | x = self.norm(x) * (1 + scale) + shift 370 | return x 371 | 372 | 373 | class AdaLayerNormZero(nn.Module): 374 | """ 375 | Norm layer adaptive layer norm zero (adaLN-Zero). 376 | """ 377 | 378 | def __init__(self, embedding_dim, num_embeddings): 379 | super().__init__() 380 | 381 | self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) 382 | 383 | self.silu = nn.SiLU() 384 | self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) 385 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) 386 | 387 | def forward(self, x, timestep, class_labels, hidden_dtype=None): 388 | emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))) 389 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) 390 | x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] 391 | return x, gate_msa, shift_mlp, scale_mlp, gate_mlp 392 | 393 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import CLIPTextModel, CLIPTokenizer 3 | from diffusers import AutoencoderKL, DDIMScheduler, DDIMInverseScheduler, DPMSolverMultistepScheduler 4 | from .unet_2d_condition import UNet2DConditionModel 5 | from easydict import EasyDict 6 | import numpy as np 7 | # For compatibility 8 | from utils.latents import get_unscaled_latents, get_scaled_latents, blend_latents 9 | from utils import torch_device 10 | 11 | # This is to be set in the `generate.py` 12 | sd_key = "" 13 | sd_version = "" 14 | model_dict = None 15 | 16 | def load_sd(key="runwayml/stable-diffusion-v1-5", use_fp16=False, load_inverse_scheduler=True, use_dpm_multistep_scheduler=False, scheduler_cls=None): 17 | """ 18 | Keys: 19 | key = "CompVis/stable-diffusion-v1-4" 20 | key = "runwayml/stable-diffusion-v1-5" 21 | key = "stabilityai/stable-diffusion-2-1-base" 22 | 23 | Unpack with: 24 | ``` 25 | model_dict = load_sd(key=key, use_fp16=use_fp16, **models.model_kwargs) 26 | vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype 27 | ``` 28 | 29 | use_fp16: fp16 might have degraded performance 30 | use_dpm_multistep_scheduler: DPMSolverMultistepScheduler 31 | """ 32 | 33 | # run final results in fp32 34 | if use_fp16: 35 | dtype = torch.float16 36 | revision = "fp16" 37 | else: 38 | dtype = torch.float 39 | revision = "main" 40 | 41 | vae = AutoencoderKL.from_pretrained(key, subfolder="vae", revision=revision, torch_dtype=dtype).to(torch_device) 42 | tokenizer = CLIPTokenizer.from_pretrained(key, subfolder="tokenizer", revision=revision, torch_dtype=dtype) 43 | text_encoder = CLIPTextModel.from_pretrained(key, subfolder="text_encoder", revision=revision, torch_dtype=dtype).to(torch_device) 44 | unet = UNet2DConditionModel.from_pretrained(key, subfolder="unet", revision=revision, torch_dtype=dtype).to(torch_device) 45 | # print(unet) 46 | # exit() 47 | if scheduler_cls is None: # Default setting (for compatibility) 48 | if use_dpm_multistep_scheduler: 49 | scheduler = DPMSolverMultistepScheduler.from_pretrained(key, subfolder="scheduler", revision=revision, torch_dtype=dtype) 50 | else: 51 | scheduler = DDIMScheduler.from_pretrained(key, subfolder="scheduler", revision=revision, torch_dtype=dtype) 52 | else: 53 | print("Using scheduler:", scheduler_cls) 54 | assert not use_dpm_multistep_scheduler, "`use_dpm_multistep_scheduler` cannot be used with `scheduler_cls`" 55 | scheduler = scheduler_cls.from_pretrained(key, subfolder="scheduler", revision=revision, torch_dtype=dtype) 56 | 57 | model_dict = EasyDict(vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, dtype=dtype) 58 | 59 | if load_inverse_scheduler: 60 | inverse_scheduler = DDIMInverseScheduler.from_config(scheduler.config) 61 | model_dict.inverse_scheduler = inverse_scheduler 62 | 63 | return model_dict 64 | 65 | def encode_prompts(tokenizer, text_encoder, prompts, negative_prompt="", return_full_only=False, one_uncond_input_only=False): 66 | if negative_prompt == "": 67 | print("Note that negative_prompt is an empty string") 68 | 69 | text_input = tokenizer( 70 | prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt" 71 | ) 72 | 73 | max_length = text_input.input_ids.shape[-1] 74 | if one_uncond_input_only: 75 | num_uncond_input = 1 76 | else: 77 | num_uncond_input = len(prompts) 78 | uncond_input = tokenizer([negative_prompt] * num_uncond_input, padding="max_length", max_length=max_length, return_tensors="pt") 79 | 80 | with torch.no_grad(): 81 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] 82 | cond_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] 83 | 84 | if one_uncond_input_only: 85 | return uncond_embeddings, cond_embeddings 86 | 87 | text_embeddings = torch.cat([uncond_embeddings, cond_embeddings]) 88 | 89 | if return_full_only: 90 | return text_embeddings 91 | return text_embeddings, uncond_embeddings, cond_embeddings 92 | 93 | def process_input_embeddings(input_embeddings): 94 | assert isinstance(input_embeddings, (tuple, list)) 95 | if len(input_embeddings) == 3: 96 | # input_embeddings: text_embeddings, uncond_embeddings, cond_embeddings 97 | # Assume `uncond_embeddings` is full (has batch size the same as cond_embeddings) 98 | _, uncond_embeddings, cond_embeddings = input_embeddings 99 | assert uncond_embeddings.shape[0] == cond_embeddings.shape[0], f"{uncond_embeddings.shape[0]} != {cond_embeddings.shape[0]}" 100 | return input_embeddings 101 | elif len(input_embeddings) == 2: 102 | # input_embeddings: uncond_embeddings, cond_embeddings 103 | # uncond_embeddings may have only one item 104 | uncond_embeddings, cond_embeddings = input_embeddings 105 | if uncond_embeddings.shape[0] == 1: 106 | uncond_embeddings = uncond_embeddings.expand(cond_embeddings.shape) 107 | # We follow the convention: negative (unconditional) prompt comes first 108 | text_embeddings = torch.cat((uncond_embeddings, cond_embeddings), dim=0) 109 | return text_embeddings, uncond_embeddings, cond_embeddings 110 | else: 111 | raise ValueError(f"input_embeddings length: {len(input_embeddings)}") 112 | -------------------------------------------------------------------------------- /models/sam.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from models import torch_device 7 | from transformers import SamModel, SamProcessor 8 | import utils 9 | from utils import vis 10 | import cv2 11 | from scipy import ndimage 12 | 13 | 14 | def load_sam(): 15 | sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(torch_device) 16 | sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") 17 | 18 | sam_model_dict = dict(sam_model=sam_model, sam_processor=sam_processor) 19 | 20 | return sam_model_dict 21 | 22 | 23 | # Not fully backward compatible with the previous implementation 24 | # Reference: lmdv2/notebooks/gen_masked_latents_multi_object_ref_ca_loss_modular.ipynb 25 | def sam( 26 | sam_model_dict, 27 | image, 28 | input_points=None, 29 | input_boxes=None, 30 | target_mask_shape=None, 31 | return_numpy=True, 32 | ): 33 | """target_mask_shape: (h, w)""" 34 | sam_model, sam_processor = ( 35 | sam_model_dict["sam_model"], 36 | sam_model_dict["sam_processor"], 37 | ) 38 | 39 | if not isinstance(input_boxes, torch.Tensor): 40 | if input_boxes and isinstance(input_boxes[0], tuple): 41 | # Convert tuple to list 42 | input_boxes = [list(input_box) for input_box in input_boxes] 43 | 44 | if input_boxes and input_boxes[0] and isinstance(input_boxes[0][0], tuple): 45 | # Convert tuple to list 46 | input_boxes = [ 47 | [list(input_box) for input_box in input_boxes_item] 48 | for input_boxes_item in input_boxes 49 | ] 50 | 51 | with torch.no_grad(): 52 | with torch.autocast(torch_device): 53 | inputs = sam_processor( 54 | image, 55 | input_points=input_points, 56 | input_boxes=input_boxes, 57 | return_tensors="pt", 58 | ).to(torch_device) 59 | outputs = sam_model(**inputs) 60 | masks = sam_processor.image_processor.post_process_masks( 61 | outputs.pred_masks.cpu().float(), 62 | inputs["original_sizes"].cpu(), 63 | inputs["reshaped_input_sizes"].cpu(), 64 | ) 65 | conf_scores = outputs.iou_scores.cpu().numpy()[0, 0] 66 | del inputs, outputs 67 | 68 | # Uncomment if experiencing out-of-memory error: 69 | utils.free_memory() 70 | if return_numpy: 71 | masks = [ 72 | F.interpolate( 73 | masks_item.type(torch.float), target_mask_shape, mode="bilinear" 74 | ) 75 | .type(torch.bool) 76 | .numpy() 77 | for masks_item in masks 78 | ] 79 | else: 80 | masks = [ 81 | F.interpolate( 82 | masks_item.type(torch.float), target_mask_shape, mode="bilinear" 83 | ).type(torch.bool) 84 | for masks_item in masks 85 | ] 86 | 87 | return masks, conf_scores 88 | 89 | 90 | def sam_point_input(sam_model_dict, image, input_points, **kwargs): 91 | return sam(sam_model_dict, image, input_points=input_points, **kwargs) 92 | 93 | 94 | def sam_box_input(sam_model_dict, image, input_boxes, **kwargs): 95 | return sam(sam_model_dict, image, input_boxes=input_boxes, **kwargs) 96 | 97 | 98 | def get_iou_with_resize(mask, masks, masks_shape): 99 | masks = np.array( 100 | [ 101 | cv2.resize( 102 | mask.astype(np.uint8) * 255, masks_shape[::-1], cv2.INTER_LINEAR 103 | ).astype(bool) 104 | for mask in masks 105 | ] 106 | ) 107 | return utils.iou(mask, masks) 108 | 109 | 110 | def select_mask( 111 | masks, 112 | conf_scores, 113 | coarse_ious=None, 114 | rule="largest_over_conf", 115 | discourage_mask_below_confidence=0.85, 116 | discourage_mask_below_coarse_iou=0.2, 117 | verbose=False, 118 | ): 119 | """masks: numpy bool array""" 120 | mask_sizes = masks.sum(axis=(1, 2)) 121 | 122 | # Another possible rule: iou with the attention mask 123 | if rule == "largest_over_conf": 124 | # Use the largest segmentation 125 | # Discourage selecting masks with conf too low or coarse iou is too low 126 | max_mask_size = np.max(mask_sizes) 127 | if coarse_ious is not None: 128 | scores = ( 129 | mask_sizes 130 | - (conf_scores < discourage_mask_below_confidence) * max_mask_size 131 | - (coarse_ious < discourage_mask_below_coarse_iou) * max_mask_size 132 | ) 133 | else: 134 | scores = ( 135 | mask_sizes 136 | - (conf_scores < discourage_mask_below_confidence) * max_mask_size 137 | ) 138 | if verbose: 139 | print(f"mask_sizes: {mask_sizes}, scores: {scores}") 140 | else: 141 | raise ValueError(f"Unknown rule: {rule}") 142 | 143 | mask_id = np.argmax(scores) 144 | mask = masks[mask_id] 145 | 146 | selection_conf = conf_scores[mask_id] 147 | 148 | if coarse_ious is not None: 149 | selection_coarse_iou = coarse_ious[mask_id] 150 | else: 151 | selection_coarse_iou = None 152 | 153 | if verbose: 154 | # print(f"Confidences: {conf_scores}") 155 | print( 156 | f"Selected a mask with confidence: {selection_conf}, coarse_iou: {selection_coarse_iou}" 157 | ) 158 | 159 | if verbose >= 2: 160 | plt.figure(figsize=(10, 8)) 161 | # plt.suptitle("After SAM") 162 | for ind in range(3): 163 | plt.subplot(1, 3, ind + 1) 164 | # This is obtained before resize. 165 | plt.title( 166 | f"Mask {ind}, score {scores[ind]}, conf {conf_scores[ind]:.2f}, iou {coarse_ious[ind] if coarse_ious is not None else None:.2f}" 167 | ) 168 | plt.imshow(masks[ind]) 169 | plt.tight_layout() 170 | plt.show() 171 | plt.close() 172 | 173 | return mask, selection_conf 174 | 175 | 176 | def preprocess_mask(token_attn_np_smooth, mask_th, n_erode_dilate_mask=0): 177 | token_attn_np_smooth_normalized = token_attn_np_smooth - token_attn_np_smooth.min() 178 | token_attn_np_smooth_normalized /= token_attn_np_smooth_normalized.max() 179 | mask_thresholded = token_attn_np_smooth_normalized > mask_th 180 | 181 | if n_erode_dilate_mask: 182 | mask_thresholded = ndimage.binary_erosion( 183 | mask_thresholded, iterations=n_erode_dilate_mask 184 | ) 185 | mask_thresholded = ndimage.binary_dilation( 186 | mask_thresholded, iterations=n_erode_dilate_mask 187 | ) 188 | 189 | return mask_thresholded 190 | 191 | 192 | # The overall pipeline to refine the attention mask 193 | def sam_refine_attn( 194 | sam_input_image, 195 | token_attn_np, 196 | model_dict, 197 | height, 198 | width, 199 | H, 200 | W, 201 | use_box_input, 202 | gaussian_sigma, 203 | mask_th_for_box, 204 | n_erode_dilate_mask_for_box, 205 | mask_th_for_point, 206 | discourage_mask_below_confidence, 207 | discourage_mask_below_coarse_iou, 208 | verbose, 209 | ): 210 | # token_attn_np is for visualizations 211 | token_attn_np_smooth = ndimage.gaussian_filter( 212 | token_attn_np.astype(float), sigma=gaussian_sigma 213 | ) 214 | 215 | if verbose >= 2: 216 | # Visualize one token only 217 | vis.visualize_arrays( 218 | [ 219 | (token_attn_np, f"token_attn_np"), 220 | (token_attn_np_smooth, f"token_attn_np_smooth"), 221 | ], 222 | colorbar_index=1, 223 | ) 224 | 225 | # (w, h) 226 | mask_size_scale = ( 227 | height // token_attn_np_smooth.shape[1], 228 | width // token_attn_np_smooth.shape[0], 229 | ) 230 | 231 | if use_box_input: 232 | # box input 233 | mask_binary = preprocess_mask( 234 | token_attn_np_smooth, 235 | mask_th_for_box, 236 | n_erode_dilate_mask=n_erode_dilate_mask_for_box, 237 | ) 238 | 239 | input_boxes = utils.binary_mask_to_box( 240 | mask_binary, w_scale=mask_size_scale[0], h_scale=mask_size_scale[1] 241 | ) 242 | input_boxes = [input_boxes] 243 | 244 | masks, conf_scores = sam_box_input( 245 | model_dict, 246 | image=sam_input_image, 247 | input_boxes=input_boxes, 248 | target_mask_shape=(H, W), 249 | ) 250 | else: 251 | # point input 252 | mask_binary = preprocess_mask( 253 | token_attn_np_smooth, mask_th_for_point, n_erode_dilate_mask=0 254 | ) 255 | 256 | # Uses the max coordinate only 257 | max_coord = np.unravel_index( 258 | token_attn_np_smooth.argmax(), token_attn_np_smooth.shape 259 | ) 260 | # print("max_coord:", max_coord) 261 | input_points = [ 262 | [[max_coord[1] * mask_size_scale[1], max_coord[0] * mask_size_scale[0]]] 263 | ] 264 | 265 | masks, conf_scores = sam_point_input( 266 | model_dict, 267 | image=sam_input_image, 268 | input_points=input_points, 269 | target_mask_shape=(H, W), 270 | ) 271 | 272 | if verbose >= 2: 273 | plt.title("Coarse binary mask (for getting the box with box input and for iou)") 274 | plt.imshow(mask_binary) 275 | plt.show() 276 | 277 | # Assuming one image, one three-masks per image (so we have indexing twice) 278 | three_masks = masks[0][0] 279 | 280 | coarse_ious = get_iou_with_resize( 281 | mask_binary, three_masks, masks_shape=mask_binary.shape 282 | ) 283 | 284 | mask_selected, conf_score_selected = select_mask( 285 | three_masks, 286 | conf_scores, 287 | coarse_ious=coarse_ious, 288 | rule="largest_over_conf", 289 | discourage_mask_below_confidence=discourage_mask_below_confidence, 290 | discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou, 291 | verbose=True, 292 | ) 293 | 294 | return mask_selected, conf_score_selected 295 | 296 | 297 | def sam_refine_box(sam_input_image, box, *args, **kwargs): 298 | # One image with one box 299 | 300 | sam_input_images, boxes = [sam_input_image], [[box]] 301 | mask_selected_batched_list, conf_score_selected_batched_list = sam_refine_boxes( 302 | sam_input_images, boxes, *args, **kwargs 303 | ) 304 | 305 | return mask_selected_batched_list[0][0], conf_score_selected_batched_list[0][0] 306 | 307 | 308 | def sam_refine_boxes( 309 | sam_input_images, 310 | boxes, 311 | model_dict, 312 | height, 313 | width, 314 | H, 315 | W, 316 | discourage_mask_below_confidence, 317 | discourage_mask_below_coarse_iou, 318 | verbose, 319 | ): 320 | # (w, h) 321 | input_boxes = [ 322 | [utils.scale_proportion(box, H=height, W=width) for box in boxes_item] 323 | for boxes_item in boxes 324 | ] 325 | 326 | masks, conf_scores = sam_box_input( 327 | model_dict, 328 | image=sam_input_images, 329 | input_boxes=input_boxes, 330 | target_mask_shape=(H, W), 331 | ) 332 | 333 | mask_selected_batched_list, conf_score_selected_batched_list = [], [] 334 | 335 | for boxes_item, masks_item in zip(boxes, masks): 336 | mask_selected_list, conf_score_selected_list = [], [] 337 | for box, three_masks in zip(boxes_item, masks_item): 338 | mask_binary = utils.proportion_to_mask(box, H, W, return_np=True) 339 | if verbose >= 2: 340 | # Also the box is the input for SAM 341 | plt.title("Binary mask from input box (for iou)") 342 | plt.imshow(mask_binary) 343 | plt.show() 344 | 345 | coarse_ious = get_iou_with_resize( 346 | mask_binary, three_masks, masks_shape=mask_binary.shape 347 | ) 348 | 349 | mask_selected, conf_score_selected = select_mask( 350 | three_masks, 351 | conf_scores, 352 | coarse_ious=coarse_ious, 353 | rule="largest_over_conf", 354 | discourage_mask_below_confidence=discourage_mask_below_confidence, 355 | discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou, 356 | verbose=False, 357 | ) 358 | 359 | mask_selected_list.append(mask_selected) 360 | conf_score_selected_list.append(conf_score_selected) 361 | mask_selected_batched_list.append(mask_selected_list) 362 | conf_score_selected_batched_list.append(conf_score_selected_list) 363 | 364 | return mask_selected_batched_list, conf_score_selected_batched_list 365 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu117 2 | 3 | # Basic packages 4 | tqdm==4.64.1 5 | easydict==1.11 6 | inflect==7.0.0 7 | matplotlib==3.8.2 8 | scikit-image==0.21.0 9 | opencv-python==4.8.0.74 10 | 11 | # Pytorch 12 | torch==2.0.1 13 | torchvision==0.15.2 14 | accelerate==0.26.1 15 | 16 | # diffusers + transformers 17 | transformers==4.36.2 18 | diffusers==0.20.2 19 | openai==1.12.0 20 | -------------------------------------------------------------------------------- /sld/llm_chat.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import time 3 | from openai import OpenAI 4 | 5 | 6 | def get_key_objects(message, config): 7 | """ 8 | Retrieves key objects and additional negative prompt from a given message using a specified model. 9 | 10 | Parameters: 11 | message (str): The message to process. 12 | model (str): The language model to use (default is "gpt-4"). 13 | 14 | Returns: 15 | tuple: A tuple containing the list of key objects, the additional negative prompt, and the complete raw response. 16 | """ 17 | # Reading configuration from file 18 | 19 | organization = config.get("openai", "organization") 20 | api_key = config.get("openai", "api_key") 21 | model = config.get("openai", "model") 22 | 23 | # Alternatively, reading configuration from environment variables 24 | # organization = os.environ.get('OPENAI_ORGANIZATION') 25 | # api_key = os.environ.get('OPENAI_API_KEY') 26 | 27 | messages = [{"role": "user", "content": message}] 28 | 29 | while True: 30 | try: 31 | client = OpenAI(organization=organization, api_key=api_key) 32 | response = client.chat.completions.create(model=model, messages=messages) 33 | raw_response = response.choices[0].message.content 34 | print(f"ChatGPT: {raw_response}") 35 | 36 | # Extracting key objects 37 | key_objects_part = raw_response.split("Objects:")[1] 38 | start_index = key_objects_part.index("[") 39 | end_index = key_objects_part.rindex("]") + 1 40 | objects_str = key_objects_part[start_index:end_index] 41 | 42 | # Converting string to list 43 | parsed_objects = ast.literal_eval(objects_str) 44 | 45 | # Extracting additional negative prompt 46 | bg_prompt = raw_response.split("Background:")[1].split("\n")[0].strip() 47 | negative_prompt = raw_response.split("Negation:")[1].strip() 48 | break 49 | except Exception as e: 50 | print(f"Error occured when calling LLM API: {e}") 51 | time.sleep(5) 52 | 53 | parsed_result = { 54 | "objects": parsed_objects, 55 | "bg_prompt": bg_prompt, 56 | "neg_prompt": negative_prompt, 57 | } 58 | return parsed_result, raw_response 59 | 60 | 61 | def get_updated_layout(message, config): 62 | """ 63 | Retrieves a list of objects with updated bounding box coordinates from a given message using a specified model. 64 | 65 | Parameters: 66 | message (str): The message containing information to process. 67 | model (str): The language model to use (default is "gpt-4"). 68 | 69 | Returns: 70 | tuple: A tuple containing the list of objects with updated bounding boxes and the complete raw response. 71 | """ 72 | # Reading configuration from file 73 | organization = config.get("openai", "organization") 74 | api_key = config.get("openai", "api_key") 75 | model = config.get("openai", "model") 76 | 77 | messages = [{"role": "user", "content": message}] 78 | 79 | while True: 80 | try: 81 | client = OpenAI(organization=organization, api_key=api_key) 82 | response = client.chat.completions.create(model=model, messages=messages) 83 | 84 | raw_response = response.choices[0].message.content 85 | print(f"ChatGPT: {raw_response}") 86 | 87 | # Extracting bounding box data 88 | bbox_data = raw_response.split("Updated Objects:")[1] 89 | start_index = bbox_data.index("[") 90 | end_index = bbox_data.rindex("]") + 1 91 | bbox_str = bbox_data[start_index:end_index] 92 | 93 | # Converting string to list 94 | updated_bboxes = ast.literal_eval(bbox_str) 95 | 96 | break 97 | except Exception as e: 98 | print(f"Error occured when calling LLM API: {e}") 99 | time.sleep(5) 100 | 101 | return updated_bboxes, raw_response 102 | -------------------------------------------------------------------------------- /sld/llm_template.py: -------------------------------------------------------------------------------- 1 | # Template for self correction tasks --> parse the prompt 2 | spot_object_template = """# Your Role: Excellent Parser 3 | 4 | ## Objective: Analyze scene descriptions to identify objects and their attributes. 5 | 6 | ## Process Steps 7 | 1. Read the user prompt (scene description). 8 | 2. Identify all objects mentioned with quantities. 9 | 3. Extract attributes of each object (color, size, material, etc.). 10 | 4. If the description mentions objects that shouldn't be in the image, take note at the negation part. 11 | 5. Explain your understanding (reasoning) and then format your result (answer / negation) as shown in the examples. 12 | 6. Importance of Extracting Attributes: Attributes provide specific details about the objects. This helps differentiate between similar objects and gives a clearer understanding of the scene. 13 | 14 | ## Examples 15 | 16 | - Example 1 17 | User prompt: A brown horse is beneath a black dog. Another orange cat is beneath a brown horse. 18 | Reasoning: The description talks about three objects: a brown horse, a black dog, and an orange cat. We report the color attribute thoroughly. No specified negation terms. No background is mentioned and thus fill in the default one. 19 | Objects: [('horse', ['brown']), ('dog', ['black']), ('cat', ['orange'])] 20 | Background: A realistic image 21 | Negation: 22 | 23 | - Example 2 24 | User prompt: There's a white car and a yellow airplane in a garage. They're in front of two dogs and behind a cat. The car is small. Another yellow car is outside the garage. 25 | Reasoning: The scene has two cars, one airplane, two dogs, and a cat. The car and airplane have colors. The first car also has a size. No specified negation terms. The background is a garage. 26 | Objects: [('car', ['white and small', 'yellow']), ('airplane', ['yellow']), ('dog', [None, None]), ('cat', [None])] 27 | Background: A realistic image in a garage 28 | Negation: 29 | 30 | - Example 3 31 | User prompt: A car and a dog are on top of an airplane and below a red chair. There's another dog sitting on the mentioned chair. 32 | Reasoning: Four objects are described: one car, airplane, two dog, and a chair. The chair is red color. No specified negation terms. No background is mentioned and thus fill in the default one. 33 | Objects: [('car', [None]), ('airplane', [None]), ('dog', [None, None]), ('chair', ['red'])] 34 | Background: A realistic image 35 | Negation: 36 | 37 | - Example 4 38 | User prompt: An oil painting at the beach of a blue bicycle to the left of a bench and to the right of a palm tree with five seagulls in the sky. 39 | Reasoning: Here, there are five seagulls, one blue bicycle, one palm tree, and one bench. No specified negation terms. The background is an oil painting at the beach. 40 | Objects: [('bicycle', ['blue']), ('palm tree', [None]), ('seagull', [None, None, None, None, None]), ('bench', [None])] 41 | Background: An oil painting at the beach 42 | Negation: 43 | 44 | - Example 5 45 | User prompt: An animated-style image of a scene without backpacks. 46 | Reasoning: The description clearly states no backpacks, so this must be acknowledged. The user provides the negative prompt of backpacks. The background is an animated-style image. 47 | Objects: [('backpacks', [None])] 48 | Background: An animated-style image 49 | Negation: backpacks 50 | 51 | - Example 6 52 | User Prompt: Make the dog a sleeping dog and remove all shadows in an image of a grassland. 53 | Reasoning: The user prompt specifies a sleeping dog on the image and a shadow to be removed. The background is a realistic image of a grassland. 54 | Objects: [('dog', ['sleeping']), ['shadow', [None]]] 55 | Background: A realistic image of a grassland 56 | Negation: shadows 57 | 58 | Your Current Task: Follow the steps closely and accurately identify objects based on the given prompt. Ensure adherence to the above output format. 59 | 60 | """ 61 | 62 | # Template for self correction tasks --> adjust the bounding boxes 63 | spot_difference_template = """# Your Role: Expert Bounding Box Adjuster 64 | 65 | ## Objective: Manipulate bounding boxes in square images according to the user prompt while maintaining visual accuracy. 66 | 67 | ## Bounding Box Specifications and Manipulations 68 | 1. Image Coordinates: Define square images with top-left at [0, 0] and bottom-right at [1, 1]. 69 | 2. Box Format: [Top-left x, Top-left y, Width, Height] 70 | 3. Operations: Include addition, deletion, repositioning, and attribute modification. 71 | 72 | ## Key Guidelines 73 | 1. Alignment: Follow the user's prompt, keeping the specified object count and attributes. Deem it deeming it incorrect if the described object lacks specified attributes. 74 | 2. Boundary Adherence: Keep bounding box coordinates within [0, 1]. 75 | 3. Minimal Modifications: Change bounding boxes only if they don't match the user's prompt (i.e., don't modify matched objects). 76 | 4. Overlap Reduction: Minimize intersections in new boxes and remove the smallest, least overlapping objects. 77 | 78 | ## Process Steps 79 | 1. Interpret prompts: Read and understand the user's prompt. 80 | 2. Implement Changes: Review and adjust current bounding boxes to meet user specifications. 81 | 3. Explain Adjustments: Justify the reasons behind each alteration and ensure every adjustment abides by the key guidelines. 82 | 4. Output the Result: Present the reasoning first, followed by the updated objects section, which should include a list of bounding boxes in Python format. 83 | 84 | ## Examples 85 | 86 | - Example 1 87 | User prompt: A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky 88 | Current Objects: [('green car #1', [0.027, 0.365, 0.275, 0.207]), ('blue truck #1', [0.350, 0.368, 0.272, 0.208]), ('red air balloon #1', [0.086, 0.010, 0.189, 0.176])] 89 | Reasoning: To add a bird in the sky as per the prompt, ensuring all coordinates and dimensions remain within [0, 1]. 90 | Updated Objects: [('green car #1', [0.027, 0.365, 0.275, 0.207]), ('blue truck #1', [0.350, 0.369, 0.272, 0.208]), ('red air balloon #1', [0.086, 0.010, 0.189, 0.176]), ('bird #1', [0.385, 0.054, 0.186, 0.130])] 91 | 92 | - Example 2 93 | User prompt: A realistic image of landscape scene depicting a green car parking on the right of a blue truck, with a red air balloon and a bird in the sky 94 | Current Output Objects: [('green car #1', [0.027, 0.365, 0.275, 0.207]), ('blue truck #1', [0.350, 0.369, 0.272, 0.208]), ('red air balloon #1', [0.086, 0.010, 0.189, 0.176])] 95 | Reasoning: The relative positions of the green car and blue truck do not match the prompt. Swap positions of the green car and blue truck to match the prompt, while keeping all coordinates and dimensions within [0, 1]. 96 | Updated Objects: [('green car #1', [0.350, 0.369, 0.275, 0.207]), ('blue truck #1', [0.027, 0.365, 0.272, 0.208]), ('red air balloon #1', [0.086, 0.010, 0.189, 0.176]), ('bird #1', [0.485, 0.054, 0.186, 0.130])] 97 | 98 | - Example 3 99 | User prompt: An oil painting of a pink dolphin jumping on the left of a steam boat on the sea 100 | Current Objects: [('steam boat #1', [0.302, 0.293, 0.335, 0.194]), ('pink dolphin #1', [0.027, 0.324, 0.246, 0.160]), ('blue dolphin #1', [0.158, 0.454, 0.376, 0.290])] 101 | Reasoning: The prompt mentions only one dolphin, but two are present. Thus, remove one dolphin to match the prompt, ensuring all coordinates and dimensions stay within [0, 1]. 102 | Updated Objects: [('steam boat #1', [0.302, 0.293, 0.335, 0.194]), ('pink dolphin #1', [0.027, 0.324, 0.246, 0.160])] 103 | 104 | - Example 4 105 | User prompt: An oil painting of a pink dolphin jumping on the left of a steam boat on the sea 106 | Current Objects: [('steam boat #1', [0.302, 0.293, 0.335, 0.194]), ('dolphin #1', [0.027, 0.324, 0.246, 0.160])] 107 | Reasoning: The prompt specifies a pink dolphin, but there's only a generic one. The attribute needs to be changed. 108 | Updated Objects: [('steam boat #1', [0.302, 0.293, 0.335, 0.194]), ('pink dolphin #1', [0.027, 0.324, 0.246, 0.160])] 109 | 110 | - Example 5 111 | User prompt: A realistic photo of a scene with a brown bowl on the right and a gray dog on the left 112 | Current Objects: [('gray dog #1', [0.186, 0.592, 0.449, 0.408]), ('brown bowl #1', [0.376, 0.194, 0.624, 0.502])] 113 | Reasoning: The leftmost coordinate (0.186) of the gray dog's bounding box is positioned to the left of the leftmost coordinate (0.376) of the brown bowl, while the rightmost coordinate (0.186 + 0.449) of the bounding box has not extended beyond the rightmost coordinate of the bowl. Thus, the image aligns with the user's prompt, requiring no further modifications. 114 | Updated Objects: [('gray dog #1', [0.186, 0.592, 0.449, 0.408]), ('brown bowl #1', [0.376, 0.194, 0.624, 0.502])] 115 | 116 | Your Current Task: Carefully follow the provided guidelines and steps to adjust bounding boxes in accordance with the user's prompt. Ensure adherence to the above output format. 117 | 118 | """ 119 | 120 | 121 | image_edit_template = """# Your Role: Expert Bounding Box Adjuster 122 | 123 | ## Objective: Manipulate bounding boxes in square images according to user instructions while maintaining visual accuracy and avoiding boundary exceedance. 124 | 125 | ## Bounding Box Specifications and Manipulations 126 | 1. Image Coordinates: Define square images with top-left at [0, 0] and bottom-right at [1, 1]. 127 | 2. Box Format: [Top-left x, Top-left y, Width, Height] 128 | 3. Operations: Include addition, deletion, repositioning, and attribute modification. 129 | 130 | ## Key Guidelines 131 | 1. Alignment: Follow the user's prompt, keeping the specified object count and attributes. Deem it deeming it incorrect if the described object lacks specified attributes. 132 | 2. Boundary Adherence: Keep bounding box coordinates within [0, 1]. 133 | 3. Minimal Modifications: Change bounding boxes only if they don't match the user's prompt (i.e., don't modify matched objects). 134 | 4. Overlap Reduction: Minimize intersections in new boxes and remove the smallest, least overlapping objects. 135 | 136 | ## Process Steps 137 | 1. Interpret prompts: Read and understand the user's prompt. 138 | 2. Implement Changes: Review and adjust current bounding boxes to meet user specifications. 139 | 3. Explain Adjustments: Justify the reasons behind each alteration and ensure every adjustment abides by the key guidelines. 140 | 4. Output the Result: Present the reasoning first, followed by the updated prompts and objects section, which should include a list of bounding boxes in Python format. 141 | 142 | ## Examples: 143 | 144 | - Example 1 145 | User prompt: Move the green car to the right and make the blue truck larger in the image. 146 | Current Objects: [('green car #1', [0.027, 0.365, 0.275, 0.207]), ('blue truck #1', [0.350, 0.368, 0.272, 0.208])] 147 | Reasoning: To move the green car rightward, its x-coordinate needs to be increased from 0.027. The dimensions (height and width) of the blue truck must be enlarged. While adjusting bounding boxes, ensure they do not overlap excessively. All other elements remain unchanged. 148 | Updated Objects: [('green car #1', [0.327, 0.365, 0.275, 0.207]), ('blue truck #1', [0.350, 0.369, 0.472, 0.408])] 149 | 150 | - Example 2 151 | User prompt: Swap the positions of a green car and a blue truck in this landscape scene with an air balloon. 152 | Current Output Objects: [('green car #1', [0.350, 0.369, 0.275, 0.207]), ('blue truck #1', [0.027, 0.365, 0.272, 0.208]), ('red air balloon #1', [0.086, 0.010, 0.189, 0.176])] 153 | Reasoning: Exchange locations of the car and truck to align the bottom right part; other objects remain unchanged. 154 | Updated Objects: [('green car #1', [0.027, 0.365, 0.275, 0.207]), ('blue truck #1', [0.350, 0.364, 0.272, 0.208]), ('red air balloon #1', [0.086, 0.010, 0.189, 0.176])] 155 | 156 | - Example 3 157 | User prompt: Change the color of the dolphin from blue to pink in this oil painting of a dolphin and a steamboat. 158 | Current Objects: [('steam boat #1', [0.302, 0.293, 0.335, 0.194]), ('blue dolphin #1', [0.027, 0.324, 0.246, 0.160])] 159 | Reasoning: Alter only the dolphin's color from blue to pink, without modifying other elements. 160 | Updated Objects: [('steam boat #1', [0.302, 0.293, 0.335, 0.194]), ('pink dolphin #1', [0.027, 0.324, 0.246, 0.160])] 161 | 162 | - Example 4 163 | User prompt: Remove the leftmost bowl in this photo with two bowls and a dog. 164 | Current Objects: [('dog #1', [0.186, 0.592, 0.449, 0.408]), ('bowl #1', [0.376, 0.194, 0.324, 0.324]), ('bowl #2', [0.676, 0.494, 0.324, 0.324])] 165 | Reasoning: There are two bowls in the image and bowl #1 is identified as the leftmost one because its x coordinates (0.376) is smaller than that of bowl #2 (0.676).Thus, eliminate bowl #1 without modifying any remaining instances. 166 | Updated Objects: [('dog #1', [0.186, 0.592, 0.449, 0.408]), ('bowl #2', [0.676, 0.494, 0.324, 0.324])] 167 | 168 | - Example 5 169 | User prompt: Add a pink bowl between two existing bowls in this photo. 170 | Current Objects: [('bowl #1', [0.076, 0.494, 0.324, 0.324]), ('bowl #2', [0.676, 0.494, 0.324, 0.324])] 171 | Reasoning: There are two bowls in the image. To add a pink bowl between the two, the x coordinates should be placed between 0.076 and 0.676 and the y coordinates should be between 0.494 and 0.494. When adding the object, be sure to prevent overlapping between existing objects and make sure the [top-left x-coordinate, top-left y-coordinate, top-left x-coordinate+box width, top-left y-coordinate+box height] lie between 0 and 1. 172 | Updated Objects: [('bowl #1', [0.076, 0.494, 0.324, 0.324]), ('bowl #2', [0.676, 0.494, 0.324, 0.324]), ('bowl #3', [0.376, 0.494, 0.324, 0.324])] 173 | 174 | Your Current Task: Carefully follow the provided guidelines and steps to adjust bounding boxes in accordance with the user's prompt. Ensure adherence to the above output format. 175 | 176 | """ -------------------------------------------------------------------------------- /sld/sdxl_refine.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | from diffusers import StableDiffusionXLImg2ImgPipeline 4 | 5 | 6 | def sdxl_refine(prompt, input_fname, output_fname): 7 | torch.set_float32_matmul_precision("high") 8 | pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( 9 | "stabilityai/stable-diffusion-xl-refiner-1.0", 10 | torch_dtype=torch.float32, 11 | ) 12 | 13 | pipe = pipe.to("cuda") 14 | 15 | init_image = Image.open(input_fname) 16 | init_image = init_image.resize((1024, 1024), Image.LANCZOS) 17 | image = pipe( 18 | prompt, 19 | image=init_image, 20 | strength=0.3, 21 | aesthetic_score=7.0, 22 | num_inference_steps=50, 23 | ).images 24 | image[0].save(output_fname) 25 | -------------------------------------------------------------------------------- /sld/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | from models import sam 6 | from models import pipelines 7 | 8 | 9 | DEFAULT_SO_NEGATIVE_PROMPT = "artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate, two, many, group, occlusion, occluded, side, border, collate" 10 | DEFAULT_OVERALL_NEGATIVE_PROMPT = "artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate" 11 | 12 | 13 | 14 | def get_all_latents(img_np, models, inv_seed=1): 15 | generator = torch.cuda.manual_seed(inv_seed) 16 | cln_latents = pipelines.encode(models.model_dict, img_np, generator) 17 | # Magic prompt 18 | # Have tried using the parsed bg prompt from the LLM, but it doesn't work well 19 | prompt = "A realistic photo of a scene" 20 | input_embeddings = models.encode_prompts( 21 | prompts=[prompt], 22 | tokenizer=models.model_dict.tokenizer, 23 | text_encoder=models.model_dict.text_encoder, 24 | negative_prompt=DEFAULT_OVERALL_NEGATIVE_PROMPT, 25 | one_uncond_input_only=False, 26 | ) 27 | # Get all hidden latents 28 | all_latents = pipelines.invert( 29 | models.model_dict, 30 | cln_latents, 31 | input_embeddings, 32 | num_inference_steps=50, 33 | guidance_scale=2.5, 34 | ) 35 | return all_latents, input_embeddings 36 | 37 | 38 | def run_sam(bbox, image_source, models): 39 | H, W, _ = image_source.shape 40 | box_xyxy = torch.Tensor( 41 | [ 42 | bbox[0], 43 | bbox[1], 44 | bbox[2] + bbox[0], 45 | bbox[3] + bbox[1], 46 | ] 47 | ) * torch.Tensor([W, H, W, H]) 48 | box_xyxy = box_xyxy.unsqueeze(0).unsqueeze(0) 49 | masks, _ = sam.sam( 50 | models.model_dict, 51 | image_source, 52 | input_boxes=box_xyxy, 53 | target_mask_shape=(H, W), 54 | ) 55 | masks = masks[0][0].transpose(1, 2, 0).astype(bool) 56 | return masks 57 | 58 | 59 | def run_sam_postprocess(remove_mask, H, W, config): 60 | remove_mask = np.mean(remove_mask, axis=2) 61 | remove_mask[remove_mask > 0.05] = 1.0 62 | k_size = int(config.get("SLD", "SAM_refine_dilate")) 63 | kernel = np.ones((k_size, k_size), np.uint8) 64 | dilated_mask = cv2.dilate( 65 | (remove_mask * 255).astype(np.uint8), kernel, iterations=1 66 | ) 67 | # Resize the mask from the image size to the latent size 68 | remove_region = cv2.resize( 69 | dilated_mask.astype(np.int64), 70 | dsize=(W // 8, H // 8), 71 | interpolation=cv2.INTER_NEAREST, 72 | ) 73 | return remove_region 74 | 75 | 76 | def calculate_scale_ratio(region_a_param, region_b_param): 77 | _, _, a_width, a_height = region_a_param 78 | _, _, b_width, b_height = region_b_param 79 | scale_ratio_width = b_width / a_width 80 | scale_ratio_height = b_height / a_height 81 | return min(scale_ratio_width, scale_ratio_height) 82 | 83 | 84 | def resize_image(image, region_a_param, region_b_param): 85 | """ 86 | Resizes the image based on the scaling ratio between two regions and performs cropping or padding. 87 | """ 88 | old_h, old_w, _ = image.shape 89 | scale_ratio = calculate_scale_ratio(region_a_param, region_b_param) 90 | 91 | new_size = (int(old_w * scale_ratio), int(old_h * scale_ratio)) 92 | 93 | resized_image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA) 94 | new_h, new_w, _ = resized_image.shape 95 | region_a_param_real = [ 96 | int(region_a_param[0] * new_h), 97 | int(region_a_param[1] * new_w), 98 | int(region_a_param[2] * new_h), 99 | int(region_a_param[3] * new_w), 100 | ] 101 | if scale_ratio >= 1: # Cropping 102 | new_xmin = min(region_a_param_real[0], int(new_h - old_h)) 103 | new_ymin = min(region_a_param_real[1], int(new_w - old_w)) 104 | 105 | new_img = resized_image[ 106 | new_ymin : new_ymin + old_w, new_xmin : new_xmin + old_h 107 | ] 108 | 109 | new_param = [ 110 | (region_a_param_real[0] - new_xmin) / old_h, 111 | (region_a_param_real[1] - new_ymin) / old_w, 112 | region_a_param[2] * scale_ratio, 113 | region_a_param[3] * scale_ratio, 114 | ] 115 | else: # Padding 116 | new_img = np.ones((old_h, old_w, 3), dtype=np.uint8) * 255 117 | new_img[:new_h, :new_w] = resized_image 118 | new_param = [region_a_param[i] * scale_ratio for i in range(4)] 119 | 120 | return new_img, new_param 121 | 122 | 123 | def nms( 124 | bounding_boxes, 125 | confidence_score, 126 | labels, 127 | threshold, 128 | input_in_pixels=False, 129 | return_array=True, 130 | ): 131 | """ 132 | This NMS processes boxes of all labels. It not only removes the box with the same label. 133 | 134 | Adapted from https://github.com/amusi/Non-Maximum-Suppression/blob/master/nms.py 135 | """ 136 | # If no bounding boxes, return empty list 137 | if len(bounding_boxes) == 0: 138 | return np.array([]), np.array([]), np.array([]) 139 | 140 | # Bounding boxes 141 | boxes = np.array(bounding_boxes) 142 | 143 | # coordinates of bounding boxes 144 | start_x = boxes[:, 0] 145 | start_y = boxes[:, 1] 146 | end_x = boxes[:, 2] 147 | end_y = boxes[:, 3] 148 | 149 | # Confidence scores of bounding boxes 150 | score = np.array(confidence_score) 151 | 152 | # Picked bounding boxes 153 | picked_boxes = [] 154 | picked_score = [] 155 | picked_labels = [] 156 | 157 | # Compute areas of bounding boxes 158 | if input_in_pixels: 159 | areas = (end_x - start_x + 1) * (end_y - start_y + 1) 160 | else: 161 | areas = (end_x - start_x) * (end_y - start_y) 162 | 163 | # Sort by confidence score of bounding boxes 164 | order = np.argsort(score) 165 | 166 | # Iterate bounding boxes 167 | while order.size > 0: 168 | # The index of largest confidence score 169 | index = order[-1] 170 | 171 | # Pick the bounding box with largest confidence score 172 | picked_boxes.append(bounding_boxes[index]) 173 | picked_score.append(confidence_score[index]) 174 | picked_labels.append(labels[index]) 175 | 176 | # Compute ordinates of intersection-over-union(IOU) 177 | x1 = np.maximum(start_x[index], start_x[order[:-1]]) 178 | x2 = np.minimum(end_x[index], end_x[order[:-1]]) 179 | y1 = np.maximum(start_y[index], start_y[order[:-1]]) 180 | y2 = np.minimum(end_y[index], end_y[order[:-1]]) 181 | 182 | # Compute areas of intersection-over-union 183 | if input_in_pixels: 184 | w = np.maximum(0.0, x2 - x1 + 1) 185 | h = np.maximum(0.0, y2 - y1 + 1) 186 | else: 187 | w = np.maximum(0.0, x2 - x1) 188 | h = np.maximum(0.0, y2 - y1) 189 | intersection = w * h 190 | 191 | # Compute the ratio between intersection and union 192 | ratio = intersection / (areas[index] + areas[order[:-1]] - intersection) 193 | 194 | left = np.where(ratio < threshold) 195 | order = order[left] 196 | 197 | if return_array: 198 | picked_boxes, picked_score, picked_labels = ( 199 | np.array(picked_boxes), 200 | np.array(picked_score), 201 | np.array(picked_labels), 202 | ) 203 | 204 | return picked_boxes, picked_score, picked_labels 205 | 206 | 207 | def post_process(box): 208 | new_box = [] 209 | for item in box: 210 | item = min(1.0, max(0.0, item)) 211 | new_box.append(round(item, 3)) 212 | return new_box -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | -------------------------------------------------------------------------------- /utils/attn.py: -------------------------------------------------------------------------------- 1 | # visualization-related functions are in vis 2 | import numbers 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import math 7 | import utils 8 | 9 | def get_token_attnv2(token_id, saved_attns, attn_key, attn_aggregation_step_start=10, input_ca_has_condition_only=False, return_np=False): 10 | """ 11 | saved_attns: a list of saved_attn (list is across timesteps) 12 | 13 | moves to cpu by default 14 | """ 15 | saved_attns = saved_attns[attn_aggregation_step_start:] 16 | 17 | saved_attns = [saved_attn[attn_key].cpu() for saved_attn in saved_attns] 18 | 19 | attn = torch.stack(saved_attns, dim=0).mean(dim=0) 20 | 21 | # print("attn shape", attn.shape) 22 | 23 | # attn: (batch, head, spatial, text) 24 | 25 | if not input_ca_has_condition_only: 26 | assert attn.shape[0] == 2, f"Expect to have 2 items (uncond and cond), but found {attn.shape[0]} items" 27 | attn = attn[1] 28 | else: 29 | assert attn.shape[0] == 1, f"Expect to have 1 item (cond only), but found {attn.shape[0]} items" 30 | attn = attn[0] 31 | attn = attn.mean(dim=0)[:, token_id] 32 | H = W = int(math.sqrt(attn.shape[0])) 33 | attn = attn.reshape((H, W)) 34 | 35 | if return_np: 36 | return attn.numpy() 37 | 38 | return attn 39 | 40 | def shift_saved_attns_item(saved_attns_item, offset, guidance_attn_keys, horizontal_shift_only=False): 41 | """ 42 | `horizontal_shift_only`: only shift horizontally. If you use `offset` from `compose_latents_with_alignment` with `horizontal_shift_only=True`, the `offset` already has y_offset = 0 and this option is not needed. 43 | """ 44 | x_offset, y_offset = offset 45 | if horizontal_shift_only: 46 | y_offset = 0. 47 | 48 | new_saved_attns_item = {} 49 | for k in guidance_attn_keys: 50 | attn_map = saved_attns_item[k] 51 | 52 | attn_size = attn_map.shape[-2] 53 | attn_h = attn_w = int(math.sqrt(attn_size)) 54 | # Example dimensions: [batch_size, num_heads, 8, 8, num_tokens] 55 | attn_map = attn_map.unflatten(2, (attn_h, attn_w)) 56 | attn_map = utils.shift_tensor( 57 | attn_map, x_offset, y_offset, 58 | offset_normalized=True, ignore_last_dim=True 59 | ) 60 | attn_map = attn_map.flatten(2, 3) 61 | 62 | new_saved_attns_item[k] = attn_map 63 | 64 | return new_saved_attns_item 65 | 66 | def shift_saved_attns(saved_attns, offset, guidance_attn_keys, **kwargs): 67 | # Iterate over timesteps 68 | shifted_saved_attns = [shift_saved_attns_item(saved_attns_item, offset, guidance_attn_keys, **kwargs) for saved_attns_item in saved_attns] 69 | 70 | return shifted_saved_attns 71 | 72 | 73 | class GaussianSmoothing(nn.Module): 74 | """ 75 | Apply gaussian smoothing on a 76 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 77 | in the input using a depthwise convolution. 78 | Arguments: 79 | channels (int, sequence): Number of channels of the input tensors. Output will 80 | have this number of channels as well. 81 | kernel_size (int, sequence): Size of the gaussian kernel. 82 | sigma (float, sequence): Standard deviation of the gaussian kernel. 83 | dim (int, optional): The number of dimensions of the data. 84 | Default value is 2 (spatial). 85 | 86 | Credit: https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/10 87 | """ 88 | 89 | def __init__(self, channels, kernel_size, sigma, dim=2): 90 | super(GaussianSmoothing, self).__init__() 91 | if isinstance(kernel_size, numbers.Number): 92 | kernel_size = [kernel_size] * dim 93 | if isinstance(sigma, numbers.Number): 94 | sigma = [sigma] * dim 95 | 96 | # The gaussian kernel is the product of the 97 | # gaussian function of each dimension. 98 | kernel = 1 99 | meshgrids = torch.meshgrid( 100 | [ 101 | torch.arange(size, dtype=torch.float32) 102 | for size in kernel_size 103 | ] 104 | ) 105 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 106 | mean = (size - 1) / 2 107 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ 108 | torch.exp(-((mgrid - mean) / (2 * std)) ** 2) 109 | 110 | # Make sure sum of values in gaussian kernel equals 1. 111 | kernel = kernel / torch.sum(kernel) 112 | 113 | # Reshape to depthwise convolutional weight 114 | kernel = kernel.view(1, 1, *kernel.size()) 115 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 116 | 117 | self.register_buffer('weight', kernel) 118 | self.groups = channels 119 | 120 | if dim == 1: 121 | self.conv = F.conv1d 122 | elif dim == 2: 123 | self.conv = F.conv2d 124 | elif dim == 3: 125 | self.conv = F.conv3d 126 | else: 127 | raise RuntimeError( 128 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format( 129 | dim) 130 | ) 131 | 132 | def forward(self, input): 133 | """ 134 | Apply gaussian filter to input. 135 | Arguments: 136 | input (torch.Tensor): Input to apply gaussian filter on. 137 | Returns: 138 | filtered (torch.Tensor): Filtered output. 139 | """ 140 | return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) 141 | -------------------------------------------------------------------------------- /utils/boxdiff.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is an reimplementation boxdiff baseline for reference and comparison. It is not used in the Web UI and not enabled by default since the current attention guidance implementation (in `guidance`), which uses attention maps from multiple levels and attention transfer, seems to be more robust and coherent. 3 | 4 | Credit: https://github.com/showlab/BoxDiff/blob/master/pipeline/sd_pipeline_boxdiff.py 5 | """ 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import math 10 | import warnings 11 | import gc 12 | from collections.abc import Iterable 13 | import utils 14 | from . import guidance 15 | from .attn import GaussianSmoothing 16 | 17 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple 18 | 19 | 20 | def _compute_max_attention_per_index(attention_maps: torch.Tensor, 21 | object_positions: List[List[int]], 22 | smooth_attentions: bool = False, 23 | sigma: float = 0.5, 24 | kernel_size: int = 3, 25 | normalize_eot: bool = False, 26 | bboxes: List[List[int]] = None, 27 | P: float = 0.2, 28 | L: int = 1, 29 | ) -> List[torch.Tensor]: 30 | """ Computes the maximum attention value for each of the tokens we wish to alter. """ 31 | last_idx = -1 32 | assert not normalize_eot, "normalize_eot is unimplemented" 33 | 34 | attention_for_text = attention_maps[:, :, 1:last_idx] 35 | attention_for_text *= 100 36 | attention_for_text = F.softmax(attention_for_text, dim=-1) 37 | 38 | # Extract the maximum values 39 | max_indices_list_fg = [] 40 | max_indices_list_bg = [] 41 | dist_x = [] 42 | dist_y = [] 43 | 44 | for obj_idx, text_positions_per_obj in enumerate(object_positions): 45 | for text_position_per_obj in text_positions_per_obj: 46 | # Shift indices since we removed the first token 47 | image = attention_for_text[:, :, text_position_per_obj - 1] 48 | H, W = image.shape 49 | 50 | obj_mask = torch.zeros_like(image) 51 | corner_mask_x = torch.zeros( 52 | (W,), device=obj_mask.device, dtype=obj_mask.dtype) 53 | corner_mask_y = torch.zeros( 54 | (H,), device=obj_mask.device, dtype=obj_mask.dtype) 55 | 56 | obj_boxes = bboxes[obj_idx] 57 | 58 | # We support two level (one box per phrase) and three level (multiple boxes per phrase) 59 | if not isinstance(obj_boxes[0], Iterable): 60 | obj_boxes = [obj_boxes] 61 | 62 | for obj_box in obj_boxes: 63 | x_min, y_min, x_max, y_max = utils.scale_proportion( 64 | obj_box, H=H, W=W) 65 | obj_mask[y_min: y_max, x_min: x_max] = 1 66 | 67 | corner_mask_x[max(x_min - L, 0): min(x_min + L + 1, W)] = 1. 68 | corner_mask_x[max(x_max - L, 0): min(x_max + L + 1, W)] = 1. 69 | corner_mask_y[max(y_min - L, 0): min(y_min + L + 1, H)] = 1. 70 | corner_mask_y[max(y_max - L, 0): min(y_max + L + 1, H)] = 1. 71 | 72 | bg_mask = 1 - obj_mask 73 | 74 | if smooth_attentions: 75 | smoothing = GaussianSmoothing( 76 | channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda() 77 | input = F.pad(image.unsqueeze(0).unsqueeze(0), 78 | (1, 1, 1, 1), mode='reflect') 79 | image = smoothing(input).squeeze(0).squeeze(0) 80 | 81 | # Inner-Box constraint 82 | k = (obj_mask.sum() * P).long() 83 | max_indices_list_fg.append( 84 | (image * obj_mask).reshape(-1).topk(k)[0].mean()) 85 | 86 | # Outer-Box constraint 87 | k = (bg_mask.sum() * P).long() 88 | max_indices_list_bg.append( 89 | (image * bg_mask).reshape(-1).topk(k)[0].mean()) 90 | 91 | # Corner Constraint 92 | gt_proj_x = torch.max(obj_mask, dim=0).values 93 | gt_proj_y = torch.max(obj_mask, dim=1).values 94 | 95 | # create gt according to the number L 96 | dist_x.append((F.l1_loss(image.max(dim=0)[ 97 | 0], gt_proj_x, reduction='none') * corner_mask_x).mean()) 98 | dist_y.append((F.l1_loss(image.max(dim=1)[ 99 | 0], gt_proj_y, reduction='none') * corner_mask_y).mean()) 100 | 101 | return max_indices_list_fg, max_indices_list_bg, dist_x, dist_y 102 | 103 | 104 | def _compute_loss(max_attention_per_index_fg: List[torch.Tensor], max_attention_per_index_bg: List[torch.Tensor], 105 | dist_x: List[torch.Tensor], dist_y: List[torch.Tensor], return_losses: bool = False) -> torch.Tensor: 106 | """ Computes the attend-and-excite loss using the maximum attention value for each token. """ 107 | losses_fg = [max(0, 1. - curr_max) 108 | for curr_max in max_attention_per_index_fg] 109 | losses_bg = [max(0, curr_max) for curr_max in max_attention_per_index_bg] 110 | loss = sum(losses_fg) + sum(losses_bg) + sum(dist_x) + sum(dist_y) 111 | 112 | # print(f"{losses_fg}, {losses_bg}, {dist_x}, {dist_y}, {loss}") 113 | 114 | if return_losses: 115 | return max(losses_fg), losses_fg 116 | else: 117 | return max(losses_fg), loss 118 | 119 | 120 | def compute_ca_loss_boxdiff(saved_attn, bboxes, object_positions, guidance_attn_keys, ref_ca_saved_attns=None, ref_ca_last_token_only=True, ref_ca_word_token_only=False, word_token_indices=None, index=None, ref_ca_loss_weight=1.0, verbose=False, **kwargs): 121 | """ 122 | v3 is equivalent to v2 but with new dictionary format for attention maps. 123 | The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss. 124 | `AttnProcessor` will put attention maps into the `save_attn_to_dict`. 125 | 126 | `index` is the timestep. 127 | `ref_ca_word_token_only`: This has precedence over `ref_ca_last_token_only` (i.e., if both are enabled, we take the token from word rather than the last token). 128 | `ref_ca_last_token_only`: `ref_ca_saved_attn` comes from the attention map of the last token of the phrase in single object generation, so we apply it only to the last token of the phrase in overall generation if this is set to True. If set to False, `ref_ca_saved_attn` will be applied to all the text tokens. 129 | """ 130 | loss = torch.tensor(0).float().cuda() 131 | object_number = len(bboxes) 132 | if object_number == 0: 133 | return loss 134 | 135 | attn_map_list = [] 136 | 137 | for attn_key in guidance_attn_keys: 138 | # We only have 1 cross attention for mid. 139 | attn_map_integrated = saved_attn[attn_key] 140 | if not attn_map_integrated.is_cuda: 141 | attn_map_integrated = attn_map_integrated.cuda() 142 | # Example dimension: [20, 64, 77] 143 | attn_map = attn_map_integrated.squeeze(dim=0) 144 | attn_map_list.append(attn_map) 145 | # This averages both across layers and across attention heads 146 | attn_map = torch.cat(attn_map_list, dim=0).mean(dim=0) 147 | loss = add_ca_loss_per_attn_map_to_loss_boxdiff( 148 | loss, attn_map, object_number, bboxes, object_positions, verbose=verbose, **kwargs) 149 | 150 | if ref_ca_saved_attns is not None: 151 | warnings.warn('Attention reference loss is enabled in boxdiff mode. The original boxdiff does not have attention reference loss.') 152 | 153 | ref_loss = torch.tensor(0).float().cuda() 154 | ref_loss = guidance.add_ref_ca_loss_per_attn_map_to_lossv2( 155 | ref_loss, saved_attn=saved_attn, object_number=object_number, bboxes=bboxes, object_positions=object_positions, guidance_attn_keys=guidance_attn_keys, 156 | ref_ca_saved_attns=ref_ca_saved_attns, ref_ca_last_token_only=ref_ca_last_token_only, ref_ca_word_token_only=ref_ca_word_token_only, word_token_indices=word_token_indices, verbose=verbose, index=index, loss_weight=ref_ca_loss_weight 157 | ) 158 | print(f"loss {loss.item():.3f}, reference attention loss (weighted) {ref_loss.item():.3f}") 159 | loss += ref_loss 160 | 161 | return loss 162 | 163 | 164 | def add_ca_loss_per_attn_map_to_loss_boxdiff(original_loss, attention_maps, object_number, bboxes, object_positions, P=0.2, L=1, smooth_attentions=True, sigma=0.5, kernel_size=3, normalize_eot=False, verbose=False): 165 | # NOTE: normalize_eot is enabled in SD v2.1 in boxdiff 166 | i, j = attention_maps.shape 167 | H = W = int(math.sqrt(i)) 168 | 169 | attention_maps = attention_maps.view(H, W, j) 170 | # attention_maps is aggregated cross attn map across layers and steps 171 | # attention_maps shape: [H, W, 77] 172 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = _compute_max_attention_per_index( 173 | attention_maps=attention_maps, 174 | object_positions=object_positions, 175 | smooth_attentions=smooth_attentions, 176 | sigma=sigma, 177 | kernel_size=kernel_size, 178 | normalize_eot=normalize_eot, 179 | bboxes=bboxes, 180 | P=P, 181 | L=L 182 | ) 183 | 184 | _, loss = _compute_loss(max_attention_per_index_fg, 185 | max_attention_per_index_bg, dist_x, dist_y) 186 | 187 | return original_loss + loss 188 | 189 | 190 | def latent_backward_guidance_boxdiff(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, amp_loss_scale=10, latent_scale=20, scale_range=(1., 0.5), max_index_step=25, cross_attention_kwargs=None, ref_ca_saved_attns=None, guidance_attn_keys=None, verbose=False, **kwargs): 191 | """ 192 | amp_loss_scale: this scales the loss but will de-scale before applying for latents. This is to prevent overflow/underflow with amp, not to adjust the update step size. 193 | latent_scale: this scales the step size for update (scale_factor in boxdiff). 194 | """ 195 | 196 | if index < max_index_step: 197 | saved_attn = {} 198 | full_cross_attention_kwargs = { 199 | 'save_attn_to_dict': saved_attn, 200 | 'save_keys': guidance_attn_keys, 201 | } 202 | 203 | if cross_attention_kwargs is not None: 204 | full_cross_attention_kwargs.update(cross_attention_kwargs) 205 | 206 | latents.requires_grad_(True) 207 | latent_model_input = latents 208 | latent_model_input = scheduler.scale_model_input(latent_model_input, t) 209 | 210 | unet(latent_model_input, t, encoder_hidden_states=cond_embeddings, 211 | return_cross_attention_probs=False, cross_attention_kwargs=full_cross_attention_kwargs) 212 | 213 | # TODO: could return the attention maps for the required blocks only and not necessarily the final output 214 | # update latents with guidance 215 | loss = compute_ca_loss_boxdiff(saved_attn=saved_attn, bboxes=bboxes, object_positions=object_positions, guidance_attn_keys=guidance_attn_keys, 216 | ref_ca_saved_attns=ref_ca_saved_attns, index=index, verbose=verbose, **kwargs) * amp_loss_scale 217 | 218 | if torch.isnan(loss): 219 | print("**Loss is NaN**") 220 | 221 | del full_cross_attention_kwargs, saved_attn 222 | # call gc.collect() here may release some memory 223 | 224 | grad_cond = torch.autograd.grad( 225 | loss.requires_grad_(True), [latents])[0] 226 | 227 | latents.requires_grad_(False) 228 | 229 | if True: 230 | warnings.warn("Using guidance scaled with sqrt scale") 231 | # According to boxdiff's implementation: https://github.com/Sierkinhane/BoxDiff/blob/16ffb677a9128128e04553a0200870a526731be0/pipeline/sd_pipeline_boxdiff.py#L616 232 | scale = (scale_range[0] + (scale_range[1] - scale_range[0]) 233 | * index / (len(scheduler.timesteps) - 1)) ** (0.5) 234 | latents = latents - latent_scale * scale / amp_loss_scale * grad_cond 235 | elif hasattr(scheduler, 'sigmas'): 236 | warnings.warn("Using guidance scaled with sigmas") 237 | scale = scheduler.sigmas[index] ** 2 238 | latents = latents - grad_cond * scale 239 | elif hasattr(scheduler, 'alphas_cumprod'): 240 | warnings.warn("Using guidance scaled with alphas_cumprod") 241 | # Scaling with classifier guidance 242 | alpha_prod_t = scheduler.alphas_cumprod[t] 243 | # Classifier guidance: https://arxiv.org/pdf/2105.05233.pdf 244 | # DDIM: https://arxiv.org/pdf/2010.02502.pdf 245 | scale = (1 - alpha_prod_t) ** (0.5) 246 | latents = latents - latent_scale * scale / amp_loss_scale * grad_cond 247 | else: 248 | warnings.warn("No scaling in guidance is performed") 249 | scale = 1 250 | latents = latents - grad_cond 251 | 252 | gc.collect() 253 | torch.cuda.empty_cache() 254 | 255 | if verbose: 256 | print( 257 | f"time index {index}, loss: {loss.item() / amp_loss_scale:.3f} (de-scaled with scale {amp_loss_scale:.1f}), latent grad scale: {scale:.3f}") 258 | 259 | return latents, loss 260 | -------------------------------------------------------------------------------- /utils/parse.py: -------------------------------------------------------------------------------- 1 | import ast 2 | from matplotlib.patches import Polygon, Rectangle 3 | from matplotlib.collections import PatchCollection 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import warnings 7 | import inflect 8 | 9 | p = inflect.engine() 10 | user_error = ValueError 11 | 12 | img_dir = "imgs" 13 | objects_text = "Objects: " 14 | bg_prompt_text = "Background prompt: " 15 | bg_prompt_text_no_trailing_space = bg_prompt_text.rstrip() 16 | neg_prompt_text = "Negative prompt: " 17 | neg_prompt_text_no_trailing_space = neg_prompt_text.rstrip() 18 | 19 | # h, w 20 | box_scale = (512, 512) 21 | size = box_scale 22 | size_h, size_w = size 23 | print(f"Using box scale: {box_scale}") 24 | 25 | 26 | def parse_input(text=None, no_input=False): 27 | warnings.warn("Parsing input without negative prompt is deprecated.") 28 | 29 | if not text: 30 | if no_input: 31 | raise user_error(f'No input parsed in "{text}".') 32 | 33 | text = input("Enter the response: ") 34 | if objects_text in text: 35 | text = text.split(objects_text)[1] 36 | 37 | text_split = text.split(bg_prompt_text_no_trailing_space) 38 | if len(text_split) == 2: 39 | gen_boxes, bg_prompt = text_split 40 | elif len(text_split) == 1: 41 | if no_input: 42 | raise user_error(f"Invalid input (no background prompt): {text}") 43 | gen_boxes = text 44 | bg_prompt = "" 45 | while not bg_prompt: 46 | # Ignore the empty lines in the response 47 | bg_prompt = input("Enter the background prompt: ").strip() 48 | if bg_prompt_text_no_trailing_space in bg_prompt: 49 | bg_prompt = bg_prompt.split(bg_prompt_text_no_trailing_space)[1] 50 | else: 51 | raise user_error( 52 | f"Invalid input (possibly multiple background prompts): {text}" 53 | ) 54 | try: 55 | gen_boxes = ast.literal_eval(gen_boxes) 56 | except SyntaxError as e: 57 | # Sometimes the response is in plain text 58 | if "No objects" in gen_boxes: 59 | gen_boxes = [] 60 | else: 61 | raise e 62 | bg_prompt = bg_prompt.strip() 63 | 64 | return gen_boxes, bg_prompt 65 | 66 | 67 | def parse_input_with_negative(text=None, no_input=False): 68 | # no_input: should not request interactive input 69 | 70 | if not text: 71 | if no_input: 72 | raise user_error(f'No input parsed in "{text}".') 73 | 74 | text = input("Enter the response: ") 75 | if objects_text in text: 76 | text = text.split(objects_text)[1] 77 | 78 | text_split = text.split(bg_prompt_text_no_trailing_space) 79 | if len(text_split) == 2: 80 | gen_boxes, text_rem = text_split 81 | elif len(text_split) == 1: 82 | if no_input: 83 | raise user_error(f"Invalid input (no background prompt): {text}") 84 | gen_boxes = text 85 | text_rem = "" 86 | while not text_rem: 87 | # Ignore the empty lines in the response 88 | text_rem = input("Enter the background prompt: ").strip() 89 | if bg_prompt_text_no_trailing_space in text_rem: 90 | text_rem = text_rem.split(bg_prompt_text_no_trailing_space)[1] 91 | else: 92 | raise user_error( 93 | f"Invalid input (possibly multiple background prompts): {text}" 94 | ) 95 | 96 | text_split = text_rem.split(neg_prompt_text_no_trailing_space) 97 | 98 | if len(text_split) == 2: 99 | bg_prompt, neg_prompt = text_split 100 | elif len(text_split) == 1: 101 | bg_prompt = text_rem 102 | # Negative prompt is optional: if it's not provided, we default to empty string 103 | neg_prompt = "" 104 | if not no_input: 105 | # Ignore the empty lines in the response 106 | neg_prompt = input("Enter the negative prompt: ").strip() 107 | if neg_prompt_text_no_trailing_space in neg_prompt: 108 | neg_prompt = neg_prompt.split(neg_prompt_text_no_trailing_space)[1] 109 | else: 110 | raise user_error(f"Invalid input (possibly multiple negative prompts): {text}") 111 | 112 | try: 113 | gen_boxes = ast.literal_eval(gen_boxes) 114 | except SyntaxError as e: 115 | # Sometimes the response is in plain text 116 | if "No objects" in gen_boxes or gen_boxes.strip() == "": 117 | gen_boxes = [] 118 | else: 119 | raise e 120 | bg_prompt = bg_prompt.strip() 121 | neg_prompt = neg_prompt.strip() 122 | 123 | # LLM may return "None" to mean no negative prompt provided. 124 | if neg_prompt == "None": 125 | neg_prompt = "" 126 | 127 | return gen_boxes, bg_prompt, neg_prompt 128 | 129 | 130 | def filter_boxes(gen_boxes, scale_boxes=True, ignore_background=True, max_scale=3): 131 | if gen_boxes is None: 132 | return [] 133 | 134 | if len(gen_boxes) == 0: 135 | return [] 136 | 137 | box_dict_format = False 138 | gen_boxes_new = [] 139 | for gen_box in gen_boxes: 140 | if isinstance(gen_box, dict): 141 | if not gen_box["bounding_box"]: 142 | continue 143 | name, [bbox_x, bbox_y, bbox_w, bbox_h] = ( 144 | gen_box["name"], 145 | gen_box["bounding_box"], 146 | ) 147 | box_dict_format = True 148 | else: 149 | if not gen_box[1]: 150 | continue 151 | name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box 152 | if bbox_w <= 0 or bbox_h <= 0: 153 | # Empty boxes 154 | continue 155 | if ignore_background: 156 | if ( 157 | (bbox_w >= size[1] and bbox_h >= size[0]) 158 | or bbox_x > size[1] 159 | or bbox_y > size[0] 160 | ): 161 | # Ignore the background boxes 162 | continue 163 | 164 | if ( 165 | bbox_x < 0 166 | or bbox_y < 0 167 | or bbox_x + bbox_w > size[1] 168 | or bbox_y + bbox_h > size[0] 169 | ): 170 | # Out of bounds boxes exist: we need to scale and shift all the boxes 171 | print( 172 | f"**Some boxes are out of bounds: {gen_box}, scaling all the boxes to fit**" 173 | ) 174 | scale_boxes = True 175 | 176 | gen_boxes_new.append(gen_box) 177 | 178 | gen_boxes = gen_boxes_new 179 | 180 | if len(gen_boxes) == 0: 181 | return [] 182 | 183 | filtered_gen_boxes = [] 184 | if box_dict_format: 185 | # For compatibility 186 | bbox_left_x_min = min([gen_box["bounding_box"][0] for gen_box in gen_boxes]) 187 | bbox_right_x_max = max( 188 | [ 189 | gen_box["bounding_box"][0] + gen_box["bounding_box"][2] 190 | for gen_box in gen_boxes 191 | ] 192 | ) 193 | bbox_top_y_min = min([gen_box["bounding_box"][1] for gen_box in gen_boxes]) 194 | bbox_bottom_y_max = max( 195 | [ 196 | gen_box["bounding_box"][1] + gen_box["bounding_box"][3] 197 | for gen_box in gen_boxes 198 | ] 199 | ) 200 | else: 201 | bbox_left_x_min = min([gen_box[1][0] for gen_box in gen_boxes]) 202 | bbox_right_x_max = max([gen_box[1][0] + gen_box[1][2] for gen_box in gen_boxes]) 203 | bbox_top_y_min = min([gen_box[1][1] for gen_box in gen_boxes]) 204 | bbox_bottom_y_max = max( 205 | [gen_box[1][1] + gen_box[1][3] for gen_box in gen_boxes] 206 | ) 207 | 208 | # All boxes are empty 209 | if (bbox_right_x_max - bbox_left_x_min) == 0: 210 | return [] 211 | 212 | # Used if scale_boxes is True 213 | shift = -bbox_left_x_min 214 | # Make sure the boxes fit horizontally and vertically 215 | scale_w = size_w / (bbox_right_x_max - bbox_left_x_min) 216 | scale_h = size_h / (bbox_bottom_y_max - bbox_top_y_min) 217 | 218 | scale = min(scale_w, scale_h, max_scale) 219 | 220 | for gen_box in gen_boxes: 221 | if box_dict_format: 222 | name, [bbox_x, bbox_y, bbox_w, bbox_h] = ( 223 | gen_box["name"], 224 | gen_box["bounding_box"], 225 | ) 226 | else: 227 | name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box 228 | 229 | if scale_boxes: 230 | # Vertical: move the boxes if out of bound 231 | # Horizontal: move and scale the boxes so it spans the horizontal line 232 | 233 | bbox_x = (bbox_x + shift) * scale 234 | bbox_y = bbox_y * scale 235 | bbox_w, bbox_h = bbox_w * scale, bbox_h * scale 236 | # TODO: verify this makes the y center not moving 237 | bbox_y_offset = 0 238 | if bbox_top_y_min * scale + bbox_y_offset < 0: 239 | bbox_y_offset -= bbox_top_y_min * scale 240 | if bbox_bottom_y_max * scale + bbox_y_offset >= size_h: 241 | bbox_y_offset -= bbox_bottom_y_max * scale - size_h 242 | bbox_y += bbox_y_offset 243 | 244 | if bbox_y < 0: 245 | bbox_y, bbox_h = 0, bbox_h - bbox_y 246 | 247 | name = name.rstrip(".") 248 | bounding_box = ( 249 | int(np.round(bbox_x)), 250 | int(np.round(bbox_y)), 251 | int(np.round(bbox_w)), 252 | int(np.round(bbox_h)), 253 | ) 254 | if box_dict_format: 255 | gen_box = {"name": name, "bounding_box": bounding_box} 256 | else: 257 | gen_box = (name, bounding_box) 258 | 259 | filtered_gen_boxes.append(gen_box) 260 | 261 | return filtered_gen_boxes 262 | 263 | 264 | def draw_boxes(ax, anns): 265 | ax.set_autoscale_on(False) 266 | polygons = [] 267 | color = [] 268 | for ann in anns: 269 | c = np.random.random((1, 3)) * 0.6 + 0.4 270 | [bbox_x, bbox_y, bbox_w, bbox_h] = ann["bbox"] 271 | poly = [ 272 | [bbox_x, bbox_y], 273 | [bbox_x, bbox_y + bbox_h], 274 | [bbox_x + bbox_w, bbox_y + bbox_h], 275 | [bbox_x + bbox_w, bbox_y], 276 | ] 277 | np_poly = np.array(poly).reshape((4, 2)) 278 | polygons.append(Polygon(np_poly)) 279 | color.append(c) 280 | 281 | # print(ann) 282 | name = ann["name"] if "name" in ann else str(ann["category_id"]) 283 | ax.text( 284 | bbox_x, 285 | bbox_y, 286 | name, 287 | style="italic", 288 | bbox={"facecolor": "white", "alpha": 0.7, "pad": 5}, 289 | ) 290 | 291 | p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2) 292 | ax.add_collection(p) 293 | 294 | 295 | def show_boxes( 296 | gen_boxes, 297 | additional_boxes=None, # New parameter for the second set of boxes 298 | bg_prompt=None, 299 | neg_prompt=None, 300 | show=False, 301 | save=False, 302 | img=None, 303 | fname=None, 304 | ): 305 | if len(gen_boxes) == 0 and (additional_boxes is None or len(additional_boxes) == 0): 306 | return 307 | 308 | def prepare_annotations(boxes): 309 | if isinstance(boxes[0], dict): 310 | return [{"name": box["name"], "bbox": box["bounding_box"]} for box in boxes] 311 | else: 312 | return [ 313 | {"name": box[0], "bbox": [int(x * 512) for x in box[1]]} 314 | for box in boxes 315 | ] 316 | 317 | anns = prepare_annotations(gen_boxes) 318 | additional_anns = prepare_annotations(additional_boxes) if additional_boxes else [] 319 | 320 | # Create a figure with two subplots 321 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 5)) 322 | 323 | # Plot for gen_boxes 324 | ax1.imshow(img) 325 | ax1.axis("off") 326 | ax1.set_title("Curr Layout", pad=20) 327 | draw_boxes(ax1, anns) 328 | 329 | # Plot for additional_boxes 330 | ax2.imshow(np.ones((512, 512, 3), dtype=np.uint8) * 255) 331 | ax2.axis("off") 332 | ax2.set_title("New Layout", pad=20) 333 | draw_boxes(ax2, additional_anns) 334 | 335 | # Add background prompt if present 336 | if bg_prompt is not None: 337 | for ax in [ax1, ax2]: 338 | ax.text( 339 | 0, 340 | 0, 341 | bg_prompt + f" (Neg: {neg_prompt})" if neg_prompt else bg_prompt, 342 | style="italic", 343 | bbox={"facecolor": "white", "alpha": 0.7, "pad": 5}, 344 | fontsize=8, 345 | ) 346 | 347 | if show: 348 | plt.show() 349 | else: 350 | print("Saved boxes visualizations to", f"{fname}") 351 | plt.savefig(fname) 352 | plt.clf() 353 | 354 | 355 | def show_masks(masks): 356 | masks_to_show = np.zeros((*size, 3), dtype=np.float32) 357 | for mask in masks: 358 | c = np.random.random((3,)) * 0.6 + 0.4 359 | 360 | masks_to_show += mask[..., None] * c[None, None, :] 361 | plt.imshow(masks_to_show) 362 | plt.savefig(f"{img_dir}/masks.png") 363 | plt.show() 364 | plt.clf() 365 | 366 | 367 | def convert_box(box, height, width): 368 | # box: x, y, w, h (in 512 format) -> x_min, y_min, x_max, y_max 369 | x_min, y_min = box[0] / width, box[1] / height 370 | w_box, h_box = box[2] / width, box[3] / height 371 | 372 | x_max, y_max = x_min + w_box, y_min + h_box 373 | 374 | return x_min, y_min, x_max, y_max 375 | 376 | 377 | def convert_spec(spec, height, width, include_counts=True, verbose=False): 378 | # Infer from spec 379 | prompt, gen_boxes, bg_prompt = spec["prompt"], spec["gen_boxes"], spec["bg_prompt"] 380 | 381 | # This ensures the same objects appear together because flattened `overall_phrases_bboxes` should EXACTLY correspond to `so_prompt_phrase_box_list`. 382 | gen_boxes = sorted(gen_boxes, key=lambda gen_box: gen_box[0]) 383 | 384 | gen_boxes = [ 385 | (name, convert_box(box, height=height, width=width)) for name, box in gen_boxes 386 | ] 387 | 388 | # NOTE: so phrase should include all the words associated to the object (otherwise "an orange dog" may be recognized as "an orange" by the model generating the background). 389 | # so word should have one token that includes the word to transfer cross attention (the object name). 390 | # Currently using the last word of the object name as word. 391 | if bg_prompt: 392 | so_prompt_phrase_word_box_list = [ 393 | (f"{bg_prompt} with {name}", name, name.split(" ")[-1], box) 394 | for name, box in gen_boxes 395 | ] 396 | else: 397 | so_prompt_phrase_word_box_list = [ 398 | (f"{name}", name, name.split(" ")[-1], box) for name, box in gen_boxes 399 | ] 400 | 401 | objects = [gen_box[0] for gen_box in gen_boxes] 402 | 403 | objects_unique, objects_count = np.unique(objects, return_counts=True) 404 | 405 | num_total_matched_boxes = 0 406 | overall_phrases_words_bboxes = [] 407 | for ind, object_name in enumerate(objects_unique): 408 | bboxes = [box for name, box in gen_boxes if name == object_name] 409 | 410 | if objects_count[ind] > 1: 411 | phrase = p.plural_noun(object_name.replace("an ", "").replace("a ", "")) 412 | if include_counts: 413 | phrase = p.number_to_words(objects_count[ind]) + " " + phrase 414 | else: 415 | phrase = object_name 416 | # Currently using the last word of the phrase as word. 417 | word = phrase.split(" ")[-1] 418 | 419 | num_total_matched_boxes += len(bboxes) 420 | overall_phrases_words_bboxes.append((phrase, word, bboxes)) 421 | 422 | assert num_total_matched_boxes == len( 423 | gen_boxes 424 | ), f"{num_total_matched_boxes} != {len(gen_boxes)}" 425 | 426 | objects_str = ", ".join([phrase for phrase, _, _ in overall_phrases_words_bboxes]) 427 | if objects_str: 428 | if bg_prompt: 429 | overall_prompt = f"{bg_prompt} with {objects_str}" 430 | else: 431 | overall_prompt = objects_str 432 | else: 433 | overall_prompt = bg_prompt 434 | 435 | if verbose: 436 | print("so_prompt_phrase_word_box_list:", so_prompt_phrase_word_box_list) 437 | print("overall_prompt:", overall_prompt) 438 | print("overall_phrases_words_bboxes:", overall_phrases_words_bboxes) 439 | 440 | return so_prompt_phrase_word_box_list, overall_prompt, overall_phrases_words_bboxes 441 | -------------------------------------------------------------------------------- /utils/schedule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | 4 | def get_fast_schedule(origial_timesteps, fast_after_steps, fast_rate): 5 | if fast_after_steps >= len(origial_timesteps) - 1: 6 | return origial_timesteps 7 | new_timesteps = torch.cat((origial_timesteps[:fast_after_steps], origial_timesteps[fast_after_steps+1::fast_rate]), dim=0) 8 | return new_timesteps 9 | 10 | def dynamically_adjust_inference_steps(scheduler, index, t): 11 | prev_t = scheduler.timesteps[index+1] if index+1 < len(scheduler.timesteps) else -1 12 | scheduler.num_inference_steps = scheduler.config.num_train_timesteps // (t - prev_t) 13 | if index+1 < len(scheduler.timesteps): 14 | if scheduler.config.num_train_timesteps // scheduler.num_inference_steps != t - prev_t: 15 | warnings.warn(f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) != ({t} - {prev_t}), so the step sizes may not be accurate") 16 | else: 17 | # as long as we hit final cumprob, it should be fine. 18 | if scheduler.config.num_train_timesteps // scheduler.num_inference_steps > t - prev_t: 19 | warnings.warn(f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) > ({t} - {prev_t}), so the step sizes may not be accurate") 20 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import ImageDraw 3 | import numpy as np 4 | import gc 5 | 6 | torch_device = "cuda" 7 | 8 | 9 | def draw_box(pil_img, bboxes, phrases): 10 | draw = ImageDraw.Draw(pil_img) 11 | # font = ImageFont.truetype('./FreeMono.ttf', 25) 12 | 13 | for obj_bbox, phrase in zip(bboxes, phrases): 14 | x_0, y_0, x_1, y_1 = obj_bbox[0], obj_bbox[1], obj_bbox[2], obj_bbox[3] 15 | draw.rectangle( 16 | [int(x_0 * 512), int(y_0 * 512), int(x_1 * 512), int(y_1 * 512)], 17 | outline="red", 18 | width=5, 19 | ) 20 | draw.text( 21 | (int(x_0 * 512) + 5, int(y_0 * 512) + 5), 22 | phrase, 23 | font=None, 24 | fill=(255, 0, 0), 25 | ) 26 | 27 | return pil_img 28 | 29 | 30 | def get_centered_box( 31 | box, 32 | horizontal_center_only=True, 33 | vertical_placement="centered", 34 | vertical_center=0.5, 35 | floor_padding=None, 36 | ): 37 | x_min, y_min, x_max, y_max = box 38 | w = x_max - x_min 39 | 40 | x_min_new = 0.5 - w / 2 41 | x_max_new = 0.5 + w / 2 42 | 43 | if horizontal_center_only: 44 | return [x_min_new, y_min, x_max_new, y_max] 45 | 46 | h = y_max - y_min 47 | 48 | if vertical_placement == "centered": 49 | assert ( 50 | floor_padding is None 51 | ), "Set vertical_placement to floor_padding to use floor padding" 52 | 53 | y_min_new = vertical_center - h / 2 54 | y_max_new = vertical_center + h / 2 55 | elif vertical_placement == "floor_padding": 56 | # Ignores `vertical_center` 57 | 58 | y_max_new = 1 - floor_padding 59 | y_min_new = y_max_new - h 60 | else: 61 | raise ValueError(f"Unknown vertical placement: {vertical_placement}") 62 | 63 | return [x_min_new, y_min_new, x_max_new, y_max_new] 64 | 65 | 66 | # NOTE: this changes the behavior of the function 67 | def proportion_to_mask(obj_box, H, W, use_legacy=False, return_np=False): 68 | x_min, y_min, x_max, y_max = scale_proportion(obj_box, H, W, use_legacy) 69 | if return_np: 70 | mask = np.zeros((H, W)) 71 | else: 72 | mask = torch.zeros(H, W).to(torch_device) 73 | mask[y_min:y_max, x_min:x_max] = 1.0 74 | 75 | return mask 76 | 77 | 78 | def scale_proportion(obj_box, H, W, use_legacy=False): 79 | if use_legacy: 80 | # Bias towards the top-left corner 81 | x_min, y_min, x_max, y_max = ( 82 | int(obj_box[0] * W), 83 | int(obj_box[1] * H), 84 | int(obj_box[2] * W), 85 | int(obj_box[3] * H), 86 | ) 87 | else: 88 | # Separately rounding box_w and box_h to allow shift invariant box sizes. Otherwise box sizes may change when both coordinates being rounded end with ".5". 89 | x_min, y_min = round(obj_box[0] * W), round(obj_box[1] * H) 90 | box_w, box_h = round((obj_box[2] - obj_box[0]) * W), round( 91 | (obj_box[3] - obj_box[1]) * H 92 | ) 93 | x_max, y_max = x_min + box_w, y_min + box_h 94 | 95 | x_min, y_min = max(x_min, 0), max(y_min, 0) 96 | x_max, y_max = min(x_max, W), min(y_max, H) 97 | 98 | return x_min, y_min, x_max, y_max 99 | 100 | 101 | def binary_mask_to_box(mask, enlarge_box_by_one=True, w_scale=1, h_scale=1): 102 | if isinstance(mask, torch.Tensor): 103 | mask_loc = torch.where(mask) 104 | else: 105 | mask_loc = np.where(mask) 106 | height, width = mask.shape 107 | if len(mask_loc) == 0: 108 | raise ValueError("The mask is empty") 109 | if enlarge_box_by_one: 110 | ymin, ymax = max(min(mask_loc[0]) - 1, 0), min(max(mask_loc[0]) + 1, height) 111 | xmin, xmax = max(min(mask_loc[1]) - 1, 0), min(max(mask_loc[1]) + 1, width) 112 | else: 113 | ymin, ymax = min(mask_loc[0]), max(mask_loc[0]) 114 | xmin, xmax = min(mask_loc[1]), max(mask_loc[1]) 115 | box = [xmin * w_scale, ymin * h_scale, xmax * w_scale, ymax * h_scale] 116 | 117 | return box 118 | 119 | 120 | def binary_mask_to_box_mask(mask, to_device=True): 121 | box = binary_mask_to_box(mask) 122 | x_min, y_min, x_max, y_max = box 123 | 124 | H, W = mask.shape 125 | mask = torch.zeros(H, W) 126 | if to_device: 127 | mask = mask.to(torch_device) 128 | mask[y_min : y_max + 1, x_min : x_max + 1] = 1.0 129 | 130 | return mask 131 | 132 | 133 | def binary_mask_to_center(mask, normalize=False): 134 | """ 135 | This computes the mass center of the mask. 136 | normalize: the coords range from 0 to 1 137 | 138 | Reference: https://stackoverflow.com/a/66184125 139 | """ 140 | h, w = mask.shape 141 | 142 | total = mask.sum() 143 | if isinstance(mask, torch.Tensor): 144 | x_coord = ((mask.sum(dim=0) @ torch.arange(w)) / total).item() 145 | y_coord = ((mask.sum(dim=1) @ torch.arange(h)) / total).item() 146 | else: 147 | x_coord = (mask.sum(axis=0) @ np.arange(w)) / total 148 | y_coord = (mask.sum(axis=1) @ np.arange(h)) / total 149 | 150 | if normalize: 151 | x_coord, y_coord = x_coord / w, y_coord / h 152 | return x_coord, y_coord 153 | 154 | 155 | def iou(mask, masks, eps=1e-6): 156 | # mask: [h, w], masks: [n, h, w] 157 | mask = mask[None].astype(bool) 158 | masks = masks.astype(bool) 159 | i = (mask & masks).sum(axis=(1, 2)) 160 | u = (mask | masks).sum(axis=(1, 2)) 161 | 162 | return i / (u + eps) 163 | 164 | 165 | def free_memory(): 166 | gc.collect() 167 | torch.cuda.empty_cache() 168 | 169 | 170 | def expand_overall_bboxes(overall_bboxes): 171 | """ 172 | Expand overall bboxes from a 3d list to 2d list: 173 | Input: [[box 1 for phrase 1, box 2 for phrase 1], ...] 174 | Output: [box 1, box 2, ...] 175 | """ 176 | return sum(overall_bboxes, start=[]) 177 | 178 | 179 | def shift_tensor( 180 | tensor, 181 | x_offset, 182 | y_offset, 183 | base_w=8, 184 | base_h=8, 185 | offset_normalized=False, 186 | ignore_last_dim=False, 187 | ): 188 | """base_w and base_h: make sure the shift is aligned in the latent and multiple levels of cross attention""" 189 | if ignore_last_dim: 190 | tensor_h, tensor_w = tensor.shape[-3:-1] 191 | else: 192 | tensor_h, tensor_w = tensor.shape[-2:] 193 | if offset_normalized: 194 | assert ( 195 | tensor_h % base_h == 0 and tensor_w % base_w == 0 196 | ), f"{tensor_h, tensor_w} is not a multiple of {base_h, base_w}" 197 | scale_from_base_h, scale_from_base_w = tensor_h // base_h, tensor_w // base_w 198 | x_offset, y_offset = ( 199 | round(x_offset * base_w) * scale_from_base_w, 200 | round(y_offset * base_h) * scale_from_base_h, 201 | ) 202 | new_tensor = torch.zeros_like(tensor) 203 | 204 | overlap_w = tensor_w - abs(x_offset) 205 | overlap_h = tensor_h - abs(y_offset) 206 | 207 | if y_offset >= 0: 208 | y_src_start = 0 209 | y_dest_start = y_offset 210 | else: 211 | y_src_start = -y_offset 212 | y_dest_start = 0 213 | 214 | if x_offset >= 0: 215 | x_src_start = 0 216 | x_dest_start = x_offset 217 | else: 218 | x_src_start = -x_offset 219 | x_dest_start = 0 220 | 221 | if ignore_last_dim: 222 | # For cross attention maps, the third to last and the second to last are the 2D dimensions after unflatten. 223 | new_tensor[ 224 | ..., 225 | y_dest_start : y_dest_start + overlap_h, 226 | x_dest_start : x_dest_start + overlap_w, 227 | :, 228 | ] = tensor[ 229 | ..., 230 | y_src_start : y_src_start + overlap_h, 231 | x_src_start : x_src_start + overlap_w, 232 | :, 233 | ] 234 | else: 235 | new_tensor[ 236 | ..., 237 | y_dest_start : y_dest_start + overlap_h, 238 | x_dest_start : x_dest_start + overlap_w, 239 | ] = tensor[ 240 | ..., 241 | y_src_start : y_src_start + overlap_h, 242 | x_src_start : x_src_start + overlap_w, 243 | ] 244 | 245 | return new_tensor 246 | 247 | 248 | -------------------------------------------------------------------------------- /utils/vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import math 3 | import utils 4 | from PIL import Image, ImageDraw 5 | import numpy as np 6 | from . import parse 7 | 8 | save_ind = 0 9 | 10 | 11 | def visualize(image, title, colorbar=False, show_plot=True, **kwargs): 12 | plt.title(title) 13 | plt.imshow(image, **kwargs) 14 | if colorbar: 15 | plt.colorbar() 16 | if show_plot: 17 | plt.show() 18 | 19 | 20 | def visualize_arrays( 21 | image_title_pairs, 22 | colorbar_index=-1, 23 | show_plot=True, 24 | figsize=None, 25 | fname="vis_array.jpg", 26 | **kwargs, 27 | ): 28 | if figsize is not None: 29 | plt.figure(figsize=figsize) 30 | num_subplots = len(image_title_pairs) 31 | for idx, image_title_pair in enumerate(image_title_pairs): 32 | plt.subplot(1, num_subplots, idx + 1) 33 | if isinstance(image_title_pair, (list, tuple)): 34 | image, title = image_title_pair 35 | else: 36 | image, title = image_title_pair, None 37 | 38 | if title is not None: 39 | plt.title(title) 40 | 41 | plt.imshow(image, **kwargs) 42 | if idx == colorbar_index: 43 | plt.colorbar() 44 | 45 | # if show_plot: 46 | # plt.show() 47 | plt.savefig(fname) 48 | 49 | 50 | def visualize_masked_latents( 51 | latents_all, masked_latents, timestep_T=False, timestep_0=True 52 | ): 53 | if timestep_T: 54 | # from T to 0 55 | latent_idx = 0 56 | 57 | plt.subplot(1, 2, 1) 58 | plt.title("latents_all (t=T)") 59 | plt.imshow( 60 | ( 61 | latents_all[latent_idx, 0, :3] 62 | .cpu() 63 | .permute(1, 2, 0) 64 | .numpy() 65 | .astype(float) 66 | / 1.5 67 | ).clip(0.0, 1.0), 68 | cmap="gray", 69 | ) 70 | 71 | plt.subplot(1, 2, 2) 72 | plt.title("mask latents (t=T)") 73 | plt.imshow( 74 | ( 75 | masked_latents[latent_idx, 0, :3] 76 | .cpu() 77 | .permute(1, 2, 0) 78 | .numpy() 79 | .astype(float) 80 | / 1.5 81 | ).clip(0.0, 1.0), 82 | cmap="gray", 83 | ) 84 | 85 | plt.show() 86 | 87 | if timestep_0: 88 | latent_idx = -1 89 | plt.subplot(1, 2, 1) 90 | plt.title("latents_all (t=0)") 91 | plt.imshow( 92 | ( 93 | latents_all[latent_idx, 0, :3] 94 | .cpu() 95 | .permute(1, 2, 0) 96 | .numpy() 97 | .astype(float) 98 | / 1.5 99 | ).clip(0.0, 1.0), 100 | cmap="gray", 101 | ) 102 | 103 | plt.subplot(1, 2, 2) 104 | plt.title("mask latents (t=0)") 105 | plt.imshow( 106 | ( 107 | masked_latents[latent_idx, 0, :3] 108 | .cpu() 109 | .permute(1, 2, 0) 110 | .numpy() 111 | .astype(float) 112 | / 1.5 113 | ).clip(0.0, 1.0), 114 | cmap="gray", 115 | ) 116 | 117 | plt.show() 118 | 119 | 120 | # This function has not been adapted to new `saved_attn`. 121 | def visualize_attn( 122 | token_map, 123 | cross_attention_probs_tensors, 124 | stage_id, 125 | block_id, 126 | visualize_step_start=10, 127 | input_ca_has_condition_only=False, 128 | ): 129 | """ 130 | Visualize cross attention: `stage_id`th downsampling block, mean over all timesteps starting from step start, `block_id`th Transformer block, second item (conditioned), mean over heads, show each token 131 | cross_attention_probs_tensors: 132 | One of `cross_attention_probs_down_tensors`, `cross_attention_probs_mid_tensors`, and `cross_attention_probs_up_tensors` 133 | stage_id: index of downsampling/mid/upsaming block 134 | block_id: index of the transformer block 135 | """ 136 | 137 | plt.figure(figsize=(20, 8)) 138 | 139 | for token_id in range(len(token_map)): 140 | token = token_map[token_id] 141 | plt.subplot(1, len(token_map), token_id + 1) 142 | plt.title(token) 143 | attn = cross_attention_probs_tensors[stage_id][visualize_step_start:].mean( 144 | dim=0 145 | )[block_id] 146 | 147 | if not input_ca_has_condition_only: 148 | assert ( 149 | attn.shape[0] == 2 150 | ), f"Expect to have 2 items (uncond and cond), but found {attn.shape[0]} items" 151 | attn = attn[1] 152 | else: 153 | assert ( 154 | attn.shape[0] == 1 155 | ), f"Expect to have 1 item (cond only), but found {attn.shape[0]} items" 156 | attn = attn[0] 157 | 158 | attn = attn.mean(dim=0)[:, token_id] 159 | H = W = int(math.sqrt(attn.shape[0])) 160 | attn = attn.reshape((H, W)) 161 | plt.imshow(attn.cpu().numpy()) 162 | 163 | plt.show() 164 | 165 | 166 | # This function has not been adapted to new `saved_attn`. 167 | def visualize_across_timesteps( 168 | token_id, 169 | cross_attention_probs_tensors, 170 | stage_id, 171 | block_id, 172 | visualize_step_start=10, 173 | input_ca_has_condition_only=False, 174 | ): 175 | """ 176 | Visualize cross attention for one token, across timesteps: `stage_id`th downsampling block, mean over all timesteps starting from step start, `block_id`th Transformer block, second item (conditioned), mean over heads, show each token 177 | cross_attention_probs_tensors: 178 | One of `cross_attention_probs_down_tensors`, `cross_attention_probs_mid_tensors`, and `cross_attention_probs_up_tensors` 179 | stage_id: index of downsampling/mid/upsaming block 180 | block_id: index of the transformer block 181 | 182 | `visualize_step_start` is not used. We visualize all timesteps. 183 | """ 184 | plt.figure(figsize=(50, 8)) 185 | 186 | attn_stage = cross_attention_probs_tensors[stage_id] 187 | num_inference_steps = attn_stage.shape[0] 188 | 189 | for t in range(num_inference_steps): 190 | plt.subplot(1, num_inference_steps, t + 1) 191 | plt.title(f"t: {t}") 192 | 193 | attn = attn_stage[t][block_id] 194 | 195 | if not input_ca_has_condition_only: 196 | assert ( 197 | attn.shape[0] == 2 198 | ), f"Expect to have 2 items (uncond and cond), but found {attn.shape[0]} items" 199 | attn = attn[1] 200 | else: 201 | assert ( 202 | attn.shape[0] == 1 203 | ), f"Expect to have 1 item (cond only), but found {attn.shape[0]} items" 204 | attn = attn[0] 205 | 206 | attn = attn.mean(dim=0)[:, token_id] 207 | H = W = int(math.sqrt(attn.shape[0])) 208 | attn = attn.reshape((H, W)) 209 | plt.imshow(attn.cpu().numpy()) 210 | plt.axis("off") 211 | plt.tight_layout() 212 | 213 | plt.show() 214 | 215 | 216 | def visualize_bboxes(bboxes, H, W): 217 | num_boxes = len(bboxes) 218 | for ind, bbox in enumerate(bboxes): 219 | plt.subplot(1, num_boxes, ind + 1) 220 | fg_mask = utils.proportion_to_mask(bbox, H, W) 221 | plt.title(f"transformed bbox ({ind})") 222 | plt.imshow(fg_mask.cpu().numpy()) 223 | plt.show() 224 | 225 | 226 | def reset_save_ind(): 227 | global save_ind 228 | save_ind = 0 229 | 230 | 231 | def display(image, save_prefix="", ind=None, save_ind_in_filename=True): 232 | """ 233 | save_ind_in_filename: This adds a global index to the filename so that two calls to this function will not save to the same file and overwrite the previous image. 234 | """ 235 | global save_ind 236 | if save_prefix != "": 237 | save_prefix = save_prefix + "_" 238 | if save_ind_in_filename: 239 | ind = f"{ind}_" if ind is not None else "" 240 | path = f"{parse.img_dir}/{save_prefix}{ind}{save_ind}.png" 241 | else: 242 | ind = f"{ind}" if ind is not None else "" 243 | path = f"{parse.img_dir}/{save_prefix}{ind}.png" 244 | 245 | print(f"Saved to {path}") 246 | 247 | if isinstance(image, np.ndarray): 248 | image = Image.fromarray(image) 249 | 250 | image.save(path) 251 | save_ind = save_ind + 1 252 | 253 | 254 | def draw_bounding_boxes(entry, vis_fname): 255 | """ 256 | Draw bounding boxes on a PIL image. 257 | 258 | :param image: PIL Image object 259 | :param bounding_boxes: List of bounding box coordinates in the format (x, y, width, height) 260 | :param color: Color of the bounding boxes (default is "red") 261 | :param width: Width of the bounding box lines (default is 2) 262 | :return: PIL Image object with bounding boxes drawn 263 | """ 264 | image = Image.open(entry["output"][-1]) 265 | initial_bboxes = entry["det_results"] 266 | updated_bboxes = entry["llm_suggestion"] 267 | w, h = image.size 268 | draw = ImageDraw.Draw(image) 269 | for bbox in initial_bboxes: 270 | class_name = bbox[0] 271 | coords = bbox[1] 272 | x, y, width, height = coords 273 | x, y, width, height = int(x * w), int(y * h), int(width * w), int(height * h) 274 | print(x, y, width, height, class_name) 275 | draw.rectangle([x, y, x + width, y + height], outline="red", width=2) 276 | draw.text((x, y), class_name, fill="red") 277 | # Another image 278 | blank_image = Image.new("RGB", (w, h), color="white") 279 | draw_new = ImageDraw.Draw(blank_image) 280 | for bbox in updated_bboxes: 281 | class_name = bbox[0] 282 | coords = bbox[1] 283 | x, y, width, height = coords 284 | x, y, width, height = int(x * w), int(y * h), int(width * w), int(height * h) 285 | print(x, y, width, height, class_name) 286 | draw_new.rectangle([x, y, x + width, y + height], outline="red", width=2) 287 | draw_new.text((x, y), class_name, fill="red") 288 | fig, axs = plt.subplots(1, 2, figsize=(12, 6)) 289 | axs[0].imshow(image) 290 | axs[1].imshow(blank_image) 291 | # axs[0].axis("off") 292 | # axs[1].axis("off") 293 | prompt = entry["instructions"] 294 | fig.suptitle(f"{prompt}", fontsize=9) 295 | plt.tight_layout() 296 | plt.savefig(vis_fname) 297 | plt.clf() 298 | --------------------------------------------------------------------------------