├── .gitignore ├── README.md ├── assets ├── example_video_lvd_gligen_zeroscope.gif ├── example_video_lvd_zeroscope.gif └── example_video_zeroscope_baseline.gif ├── cache ├── cache_demo_v0.1_gpt-4-1106-preview.json ├── cache_lvd_v0.1_gpt-3.5-turbo.json └── cache_lvd_v0.1_gpt-4-1106-preview.json ├── generate.py ├── generation ├── lvd.py ├── lvd_gligen.py ├── lvd_plus.py ├── modelscope_dpm.py └── zeroscope_dpm.py ├── models ├── __init__.py ├── attention.py ├── attention_processor.py ├── controllable_pipeline_text_to_video_synth.py ├── models.py ├── pipelines.py ├── transformer_2d.py ├── transformer_temporal.py ├── unet_2d_blocks.py ├── unet_2d_condition.py ├── unet_3d_blocks.py └── unet_3d_condition.py ├── prompt.py ├── prompt_batch.py ├── requirements.txt ├── scripts ├── eval_owl_vit.py ├── eval_stage_one.py └── upsample.py └── utils ├── __init__.py ├── attn.py ├── cache.py ├── eval ├── __init__.py ├── eval.py ├── lvd.py └── utils.py ├── guidance.py ├── latents.py ├── llm.py ├── parse.py ├── schedule.py ├── utils.py └── vis.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | api_key.py 163 | *_ignored* 164 | 165 | img_generations 166 | 167 | .ruff_cache 168 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLM-grounded Video Diffusion Models 2 | [Long Lian](https://tonylian.com/), [Baifeng Shi](https://bfshi.github.io/), [Adam Yala](https://www.adamyala.org/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/), [Boyi Li](https://sites.google.com/site/boyilics/home) at UC Berkeley/UCSF. 3 | 4 | ***International Conference on Learning Representations (ICLR) 2024*** 5 | 6 | [Paper PDF](https://openreview.net/pdf?id=exKHibougU) | [Arxiv](https://arxiv.org/abs/2309.17444) | [Project Page](https://llm-grounded-video-diffusion.github.io/) | [Related Project: LMD](https://llm-grounded-diffusion.github.io/) | [Citation](#citation) 7 | 8 | **TL;DR**: Text Prompt -> LLM Spatiotemporal Planner -> Dynamic Scene Layouts (DSLs) -> DSL-grounded Video Generator -> Video. This gives you much better **prompt understanding capabilities** for text-to-video generation. 9 | 10 | ![Comparisons with our baseline](https://llm-grounded-video-diffusion.github.io/teaser.jpg) 11 | 12 | ## Updates 13 | **[2024.4]** From the IGLIGEN project that offers a modern GLIGEN training codebase, several GLIGEN adapters are trained for image and video generation! This can be directly plugged into our DSL-grounded video generator and is compatible with the cross-attention control introduced in our paper. We offer huggingface diffusers integration that works off-the-shelf (**no repo clone or additional packages needed**)! Check out [this demo colab notebook](https://colab.research.google.com/drive/17He4bFAF8lXmT9Nfv-Sg29iKtPelDUNZ) for box-conditioned Modelscope with IGLIGEN adapters. See more details and comparisons for how to use it with this repo [here](#example-videos). 14 | 15 | # LLM-grounded Video Diffusion Models (LVD) 16 | The codebase is based on [LLM-grounded Diffusion (LMD)](https://github.com/TonyLianLong/LLM-groundedDiffusion) repo, so please also check out the instructions and FAQs there if your issues are not covered in this repo. 17 | 18 | ## Installation 19 | Install the dependencies: 20 | ``` 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ## Stage 1: Generating Dynamic Scene Layouts (DSLs) from text 25 | **Note that we have uploaded the layout caches for the benchmark onto this repo so that you can skip this step if you don't need layouts for new prompts (i.e., just want to use LLM-generated layouts for benchmarking stage 2).** 26 | 27 | Since we have cached the layout generation (which will be downloaded when you clone the repo), **you need to remove the cache in `cache` directory if you want to re-generate the layout with the same prompts**. 28 | 29 | **Our layout generation format:** The LLM takes in a text prompt describing the image and outputs dynamic scene layouts that consist of three elements: **1.** a reasoning statement for analysis, **2.** a captioned box for each object in each frame, and **3.** a background prompt. The template and example prompts are in [prompt.py](prompt.py). You can edit the template, the example prompts, and the parsing function to ask the LLM to generate additional things or even perform chain-of-thought for better generation. 30 | 31 | ### Automated query from OpenAI API 32 | Again, if you just want to evaluate stage 2 (layout to video stage), you can skip stage 1 as we have uploaded the layout caches onto this repo. **You don't need an OpenAI API key in stage 2.** 33 | 34 | If you have an [OpenAI API key](https://openai.com/blog/openai-api), you can put the API key in `utils/api_key.py` or set `OPENAI_API_KEY` environment variable. Then you can use OpenAI's API for batch text-to-layout generation by querying an LLM, with GPT-4 as an example: 35 | ```shell 36 | python prompt_batch.py --prompt-type demo --model gpt-4-1106-preview --auto-query --always-save --template_version v0.1 37 | ``` 38 | `--prompt-type demo` includes a few prompts for demonstrations. You can change them in [prompt.py](prompt.py). The layout generation will be cached so it does not query the LLM again with the same prompt (lowers the cost). 39 | 40 | You can visualize the dynamic scene layouts in the form of bounding boxes in `img_generations/imgs_demo_templatev0.1`. They are saved as `gif` files. For horizontal video generation with zeroscope, the square layout will be scaled according to the video aspect ratio. 41 | 42 | ### Run our benchmark on text-to-layout generation evaluation 43 | We provide a benchmark that applies both to stage 1 and stage 2. This benchmarks includes a set of prompts with five tasks (numeracy, attribution, visibility, dynamic satial, and sequential) as well as unified benchmarking code for all implemented methods and both stages. 44 | 45 | This will generate layouts from the prompts in the benchmark (with `--prompt-type lvd`) and evaluate the results: 46 | ```shell 47 | python prompt_batch.py --prompt-type lvd --model gpt-4-1106-preview --auto-query --always-save --template_version v0.1 48 | python scripts/eval_stage_one.py --prompt-type lvd --model gpt-4-1106-preview --template_version v0.1 49 | ``` 50 |
51 | Our reference benchmark results (stage 1, evaluating the generated layouts only) 52 | 53 | | Method | Numeracy | Attribution | Visibility | Dynamics | Sequential | Overall | 54 | | -------- | -------- | ----------- | ---------- | -------- | ---------- | ---------- | 55 | | GPT-3.5 | 100 | 100 | 100 | 71 | 16 | 77% | 56 | | GPT-3.5* | 100 | 100 | 100 | 73 | 15 | 78% | 57 | | GPT-4 | 100 | 100 | 100 | 100 | 88 | **98%** | 58 | 59 | \* The generated cache in this repo comes from this rerun. It differs marginally (1%) from the run that we reported. 60 |
61 | 62 | ## Stage 2: Generating Videos from Dynamic Scene Layouts 63 | Note that since we provide caches for stage 1, you don't need to run stage 1 on your own for cached prompts that we provide (i.e., you don't need an OpenAI API key or to query an LLM). 64 | 65 | Run layout-to-video generation using the gpt-4 cache and **LVD with Zeroscope** (resolution 576x320, 24 frames\*): 66 | ```shell 67 | # Zeroscope (horizontal videos) 68 | python generate.py --model gpt-4-1106-preview --run-model lvd_zeroscope --prompt-type demo --save-suffix weak_guidance --template_version v0.1 --seed_offset 0 --repeats 10 --loss_scale 2.5 --loss_threshold 350. --max_iter 1 --max_index_step 10 --fg_top_p 0.25 --bg_top_p 0.25 --fg_weight 1.0 --bg_weight 2.0 --num_frames 24 69 | ``` 70 | 71 | Run video generation with **LVD with Modelscope** (resolution 256x256, 16 frames\*): 72 | ```shell 73 | # Modelscope (square videos) 74 | python generate.py --model gpt-4-1106-preview --run-model lvd_modelscope256 --prompt-type demo --save-suffix weak_guidance --template_version v0.1 --seed_offset 0 --repeats 10 --loss_scale 2.5 --loss_threshold 250. --max_iter 1 --max_index_step 10 --fg_top_p 0.25 --bg_top_p 0.25 --fg_weight 1.0 --bg_weight 2.0 --num_frames 16 75 | ``` 76 | 77 | \* For context, [Zeroscope](https://huggingface.co/cerspense/zeroscope_v2_576w) is a model fine-tuned from [Modelscope](https://huggingface.co/ali-vilab/text-to-video-ms-1.7b) that generates horizontal videos without watermark. In contrast, Modelscope generates square videos, and Modelscope's generated videos often come with watermark. 78 | 79 | ### Update: LVD with GLIGEN adapters for lower inference costs 80 | Similar to LMD+, you can also integrate GLIGEN adapters trained in the [IGLIGEN project](https://github.com/TonyLianLong/igligen) with Modelscope in stage 2. Using GLIGEN adapters on Modelscope **requires less memory and is faster than training-free cross-attention control**. It could run on a T4 GPU on Colab. To use IGLIGEN Modelscope/Zeroscope adapters as our stage 2, you can use this command: 81 | ```shell 82 | # Zeroscope (horizontal videos) 83 | python generate.py --model gpt-4-1106-preview --run-model lvd-gligen_zeroscope --prompt-type demo --save-suffix lvd_gligen --template_version v0.1 --seed_offset 0 --repeats 10 --num_frames 24 --gligen_scheduled_sampling_beta 0.4 84 | 85 | # Modelscope (square videos) 86 | python generate.py --model gpt-4-1106-preview --run-model lvd-gligen_modelscope256 --prompt-type demo --save-suffix lvd_gligen --template_version v0.1 --seed_offset 0 --repeats 10 --num_frames 16 --gligen_scheduled_sampling_beta 0.4 87 | ``` 88 | 89 | Training-based methods such as GLIGEN typiclly have better spatial control, but sometimes can lead to different interpretations of words w.r.t. the base diffusion model or limited diversity for rare objects. **We recommend trying out both this variant and the original cross-attention-based control variant to get the best of both variants.** 90 | 91 | ### Baselines 92 |
93 | Run baselines to compare the generated video 94 | 95 | Run Zeroscope baseline: 96 | ```shell 97 | python generate.py --model gpt-4-1106-preview --run-model zeroscope --prompt-type demo --save-suffix baseline --template_version v0.1 --seed_offset 0 --repeats 10 --num_frames 24 98 | ``` 99 | 100 | Run Modelscope baseline: 101 | ```shell 102 | python generate.py --model gpt-4-1106-preview --run-model modelscope_256 --prompt-type demo --save-suffix baseline --template_version v0.1 --seed_offset 0 --repeats 10 --num_frames 16 103 | ``` 104 |
105 | 106 | You can use `--save-suffix` to specify the suffix added to the name of the run. `--run-model` specifies the method to run. You can set to LVD/LVD-GLIGEN or the implemented baselines. 107 | 108 | **Note:** there is a tradeoff between the strength of the control and the overall quality of the image with cross-attention control. The commands above with suffix `weak_guidance` trades control strength for lower level of artifacts. See the note in the benchmark section in this README for details. 109 | 110 | ### Saving formats 111 | The generated videos are saved in `img_generations/imgs_demo_templatev0.1_gpt-4-1106-preview_lvd_zeroscope_weak_guidance` (or other directories in `img_generations` if you run other commands). Note that the script saves the video in two formats: **1.** a `gif` file is saved for each video for quick and easy visualizations and debugging, **2.** a `joblib` file that has the original uncompressed numpy array is saved. **Note that `gif` files only support up to 256 colors, so `joblib` is recommended for future processing/export.** You can easily load with `joblib` and export to `mp4` or do other analysis/visualizations. 112 | 113 | ### Upsampling to get high-resolution videos 114 | This repo supports both [Zeroscope v2 XL](https://huggingface.co/cerspense/zeroscope_v2_XL) and [Stable Diffusion XL refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0) upsampling. 115 | 116 | Note that SDXL performs upsampling per frame, so it's likely to have some jittering. We recommend using Zeroscope v2 XL upsampler to upsample videos generated by Zeroscope and LVD with Zeroscope. 117 | 118 | ```shell 119 | python scripts/upsample.py --videos path_to_joblib_video --prompts "prompt for the video" --use_zsxl --horizontal --output-mp4 120 | ``` 121 | 122 | You can use `--output-mp4` to also output `mp4` video. Otherwise only `gif` and `joblib` files will be saved. 123 | 124 | ### Example videos 125 | Example videos generated with the scripts above with prompt `A bear walking from the left to the right`: 126 | | LVD (Ours, using cross-attention control) | LVD-GLIGEN (using IGLIGEN adapters) | Baseline | 127 | | ---- | -------- | ---- | 128 | | ![Example Video Demo: LVD on Zeroscope](assets/example_video_lvd_zeroscope.gif) | ![Example Video Demo: LVD on Zeroscope](assets/example_video_lvd_gligen_zeroscope.gif) | ![Example Video Demo: Zeroscope baseline](assets/example_video_zeroscope_baseline.gif) | 129 | 130 | As you can see, our method leads to improved prompt understanding, while the baseline does not understand "walking from the left to the right" correctly. 131 | 132 | ### Run our benchmark on layout-to-image generation evaluation 133 | We use a unified evaluation metric as stage 1 in stage 2 (`--prompt-type lvd`). Since we have layout boxes for stage 1 but only images for stage 2, we use OWL-ViT in order to detect the objects and ensure they are generated (or not generated in negation) in the right number, with the right attributes, and in the right place. 134 | 135 | This runs generation with LVD and evaluate the generation (resolution 256x256): 136 | ```shell 137 | # Use GPT-4 layouts as an example 138 | 139 | ## Generation 1 140 | python generate.py --model "gpt-4-1106-preview" --force_run_ind 0 --run-model lvd_modelscope256 --prompt-type lvd --template_version v0.1 --seed_offset 0 --com_loss_scale 0.03 141 | python scripts/eval_owl_vit.py --run_base_path ./img_generations/imgs_lvd_templatev0.1_gpt-4-1106-preview_lvd_modelscope256/run0 --skip_first_prompts 0 --verbose --detection_score_threshold 0.1 --nms_threshold 0.1 --class-aware-nms --num_eval_frames 6 --prompt-type lvd 142 | 143 | ## Generation 2 144 | python generate.py --model "gpt-4-1106-preview" --force_run_ind 1 --run-model lvd_modelscope256 --prompt-type lvd --template_version v0.1 --seed_offset 500 --com_loss_scale 0.03 145 | python scripts/eval_owl_vit.py --run_base_path ./img_generations/imgs_lvd_templatev0.1_gpt-4-1106-preview_lvd_modelscope256/run1 --skip_first_prompts 0 --verbose --detection_score_threshold 0.1 --nms_threshold 0.1 --class-aware-nms --num_eval_frames 6 --prompt-type lvd 146 | ``` 147 | 148 | Each generation run will lead to 500 videos. Then, please take the average of the two generation runs, **taking 1000 videos into account** (i.e., two videos per generated layout). **The script supports resuming**, and you can simply rerun the same command to resume generation with `--force_run_ind` set (otherwise it will save into a new directory). You can also run these two runs in parallel. Check out `--skip_first_prompts` and `--num_prompts` in [generate.py](generate.py) to parallelize each run across GPUs. 149 | 150 | **Note:** there is a tradeoff between the strength of the control and the overall quality of the image with cross-attention control. The hyperparams used in benchmark generation are tuned for the benchmarks, and we observe quality degredation for some of the generated videos, despite with the correct objects generated. We recommend using lower guidance strength for non-benchmarking purpose so that even though some cases the objects may not align closely to the layout, the composition often turns out to be more natural. Similarly, for LVD with IGLIGEN adapters, `gligen_scheduled_sampling_beta` controls the strength of conditioning. 151 | 152 | ##### Our reference benchmark results 153 | | Method | Numeracy | Attribution | Visibility | Dynamics | Sequential | Overall | 154 | | ------------- | -------- | ----------- | ---------- | -------- | ---------- | ---------- | 155 | | ModelScope | 32 | 54 | 8 | 21 | 0 | 23.0 | 156 | | LVD (GPT-3.5) | 52 | 79 | 64 | 37 | 2 | 46.4% | 157 | | LVD (GPT-4) | 41 | 64 | 55 | 51 | 38 | **49.4%** | 158 | 159 | ## Contact us 160 | Please contact Long (Tony) Lian if you have any questions: `longlian@berkeley.edu`. 161 | 162 | ## Acknowledgements 163 | This repo is based on [LMD repo](https://github.com/TonyLianLong/LLM-groundedDiffusion), which is based on [diffusers](https://huggingface.co/docs/diffusers/index) and references [GLIGEN](https://github.com/gligen/GLIGEN), [layout-guidance](https://github.com/silent-chen/layout-guidance). This repo uses the same license as LMD. 164 | 165 | ## Citation 166 | If you use our work or our implementation in this repo, or find them helpful, please consider giving a citation. 167 | ``` 168 | @inproceedings{lian2023llmgroundedvideo, 169 | title={LLM-grounded Video Diffusion Models}, 170 | author={Lian, Long and Shi, Baifeng and Yala, Adam and Darrell, Trevor and Li, Boyi}, 171 | booktitle={The Twelfth International Conference on Learning Representations}, 172 | year={2023} 173 | } 174 | 175 | @article{lian2023llmgrounded, 176 | title={LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models}, 177 | author={Lian, Long and Li, Boyi and Yala, Adam and Darrell, Trevor}, 178 | journal={arXiv preprint arXiv:2305.13655}, 179 | year={2023} 180 | } 181 | ``` 182 | -------------------------------------------------------------------------------- /assets/example_video_lvd_gligen_zeroscope.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TonyLianLong/LLM-groundedVideoDiffusion/22d90d2994ac3a938367383a11a5608f7fe73147/assets/example_video_lvd_gligen_zeroscope.gif -------------------------------------------------------------------------------- /assets/example_video_lvd_zeroscope.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TonyLianLong/LLM-groundedVideoDiffusion/22d90d2994ac3a938367383a11a5608f7fe73147/assets/example_video_lvd_zeroscope.gif -------------------------------------------------------------------------------- /assets/example_video_zeroscope_baseline.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TonyLianLong/LLM-groundedVideoDiffusion/22d90d2994ac3a938367383a11a5608f7fe73147/assets/example_video_zeroscope_baseline.gif -------------------------------------------------------------------------------- /cache/cache_demo_v0.1_gpt-4-1106-preview.json: -------------------------------------------------------------------------------- 1 | { 2 | "A bear walks from the left to the right": [ 3 | "Reasoning: A bear walking from left to right will have its x-coordinate increasing with each frame, while its y-coordinate should remain relatively constant as it is walking on a flat surface.\n\nFrame 1: [{'id': 0, 'name': 'bear', 'box': [0, 256, 100, 100]}]\nFrame 2: [{'id': 0, 'name': 'bear', 'box': [85, 256, 100, 100]}]\nFrame 3: [{'id': 0, 'name': 'bear', 'box': [170, 256, 100, 100]}]\nFrame 4: [{'id': 0, 'name': 'bear', 'box': [255, 256, 100, 100]}]\nFrame 5: [{'id': 0, 'name': 'bear', 'box': [340, 256, 100, 100]}]\nFrame 6: [{'id': 0, 'name': 'bear', 'box': [425, 256, 100, 100]}]\nBackground keyword: forest" 4 | ] 5 | } -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | from utils import parse, vis, cache 2 | from utils.llm import get_full_model_name, model_names, get_parsed_layout 3 | from utils.parse import show_video_boxes, size 4 | from tqdm import tqdm 5 | import os 6 | from prompt import get_prompts, template_versions 7 | import matplotlib.pyplot as plt 8 | import traceback 9 | import bdb 10 | import time 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--save-suffix", default=None, type=str) 15 | parser.add_argument( 16 | "--model", 17 | choices=model_names, 18 | required=True, 19 | help="LLM model to load the cache from", 20 | ) 21 | parser.add_argument( 22 | "--repeats", default=1, type=int, help="Number of samples for each prompt" 23 | ) 24 | parser.add_argument( 25 | "--regenerate", 26 | default=1, 27 | type=int, 28 | help="Number of regenerations. Different from repeats, regeneration happens after everything is generated", 29 | ) 30 | parser.add_argument( 31 | "--force_run_ind", 32 | default=None, 33 | type=int, 34 | help="If this is enabled, we use this run_ind and skips generated images. If this is not enabled, we create a new run after existing runs.", 35 | ) 36 | parser.add_argument( 37 | "--skip_first_prompts", 38 | default=0, 39 | type=int, 40 | help="Skip the first prompts in generation (useful for parallel generation)", 41 | ) 42 | parser.add_argument( 43 | "--seed_offset", 44 | default=0, 45 | type=int, 46 | help="Offset to the seed (seed starts from this number)", 47 | ) 48 | parser.add_argument( 49 | "--num_prompts", 50 | default=None, 51 | type=int, 52 | help="The number of prompts to generate (useful for parallel generation)", 53 | ) 54 | parser.add_argument( 55 | "--run-model", 56 | default="lvd", 57 | choices=[ 58 | "lvd", 59 | "lvd_zeroscope", 60 | "lvd_modelscope256", 61 | "lvd-gligen_modelscope256", 62 | "lvd-gligen_zeroscope", 63 | "lvd-plus_modelscope256", 64 | "lvd_modelscope512", 65 | "modelscope", 66 | "modelscope_256", 67 | "zeroscope", 68 | "zeroscope_xl", 69 | ], 70 | help="The model to use (modelscope has the option to generate with resolution 256x256)", 71 | ) 72 | parser.add_argument("--visualize", action="store_true") 73 | parser.add_argument("--no-continue-on-error", action="store_true") 74 | parser.add_argument("--prompt-type", type=str, default="demo") 75 | parser.add_argument("--template_version", choices=template_versions, required=True) 76 | parser.add_argument("--dry-run", action="store_true", help="skip the generation") 77 | 78 | float_args = [ 79 | "fg_top_p", 80 | "bg_top_p", 81 | "fg_weight", 82 | "bg_weight", 83 | "loss_threshold", 84 | "loss_scale", 85 | "boxdiff_loss_scale", 86 | "com_loss_scale", 87 | "gligen_scheduled_sampling_beta", 88 | ] 89 | for float_arg in float_args: 90 | parser.add_argument("--" + float_arg, default=None, type=float) 91 | 92 | # `use_ratio_based_loss` should be 0 or 1 (as it is a bool) 93 | int_args = [ 94 | "num_inference_steps", 95 | "max_iter", 96 | "max_index_step", 97 | "num_frames", 98 | "use_ratio_based_loss", 99 | "boxdiff_normed", 100 | ] 101 | for int_arg in int_args: 102 | parser.add_argument("--" + int_arg, default=None, type=int) 103 | 104 | str_args = [] 105 | for str_arg in str_args: 106 | parser.add_argument("--" + str_arg, default=None, type=str) 107 | 108 | args = parser.parse_args() 109 | 110 | 111 | if not args.dry_run: 112 | run_model = args.run_model 113 | baseline = run_model in [ 114 | "modelscope", 115 | "zeroscope", 116 | "modelscope_256", 117 | "zeroscope_xl", 118 | ] 119 | 120 | if "_" in run_model: 121 | option = run_model.split("_")[1] 122 | else: 123 | option = "" 124 | 125 | if run_model.startswith("lvd-plus"): 126 | import generation.lvd_plus as generation 127 | 128 | base_model = option if option else "modelscope" 129 | H, W = generation.init(base_model=base_model) 130 | elif run_model.startswith("lvd-gligen"): 131 | import generation.lvd_gligen as generation 132 | 133 | base_model = option if option else "modelscope" 134 | H, W = generation.init(base_model=base_model) 135 | elif (run_model == "lvd") or (run_model.startswith("lvd_")): 136 | import generation.lvd as generation 137 | 138 | # Use modelscope as the default model 139 | base_model = option if option else "modelscope" 140 | H, W = generation.init(base_model=base_model) 141 | elif run_model == "modelscope" or run_model == "modelscope_256": 142 | import generation.modelscope_dpm as generation 143 | 144 | H, W = generation.init(option=option) 145 | elif run_model == "zeroscope" or run_model == "zeroscope_xl": 146 | import generation.zeroscope_dpm as generation 147 | 148 | H, W = generation.init(option=option) 149 | else: 150 | raise ValueError(f"Unknown model: {run_model}") 151 | 152 | if "zeroscope" in run_model and ( 153 | (args.num_frames is not None and args.num_frames < 24) 154 | or ((not baseline) and args.num_frames is None) 155 | ): 156 | # num_frames is 16 by default in non-baseline models (as it uses modelscope by default) 157 | raise ValueError( 158 | "Running zeroscope with fewer than 24 frames. This may lead to suboptimal results. Comment this out if you still want to run." 159 | ) 160 | 161 | version = generation.version 162 | assert ( 163 | version == args.run_model.split("_")[0] 164 | ), f"{version} != {args.run_model.split('_')[0]}" 165 | run = generation.run 166 | else: 167 | version = "dry_run" 168 | run = None 169 | generation = argparse.Namespace() 170 | 171 | # set visualizations to no-op in batch generation 172 | for k in vis.__dict__.keys(): 173 | if k.startswith("visualize"): 174 | vis.__dict__[k] = lambda *args, **kwargs: None 175 | 176 | 177 | ## Visualize 178 | def visualize_layout(parsed_layout): 179 | H, W = size 180 | condition = parse.parsed_layout_to_condition( 181 | parsed_layout, tokenizer=None, height=H, width=W, verbose=True 182 | ) 183 | 184 | show_video_boxes(condition, ind=ind, save=True) 185 | 186 | print(f"Visualize masks at {parse.img_dir}") 187 | 188 | 189 | # close the figure when plt.show is called 190 | plt.show = plt.close 191 | 192 | prompt_type = args.prompt_type 193 | template_version = args.template_version 194 | json_template = "json" in template_version 195 | 196 | # Use cache 197 | model = get_full_model_name(model=args.model) 198 | 199 | if not baseline: 200 | cache.cache_format = "json" 201 | cache.cache_path = f'cache/cache_{args.prompt_type.replace("lmd_", "")}_{template_version}_{model}.json' 202 | print(f"Loading LLM responses from cache {cache.cache_path}") 203 | cache.init_cache(allow_nonexist=False) 204 | 205 | prompts = get_prompts(prompt_type) 206 | 207 | save_suffix = ("_" + args.save_suffix) if args.save_suffix else "" 208 | repeats = args.repeats 209 | seed_offset = args.seed_offset 210 | 211 | model_in_base_save_dir = "" if model == "gpt-4" else f"_{model}" 212 | base_save_dir = f"img_generations/imgs_{prompt_type}_template{args.template_version}{model_in_base_save_dir}_{run_model}{save_suffix}" 213 | 214 | run_kwargs = {} 215 | 216 | argnames = float_args + int_args + str_args 217 | 218 | for argname in argnames: 219 | argvalue = getattr(args, argname) 220 | if argvalue is not None: 221 | run_kwargs[argname] = argvalue 222 | 223 | is_notebook = False 224 | 225 | if args.force_run_ind is not None: 226 | run_ind = args.force_run_ind 227 | save_dir = f"{base_save_dir}/run{run_ind}" 228 | else: 229 | run_ind = 0 230 | while True: 231 | save_dir = f"{base_save_dir}/run{run_ind}" 232 | if not os.path.exists(save_dir): 233 | break 234 | run_ind += 1 235 | 236 | if hasattr(generation, "use_autocast") and generation.use_autocast: 237 | save_dir += "_amp" 238 | 239 | print(f"Save dir: {save_dir}") 240 | 241 | LARGE_CONSTANT = 123456789 242 | LARGE_CONSTANT2 = 56789 243 | LARGE_CONSTANT3 = 6789 244 | 245 | ind = 0 246 | if args.regenerate > 1: 247 | # Need to fix the ind 248 | assert args.skip_first_prompts == 0 249 | 250 | for regenerate_ind in range(args.regenerate): 251 | print("regenerate_ind:", regenerate_ind) 252 | if not baseline: 253 | cache.reset_cache_access() 254 | for prompt_ind, prompt in enumerate(tqdm(prompts, desc=f"Run: {save_dir}")): 255 | if prompt_ind < args.skip_first_prompts: 256 | ind += 1 257 | continue 258 | if args.num_prompts is not None and prompt_ind >= ( 259 | args.skip_first_prompts + args.num_prompts 260 | ): 261 | ind += 1 262 | continue 263 | 264 | # get prompt from prompts, if prompt is a list, then prompt includes both the prompt and kwargs 265 | if isinstance(prompt, list): 266 | prompt, kwargs = prompt 267 | else: 268 | kwargs = {} 269 | 270 | prompt = prompt.strip().rstrip(".") 271 | 272 | ind_override = kwargs.get("seed", None) 273 | 274 | # Load from cache 275 | if baseline: 276 | resp = None 277 | else: 278 | resp = cache.get_cache(prompt) 279 | 280 | if resp is None: 281 | print(f"Cache miss, skipping prompt: {prompt}") 282 | ind += 1 283 | continue 284 | 285 | print(f"***run: {run_ind}***") 286 | print(f"prompt: {prompt}, resp: {resp}") 287 | parse.img_dir = f"{save_dir}/{ind}" 288 | # Skip if image is already generared 289 | if not ( 290 | os.path.exists(parse.img_dir) 291 | and len( 292 | [ 293 | img 294 | for img in os.listdir(parse.img_dir) 295 | if img.startswith("video") and img.endswith("joblib") 296 | ] 297 | ) 298 | >= args.repeats 299 | ): 300 | os.makedirs(parse.img_dir, exist_ok=True) 301 | try: 302 | if baseline: 303 | parsed_layout = {"Prompt": prompt} 304 | else: 305 | parsed_layout, _ = get_parsed_layout( 306 | prompt, 307 | max_partial_response_retries=1, 308 | override_response=resp, 309 | json_template=json_template, 310 | ) 311 | 312 | print("parsed_layout:", parsed_layout) 313 | 314 | if args.dry_run: 315 | # Skip generation 316 | ind += 1 317 | continue 318 | 319 | if args.visualize: 320 | assert ( 321 | not baseline 322 | ), "baseline methods do not have layouts from the LLM to visualize" 323 | visualize_layout(parsed_layout) 324 | 325 | original_ind_base = ( 326 | ind_override + regenerate_ind * LARGE_CONSTANT2 327 | if ind_override is not None 328 | else ind 329 | ) 330 | 331 | for repeat_ind in range(repeats): 332 | ind_offset = repeat_ind * LARGE_CONSTANT3 + seed_offset 333 | run( 334 | parsed_layout, 335 | seed=original_ind_base + ind_offset, 336 | repeat_ind=repeat_ind, 337 | **run_kwargs, 338 | ) 339 | 340 | except (KeyboardInterrupt, bdb.BdbQuit) as e: 341 | print(e) 342 | exit() 343 | except RuntimeError: 344 | print( 345 | "***RuntimeError: might run out of memory, skipping the current one***" 346 | ) 347 | print(traceback.format_exc()) 348 | time.sleep(10) 349 | except Exception as e: 350 | print(f"***Error: {e}***") 351 | print(traceback.format_exc()) 352 | if args.no_continue_on_error: 353 | raise e 354 | else: 355 | print(f"Image exists at {parse.img_dir}, skipping") 356 | ind += 1 357 | 358 | if not baseline and cache.values_accessed() != len(prompts): 359 | print( 360 | f"**Cache is hit {cache.values_accessed()} time(s) but we have {len(prompts)} prompts. There may be cache misses or inconsistencies between the prompts and the cache such as extra items in the cache.**" 361 | ) 362 | -------------------------------------------------------------------------------- /generation/lvd.py: -------------------------------------------------------------------------------- 1 | from models.controllable_pipeline_text_to_video_synth import TextToVideoSDPipeline 2 | from diffusers import DPMSolverMultistepScheduler 3 | from models.unet_3d_condition import UNet3DConditionModel 4 | from utils import parse, vis 5 | from prompt import negative_prompt 6 | import utils 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | import os 11 | 12 | version = "lvd" 13 | 14 | # %% 15 | # H, W are generation H and W. box_W and box_W are for scaling the boxes to [0, 1]. 16 | pipe, base_attn_dim, H, W, box_H, box_W = None, None, None, None, None, None 17 | 18 | 19 | def init(base_model): 20 | global pipe, base_attn_dim, H, W, box_H, box_W 21 | if base_model == "modelscope512": 22 | model_key = "damo-vilab/text-to-video-ms-1.7b" 23 | base_attn_dim = (64, 64) 24 | H, W = 512, 512 25 | box_H, box_W = parse.size 26 | elif base_model == "modelscope256": 27 | model_key = "damo-vilab/text-to-video-ms-1.7b" 28 | base_attn_dim = (32, 32) 29 | H, W = 256, 256 30 | box_H, box_W = parse.size 31 | elif base_model == "zeroscope": 32 | model_key = "cerspense/zeroscope_v2_576w" 33 | base_attn_dim = (40, 72) 34 | H, W = 320, 576 35 | box_H, box_W = parse.size 36 | else: 37 | raise ValueError(f"Unknown base model: {base_model}") 38 | 39 | unet = UNet3DConditionModel.from_pretrained(model_key, subfolder="unet").to( 40 | torch.float16 41 | ) 42 | pipe = TextToVideoSDPipeline.from_pretrained( 43 | model_key, unet=unet, torch_dtype=torch.float16 44 | ) 45 | # The default one is DDIMScheduler 46 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 47 | pipe.to("cuda") 48 | pipe.enable_vae_slicing() 49 | 50 | # No auxiliary guidance 51 | pipe.guidance_models = None 52 | 53 | return H, W 54 | 55 | 56 | # %% 57 | upsample_scale, upsample_mode = 1, "bilinear" 58 | 59 | # %% 60 | # For visualizations: set to False to not save 61 | return_guidance_saved_attn = False 62 | # This is the main attn, not the attn for guidance. 63 | save_keys = [] 64 | 65 | # %% 66 | overall_guidance_attn_keys = [ 67 | ("down", 1, 0, 0), 68 | ("down", 2, 0, 0), 69 | ("down", 2, 1, 0), 70 | ("up", 1, 0, 0), 71 | ("up", 1, 1, 0), 72 | ("up", 2, 2, 0), 73 | ] 74 | 75 | # Seems like `enable_model_cpu_offload` performs deepcopy so `save_attn_to_dict` does not save the attn 76 | cross_attention_kwargs = { 77 | "save_attn_to_dict": {}, 78 | "save_keys": save_keys, 79 | # This is for visualizations 80 | # 'offload_cross_attn_to_cpu': True 81 | } 82 | 83 | 84 | # %% 85 | def run( 86 | parsed_layout, 87 | seed, 88 | num_inference_steps=40, 89 | num_frames=16, 90 | repeat_ind=None, 91 | save_annotated_videos=False, 92 | loss_scale=5.0, 93 | loss_threshold=200.0, 94 | max_iter=5, 95 | max_index_step=10, 96 | fg_top_p=0.75, 97 | bg_top_p=0.75, 98 | fg_weight=1.0, 99 | bg_weight=4.0, 100 | attn_sync_weight=0.0, 101 | boxdiff_loss_scale=0.0, 102 | boxdiff_normed=True, 103 | com_loss_scale=0.0, 104 | use_ratio_based_loss=False, 105 | save_formats=["gif", "joblib"], 106 | ): 107 | condition = parse.parsed_layout_to_condition( 108 | parsed_layout, 109 | tokenizer=pipe.tokenizer, 110 | height=box_H, 111 | width=box_W, 112 | num_condition_frames=num_frames, 113 | verbose=True, 114 | ) 115 | prompt, bboxes, phrases, object_positions, token_map = ( 116 | condition.prompt, 117 | condition.boxes, 118 | condition.phrases, 119 | condition.object_positions, 120 | condition.token_map, 121 | ) 122 | 123 | backward_guidance_kwargs = dict( 124 | bboxes=bboxes, 125 | object_positions=object_positions, 126 | loss_scale=loss_scale, 127 | loss_threshold=loss_threshold, 128 | max_iter=max_iter, 129 | max_index_step=max_index_step, 130 | fg_top_p=fg_top_p, 131 | bg_top_p=bg_top_p, 132 | fg_weight=fg_weight, 133 | bg_weight=bg_weight, 134 | use_ratio_based_loss=use_ratio_based_loss, 135 | guidance_attn_keys=overall_guidance_attn_keys, 136 | exclude_bg_heads=False, 137 | upsample_scale=upsample_scale, 138 | upsample_mode=upsample_mode, 139 | base_attn_dim=base_attn_dim, 140 | attn_sync_weight=attn_sync_weight, 141 | boxdiff_loss_scale=boxdiff_loss_scale, 142 | boxdiff_normed=boxdiff_normed, 143 | com_loss_scale=com_loss_scale, 144 | verbose=True, 145 | ) 146 | 147 | if repeat_ind is not None: 148 | save_suffix = repeat_ind 149 | 150 | else: 151 | save_suffix = f"seed{seed}" 152 | 153 | save_path = f"{parse.img_dir}/video_{save_suffix}.gif" 154 | if os.path.exists(save_path): 155 | print(f"Skipping {save_path}") 156 | return 157 | 158 | print("Generating") 159 | generator = torch.Generator(device="cuda").manual_seed(seed) 160 | 161 | video_frames = pipe( 162 | prompt, 163 | negative_prompt=negative_prompt, 164 | num_inference_steps=num_inference_steps, 165 | height=H, 166 | width=W, 167 | num_frames=num_frames, 168 | cross_attention_kwargs=cross_attention_kwargs, 169 | generator=generator, 170 | guidance_callback=None, 171 | backward_guidance_kwargs=backward_guidance_kwargs, 172 | return_guidance_saved_attn=return_guidance_saved_attn, 173 | guidance_type="main", 174 | ).frames 175 | video_frames = (video_frames[0] * 255.0).astype(np.uint8) 176 | 177 | # %% 178 | 179 | if save_annotated_videos: 180 | annotated_frames = [ 181 | np.array( 182 | utils.draw_box( 183 | Image.fromarray(video_frame), [bbox[i] for bbox in bboxes], phrases 184 | ) 185 | ) 186 | for i, video_frame in enumerate(video_frames) 187 | ] 188 | vis.save_frames( 189 | f"{save_path}/video_seed{seed}_with_box", 190 | frames=annotated_frames, 191 | formats="gif", 192 | ) 193 | 194 | vis.save_frames( 195 | f"{parse.img_dir}/video_{save_suffix}", video_frames, formats=save_formats 196 | ) 197 | -------------------------------------------------------------------------------- /generation/lvd_gligen.py: -------------------------------------------------------------------------------- 1 | from models.controllable_pipeline_text_to_video_synth import TextToVideoSDPipeline 2 | from diffusers import DPMSolverMultistepScheduler 3 | from models.unet_3d_condition import UNet3DConditionModel 4 | from utils import parse, vis 5 | from prompt import negative_prompt 6 | import utils 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | import os 11 | 12 | version = "lvd-gligen" 13 | 14 | # %% 15 | # H, W are generation H and W. box_W and box_W are for scaling the boxes to [0, 1]. 16 | pipe, H, W, box_H, box_W = None, None, None, None, None 17 | 18 | 19 | def init(base_model): 20 | global pipe, H, W, box_H, box_W 21 | if base_model == "modelscope256": 22 | model_key = "longlian/text-to-video-lvd-ms" 23 | H, W = 256, 256 24 | box_H, box_W = parse.size 25 | elif base_model == "zeroscope": 26 | model_key = "longlian/text-to-video-lvd-zs" 27 | H, W = 320, 576 28 | box_H, box_W = parse.size 29 | else: 30 | raise ValueError(f"Unknown base model: {base_model}") 31 | 32 | pipe = TextToVideoSDPipeline.from_pretrained( 33 | model_key, trust_remote_code=True, torch_dtype=torch.float16 34 | ) 35 | # The default one is DDIMScheduler 36 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 37 | pipe.to("cuda") 38 | pipe.enable_vae_slicing() 39 | 40 | # No auxiliary guidance 41 | pipe.guidance_models = None 42 | 43 | return H, W 44 | 45 | 46 | # %% 47 | upsample_scale, upsample_mode = 1, "bilinear" 48 | 49 | # %% 50 | 51 | # Seems like `enable_model_cpu_offload` performs deepcopy so `save_attn_to_dict` does not save the attn 52 | cross_attention_kwargs = { 53 | # This is for visualizations 54 | # 'offload_cross_attn_to_cpu': True 55 | } 56 | 57 | 58 | # %% 59 | def run( 60 | parsed_layout, 61 | seed, 62 | num_inference_steps=40, 63 | num_frames=16, 64 | gligen_scheduled_sampling_beta=1.0, 65 | repeat_ind=None, 66 | save_annotated_videos=False, 67 | save_formats=["gif", "joblib"], 68 | ): 69 | condition = parse.parsed_layout_to_condition( 70 | parsed_layout, 71 | tokenizer=pipe.tokenizer, 72 | height=box_H, 73 | width=box_W, 74 | num_condition_frames=num_frames, 75 | verbose=True, 76 | ) 77 | prompt, bboxes, phrases, object_positions, token_map = ( 78 | condition.prompt, 79 | condition.boxes, 80 | condition.phrases, 81 | condition.object_positions, 82 | condition.token_map, 83 | ) 84 | 85 | if repeat_ind is not None: 86 | save_suffix = repeat_ind 87 | 88 | else: 89 | save_suffix = f"seed{seed}" 90 | 91 | save_path = f"{parse.img_dir}/video_{save_suffix}.gif" 92 | if os.path.exists(save_path): 93 | print(f"Skipping {save_path}") 94 | return 95 | 96 | print("Generating") 97 | generator = torch.Generator(device="cuda").manual_seed(seed) 98 | 99 | lvd_gligen_boxes = [] 100 | lvd_gligen_phrases = [] 101 | for i in range(num_frames): 102 | lvd_gligen_boxes.append( 103 | [ 104 | bboxes_item[i] 105 | for phrase, bboxes_item in zip(phrases, bboxes) 106 | if bboxes_item[i] != [0.0, 0.0, 0.0, 0.0] 107 | ] 108 | ) 109 | lvd_gligen_phrases.append( 110 | [ 111 | phrase 112 | for phrase, bboxes_item in zip(phrases, bboxes) 113 | if bboxes_item[i] != [0.0, 0.0, 0.0, 0.0] 114 | ] 115 | ) 116 | 117 | video_frames = pipe( 118 | prompt, 119 | negative_prompt=negative_prompt, 120 | num_inference_steps=num_inference_steps, 121 | height=H, 122 | width=W, 123 | num_frames=num_frames, 124 | cross_attention_kwargs=cross_attention_kwargs, 125 | generator=generator, 126 | lvd_gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, 127 | lvd_gligen_boxes=lvd_gligen_boxes, 128 | lvd_gligen_phrases=lvd_gligen_phrases, 129 | ).frames 130 | # `diffusers` has a backward-breaking change 131 | # video_frames = (video_frames[0] * 255.).astype(np.uint8) 132 | 133 | # %% 134 | 135 | if save_annotated_videos: 136 | annotated_frames = [ 137 | np.array( 138 | utils.draw_box( 139 | Image.fromarray(video_frame), [bbox[i] for bbox in bboxes], phrases 140 | ) 141 | ) 142 | for i, video_frame in enumerate(video_frames) 143 | ] 144 | vis.save_frames( 145 | f"{save_path}/video_seed{seed}_with_box", 146 | frames=annotated_frames, 147 | formats="gif", 148 | ) 149 | 150 | vis.save_frames( 151 | f"{parse.img_dir}/video_{save_suffix}", video_frames, formats=save_formats 152 | ) 153 | -------------------------------------------------------------------------------- /generation/lvd_plus.py: -------------------------------------------------------------------------------- 1 | from models.controllable_pipeline_text_to_video_synth import TextToVideoSDPipeline 2 | from diffusers import DPMSolverMultistepScheduler 3 | from models.unet_3d_condition import UNet3DConditionModel 4 | from utils import parse, vis 5 | from prompt import negative_prompt 6 | import utils 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | import os 11 | 12 | version = "lvd-plus" 13 | 14 | # %% 15 | # H, W are generation H and W. box_W and box_W are for scaling the boxes to [0, 1]. 16 | pipe, base_attn_dim, H, W, box_H, box_W = None, None, None, None, None, None 17 | 18 | 19 | def init(base_model): 20 | global pipe, base_attn_dim, H, W, box_H, box_W 21 | if base_model == "modelscope256": 22 | model_key = "longlian/text-to-video-lvd-ms" 23 | base_attn_dim = (32, 32) 24 | H, W = 256, 256 25 | box_H, box_W = parse.size 26 | else: 27 | raise ValueError(f"Unknown base model: {base_model}") 28 | 29 | unet = UNet3DConditionModel.from_pretrained( 30 | model_key, subfolder="unet", revision="weights_only" 31 | ).to(torch.float16) 32 | pipe = TextToVideoSDPipeline.from_pretrained( 33 | model_key, unet=unet, torch_dtype=torch.float16, revision="weights_only" 34 | ) 35 | # The default one is DDIMScheduler 36 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 37 | pipe.to("cuda") 38 | pipe.enable_vae_slicing() 39 | 40 | # No auxiliary guidance 41 | pipe.guidance_models = None 42 | 43 | return H, W 44 | 45 | 46 | # %% 47 | upsample_scale, upsample_mode = 1, "bilinear" 48 | 49 | # %% 50 | # For visualizations: set to False to not save 51 | return_guidance_saved_attn = False 52 | # This is the main attn, not the attn for guidance. 53 | save_keys = [] 54 | 55 | # %% 56 | overall_guidance_attn_keys = [ 57 | ("down", 1, 0, 0), 58 | ("down", 2, 0, 0), 59 | ("down", 2, 1, 0), 60 | ("up", 1, 0, 0), 61 | ("up", 1, 1, 0), 62 | ("up", 2, 2, 0), 63 | ] 64 | 65 | # Seems like `enable_model_cpu_offload` performs deepcopy so `save_attn_to_dict` does not save the attn 66 | cross_attention_kwargs = { 67 | "save_attn_to_dict": {}, 68 | "save_keys": save_keys, 69 | # This is for visualizations 70 | # 'offload_cross_attn_to_cpu': True 71 | } 72 | 73 | 74 | # %% 75 | def run( 76 | parsed_layout, 77 | seed, 78 | num_inference_steps=40, 79 | num_frames=16, 80 | gligen_scheduled_sampling_beta=1.0, 81 | repeat_ind=None, 82 | save_annotated_videos=False, 83 | loss_scale=5.0, 84 | loss_threshold=200.0, 85 | max_iter=5, 86 | max_index_step=10, 87 | fg_top_p=0.75, 88 | bg_top_p=0.75, 89 | fg_weight=1.0, 90 | bg_weight=4.0, 91 | attn_sync_weight=0.0, 92 | boxdiff_loss_scale=0.0, 93 | boxdiff_normed=True, 94 | com_loss_scale=0.0, 95 | use_ratio_based_loss=False, 96 | save_formats=["gif", "joblib"], 97 | ): 98 | condition = parse.parsed_layout_to_condition( 99 | parsed_layout, 100 | tokenizer=pipe.tokenizer, 101 | height=box_H, 102 | width=box_W, 103 | num_condition_frames=num_frames, 104 | verbose=True, 105 | ) 106 | prompt, bboxes, phrases, object_positions, token_map = ( 107 | condition.prompt, 108 | condition.boxes, 109 | condition.phrases, 110 | condition.object_positions, 111 | condition.token_map, 112 | ) 113 | 114 | backward_guidance_kwargs = dict( 115 | bboxes=bboxes, 116 | object_positions=object_positions, 117 | loss_scale=loss_scale, 118 | loss_threshold=loss_threshold, 119 | max_iter=max_iter, 120 | max_index_step=max_index_step, 121 | fg_top_p=fg_top_p, 122 | bg_top_p=bg_top_p, 123 | fg_weight=fg_weight, 124 | bg_weight=bg_weight, 125 | use_ratio_based_loss=use_ratio_based_loss, 126 | guidance_attn_keys=overall_guidance_attn_keys, 127 | exclude_bg_heads=False, 128 | upsample_scale=upsample_scale, 129 | upsample_mode=upsample_mode, 130 | base_attn_dim=base_attn_dim, 131 | attn_sync_weight=attn_sync_weight, 132 | boxdiff_loss_scale=boxdiff_loss_scale, 133 | boxdiff_normed=boxdiff_normed, 134 | com_loss_scale=com_loss_scale, 135 | verbose=True, 136 | ) 137 | 138 | if repeat_ind is not None: 139 | save_suffix = repeat_ind 140 | 141 | else: 142 | save_suffix = f"seed{seed}" 143 | 144 | save_path = f"{parse.img_dir}/video_{save_suffix}.gif" 145 | if os.path.exists(save_path): 146 | print(f"Skipping {save_path}") 147 | return 148 | 149 | print("Generating") 150 | generator = torch.Generator(device="cuda").manual_seed(seed) 151 | 152 | # print(bboxes, phrases) 153 | 154 | lvd_gligen_boxes = [] 155 | lvd_gligen_phrases = [] 156 | for i in range(num_frames): 157 | lvd_gligen_boxes.append( 158 | [ 159 | bboxes_item[i] 160 | for phrase, bboxes_item in zip(phrases, bboxes) 161 | if bboxes_item[i] != [0.0, 0.0, 0.0, 0.0] 162 | ] 163 | ) 164 | lvd_gligen_phrases.append( 165 | [ 166 | phrase 167 | for phrase, bboxes_item in zip(phrases, bboxes) 168 | if bboxes_item[i] != [0.0, 0.0, 0.0, 0.0] 169 | ] 170 | ) 171 | 172 | video_frames = pipe( 173 | prompt, 174 | negative_prompt=negative_prompt, 175 | num_inference_steps=num_inference_steps, 176 | height=H, 177 | width=W, 178 | num_frames=num_frames, 179 | cross_attention_kwargs=cross_attention_kwargs, 180 | generator=generator, 181 | gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, 182 | gligen_boxes=lvd_gligen_boxes, 183 | gligen_phrases=lvd_gligen_phrases, 184 | guidance_callback=None, 185 | backward_guidance_kwargs=backward_guidance_kwargs, 186 | return_guidance_saved_attn=return_guidance_saved_attn, 187 | guidance_type="main", 188 | ).frames 189 | video_frames = (video_frames[0] * 255.0).astype(np.uint8) 190 | 191 | # %% 192 | 193 | if save_annotated_videos: 194 | annotated_frames = [ 195 | np.array( 196 | utils.draw_box( 197 | Image.fromarray(video_frame), [bbox[i] for bbox in bboxes], phrases 198 | ) 199 | ) 200 | for i, video_frame in enumerate(video_frames) 201 | ] 202 | vis.save_frames( 203 | f"{save_path}/video_seed{seed}_with_box", 204 | frames=annotated_frames, 205 | formats="gif", 206 | ) 207 | 208 | vis.save_frames( 209 | f"{parse.img_dir}/video_{save_suffix}", video_frames, formats=save_formats 210 | ) 211 | -------------------------------------------------------------------------------- /generation/modelscope_dpm.py: -------------------------------------------------------------------------------- 1 | from diffusers import TextToVideoSDPipeline 2 | from diffusers import DPMSolverMultistepScheduler 3 | from utils import parse, vis 4 | from prompt import negative_prompt 5 | import torch 6 | import numpy as np 7 | import os 8 | 9 | version = "modelscope" 10 | 11 | # %% 12 | model_key = "damo-vilab/text-to-video-ms-1.7b" 13 | 14 | pipe = TextToVideoSDPipeline.from_pretrained(model_key, torch_dtype=torch.float16) 15 | # The default one is DDIMScheduler 16 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 17 | # pipe.unet = UNet3DConditionModel.from_pretrained( 18 | # model_key, subfolder="unet").to(torch.float16) 19 | pipe.to("cuda") 20 | pipe.enable_vae_slicing() 21 | # No auxiliary guidance 22 | pipe.guidance_models = None 23 | 24 | # %% 25 | H, W = None, None 26 | 27 | 28 | def init(option): 29 | global H, W 30 | if option == "": 31 | H, W = 512, 512 32 | elif option == "256": 33 | H, W = 256, 256 34 | else: 35 | raise ValueError(f"Unknown option: {option}") 36 | 37 | return H, W 38 | 39 | 40 | def run( 41 | parsed_layout, 42 | seed, 43 | *, 44 | num_inference_steps=40, 45 | num_frames=16, 46 | repeat_ind=None, 47 | save_formats=["gif", "joblib"], 48 | ): 49 | prompt = parsed_layout["Prompt"] 50 | 51 | if repeat_ind is not None: 52 | save_suffix = repeat_ind 53 | else: 54 | save_suffix = f"seed{seed}" 55 | 56 | save_path = f"{parse.img_dir}/video_{save_suffix}.gif" 57 | if os.path.exists(save_path): 58 | print(f"Skipping {save_path}") 59 | return 60 | 61 | print("Generating") 62 | generator = torch.Generator(device="cuda").manual_seed(seed) 63 | 64 | video_frames = pipe( 65 | prompt, 66 | negative_prompt=negative_prompt, 67 | num_inference_steps=num_inference_steps, 68 | height=H, 69 | width=W, 70 | num_frames=num_frames, 71 | cross_attention_kwargs=None, 72 | generator=generator, 73 | ).frames 74 | 75 | video_frames = (video_frames[0] * 255.0).astype(np.uint8) 76 | 77 | # %% 78 | vis.save_frames( 79 | f"{parse.img_dir}/video_{save_suffix}", video_frames, formats=save_formats 80 | ) 81 | -------------------------------------------------------------------------------- /generation/zeroscope_dpm.py: -------------------------------------------------------------------------------- 1 | from diffusers import TextToVideoSDPipeline, VideoToVideoSDPipeline 2 | from diffusers import DPMSolverMultistepScheduler 3 | from utils import parse, vis 4 | from prompt import negative_prompt 5 | import torch 6 | from PIL import Image 7 | import numpy as np 8 | import os 9 | 10 | version = "zeroscope" 11 | 12 | # %% 13 | H, W = None, None 14 | 15 | pipe = TextToVideoSDPipeline.from_pretrained( 16 | "cerspense/zeroscope_v2_576w", 17 | # unet = UNet3DConditionModel.from_pretrained( 18 | # "cerspense/zeroscope_v2_576w", subfolder="unet" 19 | # ).to(torch.float16), 20 | torch_dtype=torch.float16, 21 | ) 22 | # The default one is DDIMScheduler 23 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 24 | pipe.to("cuda") 25 | pipe.enable_vae_slicing() 26 | pipe_xl = None 27 | 28 | 29 | def init(option): 30 | global pipe_xl, H, W 31 | 32 | if option == "": 33 | H, W = 320, 576 34 | elif option == "xl": 35 | # the base model is still in 320, 576. The xl model outputs (576, 1024). 36 | H, W = 320, 576 37 | 38 | pipe_xl = VideoToVideoSDPipeline.from_pretrained( 39 | "cerspense/zeroscope_v2_XL", torch_dtype=torch.float16 40 | ) 41 | pipe_xl.scheduler = DPMSolverMultistepScheduler.from_config( 42 | pipe.scheduler.config 43 | ) 44 | # pipe_xl.enable_model_cpu_offload() 45 | pipe_xl.to("cuda") 46 | pipe_xl.enable_vae_slicing() 47 | else: 48 | raise ValueError(f"Unknown option: {option}") 49 | # WIP 50 | return H, W 51 | 52 | 53 | # %% 54 | def run( 55 | parsed_layout, 56 | seed, 57 | *, 58 | num_inference_steps=40, 59 | num_frames=24, 60 | repeat_ind=None, 61 | save_formats=["gif", "joblib"], 62 | ): 63 | prompt = parsed_layout["Prompt"] 64 | 65 | if repeat_ind is not None: 66 | save_suffix = repeat_ind 67 | else: 68 | save_suffix = f"seed{seed}" 69 | 70 | save_path = f"{parse.img_dir}/video_{save_suffix}.gif" 71 | if os.path.exists(save_path): 72 | print(f"Skipping {save_path}") 73 | return 74 | 75 | print("Generating") 76 | generator = torch.Generator(device="cuda").manual_seed(seed) 77 | 78 | video_frames = pipe( 79 | prompt, 80 | negative_prompt=negative_prompt, 81 | num_inference_steps=num_inference_steps, 82 | height=H, 83 | width=W, 84 | num_frames=num_frames, 85 | cross_attention_kwargs=None, 86 | generator=generator, 87 | ).frames 88 | video_frames = (video_frames[0] * 255.0).astype(np.uint8) 89 | 90 | if pipe_xl is not None: 91 | print("Refining") 92 | video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames] 93 | video_frames_xl = pipe_xl( 94 | prompt, 95 | negative_prompt=negative_prompt, 96 | num_inference_steps=num_inference_steps, 97 | video=video, 98 | strength=0.6, 99 | generator=generator, 100 | ).frames 101 | 102 | video_frames = (video_frames[0] * 255.0).astype(np.uint8) 103 | 104 | print("Saving") 105 | vis.save_frames( 106 | f"{parse.img_dir}/video_xl_{save_suffix}", 107 | video_frames_xl, 108 | formats=save_formats, 109 | ) 110 | else: 111 | vis.save_frames( 112 | f"{parse.img_dir}/video_{save_suffix}", video_frames, formats=save_formats 113 | ) 114 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # For compatibility 4 | from utils import torch_device 5 | 6 | 7 | def encode_prompts( 8 | tokenizer, 9 | text_encoder, 10 | prompts, 11 | negative_prompt="", 12 | return_full_only=False, 13 | one_uncond_input_only=False, 14 | ): 15 | if negative_prompt == "": 16 | print("Note that negative_prompt is an empty string") 17 | 18 | text_input = tokenizer( 19 | prompts, 20 | padding="max_length", 21 | max_length=tokenizer.model_max_length, 22 | truncation=True, 23 | return_tensors="pt", 24 | ) 25 | 26 | max_length = text_input.input_ids.shape[-1] 27 | if one_uncond_input_only: 28 | num_uncond_input = 1 29 | else: 30 | num_uncond_input = len(prompts) 31 | uncond_input = tokenizer( 32 | [negative_prompt] * num_uncond_input, 33 | padding="max_length", 34 | max_length=max_length, 35 | return_tensors="pt", 36 | ) 37 | 38 | with torch.no_grad(): 39 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] 40 | cond_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] 41 | 42 | if one_uncond_input_only: 43 | return uncond_embeddings, cond_embeddings 44 | 45 | text_embeddings = torch.cat([uncond_embeddings, cond_embeddings]) 46 | 47 | if return_full_only: 48 | return text_embeddings 49 | return text_embeddings, uncond_embeddings, cond_embeddings 50 | 51 | 52 | def process_input_embeddings(input_embeddings): 53 | assert isinstance(input_embeddings, (tuple, list)) 54 | if len(input_embeddings) == 3: 55 | # input_embeddings: text_embeddings, uncond_embeddings, cond_embeddings 56 | # Assume `uncond_embeddings` is full (has batch size the same as cond_embeddings) 57 | _, uncond_embeddings, cond_embeddings = input_embeddings 58 | assert ( 59 | uncond_embeddings.shape[0] == cond_embeddings.shape[0] 60 | ), f"{uncond_embeddings.shape[0]} != {cond_embeddings.shape[0]}" 61 | return input_embeddings 62 | elif len(input_embeddings) == 2: 63 | # input_embeddings: uncond_embeddings, cond_embeddings 64 | # uncond_embeddings may have only one item 65 | uncond_embeddings, cond_embeddings = input_embeddings 66 | if uncond_embeddings.shape[0] == 1: 67 | uncond_embeddings = uncond_embeddings.expand(cond_embeddings.shape) 68 | # We follow the convention: negative (unconditional) prompt comes first 69 | text_embeddings = torch.cat((uncond_embeddings, cond_embeddings), dim=0) 70 | return text_embeddings, uncond_embeddings, cond_embeddings 71 | else: 72 | raise ValueError(f"input_embeddings length: {len(input_embeddings)}") 73 | 74 | 75 | def attn_list_to_tensor(cross_attention_probs): 76 | # timestep, CrossAttnBlock, Transformer2DModel, 1xBasicTransformerBlock 77 | 78 | num_cross_attn_block = len(cross_attention_probs[0]) 79 | cross_attention_probs_all = [] 80 | 81 | for i in range(num_cross_attn_block): 82 | # cross_attention_probs_timestep[i]: Transformer2DModel 83 | # 1xBasicTransformerBlock is skipped 84 | cross_attention_probs_current = [] 85 | for cross_attention_probs_timestep in cross_attention_probs: 86 | cross_attention_probs_current.append( 87 | torch.stack([item for item in cross_attention_probs_timestep[i]], dim=0) 88 | ) 89 | 90 | cross_attention_probs_current = torch.stack( 91 | cross_attention_probs_current, dim=0 92 | ) 93 | cross_attention_probs_all.append(cross_attention_probs_current) 94 | 95 | return cross_attention_probs_all 96 | -------------------------------------------------------------------------------- /models/pipelines.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm import tqdm 4 | from utils import guidance, schedule 5 | import utils 6 | from PIL import Image 7 | import gc 8 | import numpy as np 9 | from .attention import GatedSelfAttentionDense 10 | from .models import process_input_embeddings, torch_device 11 | import warnings 12 | 13 | DEFAULT_GUIDANCE_ATTN_KEYS = [ 14 | ("down", 2, 0, 0), 15 | ("down", 2, 1, 0), 16 | ("up", 1, 0, 0), 17 | ("up", 1, 1, 0), 18 | ] 19 | 20 | 21 | def latent_backward_guidance( 22 | scheduler, 23 | unet, 24 | cond_embeddings, 25 | index, 26 | bboxes, 27 | object_positions, 28 | t, 29 | latents, 30 | loss, 31 | loss_scale=30, 32 | loss_threshold=0.2, 33 | max_iter=5, 34 | max_index_step=10, 35 | cross_attention_kwargs=None, 36 | guidance_attn_keys=None, 37 | verbose=False, 38 | return_saved_attn=False, 39 | clear_cache=False, 40 | **kwargs, 41 | ): 42 | """ 43 | return_saved_attn: return the saved attention for visualizations 44 | """ 45 | 46 | iteration = 0 47 | 48 | saved_attn_to_return = None 49 | 50 | if index < max_index_step: 51 | if isinstance(max_iter, list): 52 | max_iter = max_iter[index] 53 | 54 | if verbose: 55 | print( 56 | f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}" 57 | ) 58 | 59 | with torch.set_grad_enabled(True): 60 | while ( 61 | loss.item() / loss_scale > loss_threshold 62 | and iteration < max_iter 63 | and index < max_index_step 64 | ): 65 | saved_attn = {} 66 | full_cross_attention_kwargs = { 67 | "save_attn_to_dict": saved_attn, 68 | "save_keys": guidance_attn_keys, 69 | } 70 | 71 | if cross_attention_kwargs is not None: 72 | full_cross_attention_kwargs.update(cross_attention_kwargs) 73 | 74 | latents.requires_grad_(True) 75 | latent_model_input = latents 76 | latent_model_input = scheduler.scale_model_input(latent_model_input, t) 77 | 78 | unet( 79 | latent_model_input, 80 | t, 81 | encoder_hidden_states=cond_embeddings, 82 | cross_attention_kwargs=full_cross_attention_kwargs, 83 | ) 84 | 85 | if return_saved_attn == "first": 86 | if iteration == 0: 87 | saved_attn_to_return = { 88 | k: v.detach().cpu() for k, v in saved_attn.items() 89 | } 90 | elif return_saved_attn == "last": 91 | if iteration == max_iter - 1: 92 | # It will not save if the current call returns before the last iteration 93 | saved_attn_to_return = { 94 | k: v.detach().cpu() for k, v in saved_attn.items() 95 | } 96 | elif return_saved_attn: 97 | raise ValueError(return_saved_attn) 98 | 99 | # TODO: could return the attention maps for the required blocks only and not necessarily the final output 100 | # update latents with guidance 101 | loss = ( 102 | guidance.compute_ca_lossv3( 103 | saved_attn=saved_attn, 104 | bboxes=bboxes, 105 | object_positions=object_positions, 106 | guidance_attn_keys=guidance_attn_keys, 107 | index=index, 108 | verbose=verbose, 109 | **kwargs, 110 | ) 111 | * loss_scale 112 | ) 113 | 114 | if torch.isnan(loss): 115 | print("**Loss is NaN**") 116 | 117 | del full_cross_attention_kwargs, saved_attn 118 | # call gc.collect() here may release some memory 119 | 120 | grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0] 121 | 122 | latents.requires_grad_(False) 123 | 124 | if hasattr(scheduler, "alphas_cumprod"): 125 | warnings.warn("Using guidance scaled with alphas_cumprod") 126 | # Scaling with classifier guidance 127 | alpha_prod_t = scheduler.alphas_cumprod[t] 128 | # Classifier guidance: https://arxiv.org/pdf/2105.05233.pdf 129 | # DDIM: https://arxiv.org/pdf/2010.02502.pdf 130 | scale = (1 - alpha_prod_t) ** (0.5) 131 | 132 | latents = latents - scale * grad_cond 133 | else: 134 | # NOTE: no scaling is performed 135 | warnings.warn("No scaling in guidance is performed") 136 | latents = latents - grad_cond 137 | iteration += 1 138 | 139 | if clear_cache: 140 | gc.collect() 141 | torch.cuda.empty_cache() 142 | 143 | if verbose: 144 | print( 145 | f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}" 146 | ) 147 | 148 | if return_saved_attn: 149 | return latents, loss, saved_attn_to_return 150 | return latents, loss 151 | 152 | 153 | @torch.no_grad() 154 | def encode(model_dict, image, generator): 155 | """ 156 | image should be a PIL object or numpy array with range 0 to 255 157 | """ 158 | 159 | vae, dtype = model_dict.vae, model_dict.dtype 160 | 161 | if isinstance(image, Image.Image): 162 | w, h = image.size 163 | assert ( 164 | w % 8 == 0 and h % 8 == 0 165 | ), f"h ({h}) and w ({w}) should be a multiple of 8" 166 | # w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 167 | # image = np.array(image.resize((w, h), resample=Image.Resampling.LANCZOS))[None, :] 168 | image = np.array(image) 169 | 170 | if isinstance(image, np.ndarray): 171 | assert ( 172 | image.dtype == np.uint8 173 | ), f"Should have dtype uint8 (dtype: {image.dtype})" 174 | image = image.astype(np.float32) / 255.0 175 | image = image[None, ...] 176 | image = image.transpose(0, 3, 1, 2) 177 | image = 2.0 * image - 1.0 178 | image = torch.from_numpy(image) 179 | 180 | assert isinstance(image, torch.Tensor), f"type of image: {type(image)}" 181 | 182 | image = image.to(device=torch_device, dtype=dtype) 183 | latents = vae.encode(image).latent_dist.sample(generator) 184 | 185 | latents = vae.config.scaling_factor * latents 186 | 187 | return latents 188 | 189 | 190 | @torch.no_grad() 191 | def decode(vae, latents): 192 | # scale and decode the image latents with vae 193 | scaled_latents = 1 / 0.18215 * latents 194 | with torch.no_grad(): 195 | image = vae.decode(scaled_latents).sample 196 | 197 | image = (image / 2 + 0.5).clamp(0, 1) 198 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy() 199 | images = (image * 255).round().astype("uint8") 200 | 201 | return images 202 | 203 | 204 | def generate_semantic_guidance( 205 | model_dict, 206 | latents, 207 | input_embeddings, 208 | num_inference_steps, 209 | bboxes, 210 | phrases, 211 | object_positions, 212 | guidance_scale=7.5, 213 | semantic_guidance_kwargs=None, 214 | return_cross_attn=False, 215 | return_saved_cross_attn=False, 216 | saved_cross_attn_keys=None, 217 | return_cond_ca_only=False, 218 | return_token_ca_only=None, 219 | offload_guidance_cross_attn_to_cpu=False, 220 | offload_cross_attn_to_cpu=False, 221 | offload_latents_to_cpu=True, 222 | return_box_vis=False, 223 | show_progress=True, 224 | save_all_latents=False, 225 | dynamic_num_inference_steps=False, 226 | fast_after_steps=None, 227 | fast_rate=2, 228 | additional_guidance_cross_attention_kwargs={}, 229 | custom_latent_backward_guidance=None, 230 | ): 231 | """ 232 | object_positions: object indices in text tokens 233 | return_cross_attn: should be deprecated. Use `return_saved_cross_attn` and the new format. 234 | """ 235 | vae, tokenizer, text_encoder, unet, scheduler, dtype = ( 236 | model_dict.vae, 237 | model_dict.tokenizer, 238 | model_dict.text_encoder, 239 | model_dict.unet, 240 | model_dict.scheduler, 241 | model_dict.dtype, 242 | ) 243 | text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings 244 | 245 | # Just in case that we have in-place ops 246 | latents = latents.clone() 247 | 248 | if save_all_latents: 249 | # offload to cpu to save space 250 | if offload_latents_to_cpu: 251 | latents_all = [latents.cpu()] 252 | else: 253 | latents_all = [latents] 254 | 255 | scheduler.set_timesteps(num_inference_steps) 256 | if fast_after_steps is not None: 257 | scheduler.timesteps = schedule.get_fast_schedule( 258 | scheduler.timesteps, fast_after_steps, fast_rate 259 | ) 260 | 261 | if dynamic_num_inference_steps: 262 | original_num_inference_steps = scheduler.num_inference_steps 263 | 264 | cross_attention_probs_down = [] 265 | cross_attention_probs_mid = [] 266 | cross_attention_probs_up = [] 267 | 268 | loss = torch.tensor(10000.0) 269 | 270 | # TODO: we can also save necessary tokens only to save memory. 271 | # offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep. 272 | guidance_cross_attention_kwargs = { 273 | "offload_cross_attn_to_cpu": offload_guidance_cross_attn_to_cpu, 274 | "enable_flash_attn": False, 275 | **additional_guidance_cross_attention_kwargs, 276 | } 277 | 278 | if return_saved_cross_attn: 279 | saved_attns = [] 280 | 281 | main_cross_attention_kwargs = { 282 | "offload_cross_attn_to_cpu": offload_cross_attn_to_cpu, 283 | "return_cond_ca_only": return_cond_ca_only, 284 | "return_token_ca_only": return_token_ca_only, 285 | "save_keys": saved_cross_attn_keys, 286 | } 287 | 288 | # Repeating keys leads to different weights for each key. 289 | # assert len(set(semantic_guidance_kwargs['guidance_attn_keys'])) == len(semantic_guidance_kwargs['guidance_attn_keys']), f"guidance_attn_keys not unique: {semantic_guidance_kwargs['guidance_attn_keys']}" 290 | 291 | for index, t in enumerate(tqdm(scheduler.timesteps, disable=not show_progress)): 292 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. 293 | 294 | if bboxes: 295 | if custom_latent_backward_guidance: 296 | latents, loss = custom_latent_backward_guidance( 297 | scheduler, 298 | unet, 299 | cond_embeddings, 300 | index, 301 | bboxes, 302 | object_positions, 303 | t, 304 | latents, 305 | loss, 306 | cross_attention_kwargs=guidance_cross_attention_kwargs, 307 | **semantic_guidance_kwargs, 308 | ) 309 | else: 310 | # If encountered None in `guidance_attn_keys`, please be sure to check whether `guidance_attn_keys` is added in `semantic_guidance_kwargs`. Default value has been removed. 311 | latents, loss = latent_backward_guidance( 312 | scheduler, 313 | unet, 314 | cond_embeddings, 315 | index, 316 | bboxes, 317 | object_positions, 318 | t, 319 | latents, 320 | loss, 321 | cross_attention_kwargs=guidance_cross_attention_kwargs, 322 | **semantic_guidance_kwargs, 323 | ) 324 | 325 | # predict the noise residual 326 | with torch.no_grad(): 327 | latent_model_input = torch.cat([latents] * 2) 328 | latent_model_input = scheduler.scale_model_input( 329 | latent_model_input, timestep=t 330 | ) 331 | 332 | main_cross_attention_kwargs["save_attn_to_dict"] = {} 333 | 334 | unet_output = unet( 335 | latent_model_input, 336 | t, 337 | encoder_hidden_states=text_embeddings, 338 | return_cross_attention_probs=return_cross_attn, 339 | cross_attention_kwargs=main_cross_attention_kwargs, 340 | ) 341 | noise_pred = unet_output.sample 342 | 343 | if return_cross_attn: 344 | cross_attention_probs_down.append( 345 | unet_output.cross_attention_probs_down 346 | ) 347 | cross_attention_probs_mid.append(unet_output.cross_attention_probs_mid) 348 | cross_attention_probs_up.append(unet_output.cross_attention_probs_up) 349 | 350 | if return_saved_cross_attn: 351 | saved_attns.append(main_cross_attention_kwargs["save_attn_to_dict"]) 352 | 353 | del main_cross_attention_kwargs["save_attn_to_dict"] 354 | 355 | # perform guidance 356 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 357 | noise_pred = noise_pred_uncond + guidance_scale * ( 358 | noise_pred_text - noise_pred_uncond 359 | ) 360 | 361 | if dynamic_num_inference_steps: 362 | schedule.dynamically_adjust_inference_steps(scheduler, index, t) 363 | 364 | # compute the previous noisy sample x_t -> x_t-1 365 | latents = scheduler.step(noise_pred, t, latents).prev_sample 366 | 367 | if save_all_latents: 368 | if offload_latents_to_cpu: 369 | latents_all.append(latents.cpu()) 370 | else: 371 | latents_all.append(latents) 372 | 373 | if dynamic_num_inference_steps: 374 | # Restore num_inference_steps to avoid confusion in the next generation if it is not dynamic 375 | scheduler.num_inference_steps = original_num_inference_steps 376 | 377 | images = decode(vae, latents) 378 | 379 | ret = [latents, images] 380 | 381 | if return_cross_attn: 382 | ret.append( 383 | ( 384 | cross_attention_probs_down, 385 | cross_attention_probs_mid, 386 | cross_attention_probs_up, 387 | ) 388 | ) 389 | if return_saved_cross_attn: 390 | ret.append(saved_attns) 391 | if return_box_vis: 392 | pil_images = [ 393 | utils.draw_box(Image.fromarray(image), bboxes, phrases) for image in images 394 | ] 395 | ret.append(pil_images) 396 | if save_all_latents: 397 | latents_all = torch.stack(latents_all, dim=0) 398 | ret.append(latents_all) 399 | return tuple(ret) 400 | -------------------------------------------------------------------------------- /models/transformer_temporal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Optional 16 | 17 | import torch 18 | from torch import nn 19 | 20 | from diffusers.configuration_utils import ConfigMixin, register_to_config 21 | from diffusers.utils import BaseOutput 22 | from .attention import BasicTransformerBlock 23 | from diffusers.models.modeling_utils import ModelMixin 24 | 25 | 26 | @dataclass 27 | class TransformerTemporalModelOutput(BaseOutput): 28 | """ 29 | The output of [`TransformerTemporalModel`]. 30 | 31 | Args: 32 | sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`): 33 | The hidden states output conditioned on `encoder_hidden_states` input. 34 | """ 35 | 36 | sample: torch.FloatTensor 37 | 38 | 39 | class TransformerTemporalModel(ModelMixin, ConfigMixin): 40 | """ 41 | A Transformer model for video-like data. 42 | 43 | Parameters: 44 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 45 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 46 | in_channels (`int`, *optional*): 47 | The number of channels in the input and output (specify if the input is **continuous**). 48 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 49 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 50 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 51 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). 52 | This is fixed during training since it is used to learn a number of position embeddings. 53 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. 54 | attention_bias (`bool`, *optional*): 55 | Configure if the `TransformerBlock` attention should contain a bias parameter. 56 | double_self_attention (`bool`, *optional*): 57 | Configure if each `TransformerBlock` should contain two self-attention layers. 58 | """ 59 | 60 | @register_to_config 61 | def __init__( 62 | self, 63 | num_attention_heads: int = 16, 64 | attention_head_dim: int = 88, 65 | in_channels: Optional[int] = None, 66 | out_channels: Optional[int] = None, 67 | num_layers: int = 1, 68 | dropout: float = 0.0, 69 | norm_num_groups: int = 32, 70 | cross_attention_dim: Optional[int] = None, 71 | attention_bias: bool = False, 72 | sample_size: Optional[int] = None, 73 | activation_fn: str = "geglu", 74 | norm_elementwise_affine: bool = True, 75 | double_self_attention: bool = True, 76 | ): 77 | super().__init__() 78 | self.num_attention_heads = num_attention_heads 79 | self.attention_head_dim = attention_head_dim 80 | inner_dim = num_attention_heads * attention_head_dim 81 | 82 | self.in_channels = in_channels 83 | 84 | self.norm = torch.nn.GroupNorm( 85 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 86 | ) 87 | self.proj_in = nn.Linear(in_channels, inner_dim) 88 | 89 | # 3. Define transformers blocks 90 | self.transformer_blocks = nn.ModuleList( 91 | [ 92 | BasicTransformerBlock( 93 | inner_dim, 94 | num_attention_heads, 95 | attention_head_dim, 96 | dropout=dropout, 97 | cross_attention_dim=cross_attention_dim, 98 | activation_fn=activation_fn, 99 | attention_bias=attention_bias, 100 | double_self_attention=double_self_attention, 101 | norm_elementwise_affine=norm_elementwise_affine, 102 | ) 103 | for d in range(num_layers) 104 | ] 105 | ) 106 | 107 | self.proj_out = nn.Linear(inner_dim, in_channels) 108 | 109 | def forward( 110 | self, 111 | hidden_states, 112 | encoder_hidden_states=None, 113 | timestep=None, 114 | class_labels=None, 115 | num_frames=1, 116 | cross_attention_kwargs=None, 117 | return_dict: bool = True, 118 | ): 119 | """ 120 | The [`TransformerTemporal`] forward method. 121 | 122 | Args: 123 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): 124 | Input hidden_states. 125 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): 126 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 127 | self-attention. 128 | timestep ( `torch.long`, *optional*): 129 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. 130 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): 131 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in 132 | `AdaLayerZeroNorm`. 133 | return_dict (`bool`, *optional*, defaults to `True`): 134 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 135 | tuple. 136 | 137 | Returns: 138 | [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: 139 | If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is 140 | returned, otherwise a `tuple` where the first element is the sample tensor. 141 | """ 142 | # 1. Input 143 | batch_frames, channel, height, width = hidden_states.shape 144 | batch_size = batch_frames // num_frames 145 | 146 | residual = hidden_states 147 | 148 | hidden_states = hidden_states[None, :].reshape( 149 | batch_size, num_frames, channel, height, width 150 | ) 151 | hidden_states = hidden_states.permute(0, 2, 1, 3, 4) 152 | 153 | hidden_states = self.norm(hidden_states) 154 | hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape( 155 | batch_size * height * width, num_frames, channel 156 | ) 157 | 158 | hidden_states = self.proj_in(hidden_states) 159 | 160 | base_attn_key = cross_attention_kwargs["attn_key"] 161 | 162 | # 2. Blocks 163 | for block_ind, block in enumerate(self.transformer_blocks): 164 | cross_attention_kwargs["attn_key"] = base_attn_key + [block_ind] 165 | 166 | hidden_states = block( 167 | hidden_states, 168 | encoder_hidden_states=encoder_hidden_states, 169 | timestep=timestep, 170 | cross_attention_kwargs=cross_attention_kwargs, 171 | class_labels=class_labels, 172 | ) 173 | 174 | # 3. Output 175 | hidden_states = self.proj_out(hidden_states) 176 | hidden_states = ( 177 | hidden_states[None, None, :] 178 | .reshape(batch_size, height, width, channel, num_frames) 179 | .permute(0, 3, 4, 1, 2) 180 | .contiguous() 181 | ) 182 | hidden_states = hidden_states.reshape(batch_frames, channel, height, width) 183 | 184 | output = hidden_states + residual 185 | 186 | if not return_dict: 187 | return (output,) 188 | 189 | return TransformerTemporalModelOutput(sample=output) 190 | -------------------------------------------------------------------------------- /prompt.py: -------------------------------------------------------------------------------- 1 | # make the response more structured 2 | templatev0_1_chat = [ 3 | { 4 | "role": "system", 5 | "content": """You are an intelligent bounding box generator for videos. You don't need to generate the videos themselves but need to generate the bounding boxes. I will provide you with a caption for a video with six frames, with two frames per second. Your task is to generate a list of realistic bounding boxes for the objects mentioned in the caption for each frame as well as a background keyword. The video frames are of size 512x512. The top-left corner has coordinates [0, 0]. The bottom-right corner has coordinnates [512, 512]. The bounding boxes should not overlap or go beyond the frame boundaries. 6 | 7 | Each frame should be represented as `[{'id': unique object identifier incrementing from 0, 'name': object name, 'box': [box top-left x-coordinate, box top-left y-coordinate, box width, box height]}, ...]`. Each box should not include more than one object. Your generated frames must encapsulate the whole scenario depicted by the caption. Assume objects move and interact based on real-world physics, considering aspects such as gravity and elasticity. Assume the camera follows perspective geometry. Boxes for an object should have the same id across the frames, even if the object may disappear and reappear. If needed, you can make reasonable guesses. Provide a concise reasoning statement that is not longer than a few sentences before each generation. Refer to the examples below for the desired format. Never use markdown or other formats not in the examples. Do not start each frame with `-`. Do not include any comments in your response.""", 8 | }, 9 | { 10 | "role": "user", 11 | "content": "Caption: A woman walking from the left to the right and a man jumping on the right in a room", 12 | }, 13 | { 14 | "role": "assistant", 15 | "content": """Reasoning: A woman is walking from the left to the right so her x-coordinate should increase with her y-coordinate fixed. A man is jumping on the right so his x-coordinate should be large, and his y-coordinate should first decrease (upward movement) and then increase (downward movement due to gravity). 16 | Frame 1: [{'id': 0, 'name': 'walking woman', 'box': [0, 270, 120, 200]}, {'id': 1, 'name': 'jumping man', 'box': [380, 290, 120, 180]}] 17 | Frame 2: [{'id': 0, 'name': 'walking woman', 'box': [50, 270, 120, 200]}, {'id': 1, 'name': 'jumping man', 'box': [380, 205, 120, 200]}] 18 | Frame 3: [{'id': 0, 'name': 'walking woman', 'box': [100, 270, 120, 200]}, {'id': 1, 'name': 'jumping man', 'box': [380, 175, 120, 200]}] 19 | Frame 4: [{'id': 0, 'name': 'walking woman', 'box': [150, 270, 120, 200]}, {'id': 1, 'name': 'jumping man', 'box': [380, 175, 120, 200]}] 20 | Frame 5: [{'id': 0, 'name': 'walking woman', 'box': [200, 270, 120, 200]}, {'id': 1, 'name': 'jumping man', 'box': [380, 205, 120, 200]}] 21 | Frame 6: [{'id': 0, 'name': 'walking woman', 'box': [250, 270, 120, 200]}, {'id': 1, 'name': 'jumping man', 'box': [380, 290, 120, 180]}] 22 | Background keyword: room""", 23 | }, 24 | { 25 | "role": "user", 26 | "content": "Caption: A red ball is thrown from the left to the right in a garden", 27 | }, 28 | { 29 | "role": "assistant", 30 | "content": """Reasoning: A ball is thrown from the left to the right, so its x-coordinate should increase. Due to gravity, its y-coordinate should increase, and the speed should be faster in later frames until it hits the ground. Due to its elasticity, the ball bounces back when it hits the ground. 31 | Frame 1: [{'id': 0, 'name': 'red ball', 'box': [0, 206, 50, 50]}] 32 | Frame 2: [{'id': 0, 'name': 'red ball', 'box': [80, 246, 50, 50]}] 33 | Frame 3: [{'id': 0, 'name': 'red ball', 'box': [160, 326, 50, 50]}] 34 | Frame 4: [{'id': 0, 'name': 'red ball', 'box': [240, 446, 50, 50]}] 35 | Frame 5: [{'id': 0, 'name': 'red ball', 'box': [320, 366, 50, 50]}] 36 | Frame 6: [{'id': 0, 'name': 'red ball', 'box': [400, 446, 50, 50]}] 37 | Background keyword: garden""", 38 | }, 39 | {"role": "user", "content": "Caption: The camera is moving away from a painting"}, 40 | { 41 | "role": "assistant", 42 | "content": """Reasoning: Due to perspective geometry, the painting will be smaller in later timesteps as the distance between the camera and the object is larger. 43 | Frame 1: [{'id': 0, 'name': 'painting', 'box': [156, 181, 200, 150]}] 44 | Frame 2: [{'id': 0, 'name': 'painting', 'box': [166, 189, 180, 135]}] 45 | Frame 3: [{'id': 0, 'name': 'painting', 'box': [176, 196, 160, 120]}] 46 | Frame 4: [{'id': 0, 'name': 'painting', 'box': [186, 204, 140, 105]}] 47 | Frame 5: [{'id': 0, 'name': 'painting', 'box': [196, 211, 120, 90]}] 48 | Frame 6: [{'id': 0, 'name': 'painting', 'box': [206, 219, 100, 75]}] 49 | Background keyword: room""", 50 | }, 51 | ] 52 | 53 | templates = { 54 | "v0.1": templatev0_1_chat, 55 | } 56 | 57 | template_versions = list(templates.keys()) 58 | 59 | 60 | def get_num_parsed_layout_frames(template_version): 61 | return 6 62 | 63 | 64 | # 6 frames 65 | required_lines = [f"Frame {i+1}:" for i in range(6)] + ["Background keyword:"] 66 | required_lines_ast = [True] * 6 + [False] 67 | 68 | strip_before = required_lines[0] 69 | 70 | stop = "\n\n" 71 | 72 | prompts_demo = [ 73 | "A bear walks from the left to the right", 74 | ] 75 | 76 | prompt_types = ["demo", "lvd"] 77 | 78 | negative_prompt = ( 79 | "dull, gray, unrealistic, colorless, blurry, low-quality, weird, abrupt" 80 | ) 81 | 82 | 83 | def get_prompts(prompt_type, return_predicates=False): 84 | if prompt_type.startswith("lvd"): 85 | from utils.eval.lvd import get_lvd_full_prompts, get_lvd_full_prompt_predicates 86 | 87 | if return_predicates: 88 | prompts = get_lvd_full_prompt_predicates(prompt_type) 89 | else: 90 | prompts = get_lvd_full_prompts(prompt_type) 91 | elif prompt_type == "demo": 92 | assert ( 93 | not return_predicates 94 | ), "Predicates are not supported for this prompt type" 95 | prompts = prompts_demo 96 | else: 97 | raise ValueError(f"Unknown prompt type: {prompt_type}") 98 | 99 | return prompts 100 | 101 | 102 | if __name__ == "__main__": 103 | if True: 104 | prompt_type = "demo" 105 | 106 | assert prompt_type in prompt_types, f"prompt_type {prompt_type} does not exist" 107 | 108 | prompts = get_prompts(prompt_type) 109 | prompt = prompts[-1] 110 | else: 111 | prompt = input("Prompt: ") 112 | 113 | template_key = "v0.1" 114 | template = templates[template_key] 115 | 116 | if isinstance(template, list): 117 | template = ( 118 | "\n\n".join([item["content"] for item in template]) 119 | + "\n\nCaption: {prompt}\nReasoning:" 120 | ) 121 | 122 | prompt_full = template.replace("{prompt}", prompt.strip().rstrip(".")).strip() 123 | print(prompt_full) 124 | 125 | if False: 126 | import json 127 | 128 | print(json.dumps(prompt_full.strip("\n"))) 129 | -------------------------------------------------------------------------------- /prompt_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | from prompt import get_prompts, template_versions 3 | from utils import parse, utils 4 | from utils.parse import show_video_boxes, size 5 | from utils.llm import get_llm_kwargs, get_full_prompt, model_names, get_parsed_layout 6 | from utils import cache 7 | import argparse 8 | import time 9 | 10 | # This only applies to visualization in this file. 11 | scale_boxes = False 12 | 13 | if scale_boxes: 14 | print("Scaling the bounding box to fit the scene") 15 | else: 16 | print("Not scaling the bounding box to fit the scene") 17 | 18 | H, W = size 19 | 20 | 21 | def visualize_layout(parsed_layout): 22 | condition = parse.parsed_layout_to_condition( 23 | parsed_layout, tokenizer=None, height=H, width=W, verbose=True 24 | ) 25 | 26 | show_video_boxes(condition, ind=ind, save=True) 27 | 28 | print(f"Visualize masks at {parse.img_dir}") 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--prompt-type", type=str, default="demo") 34 | parser.add_argument("--model", choices=model_names, required=True) 35 | parser.add_argument("--template_version", choices=template_versions, required=True) 36 | parser.add_argument( 37 | "--auto-query", action="store_true", help="Auto query using the API" 38 | ) 39 | parser.add_argument( 40 | "--always-save", 41 | action="store_true", 42 | help="Always save the layout without confirming", 43 | ) 44 | parser.add_argument("--no-visualize", action="store_true", help="No visualizations") 45 | parser.add_argument( 46 | "--visualize-cache-hit", action="store_true", help="Save boxes for cache hit" 47 | ) 48 | parser.add_argument( 49 | "--unnormalize-boxes-before-save", 50 | action="store_true", 51 | help="Unnormalize the boxes before saving. This should be enabled if the prompt asks the LLM to return normalized boxes.", 52 | ) 53 | args = parser.parse_args() 54 | 55 | visualize_cache_hit = args.visualize_cache_hit 56 | 57 | template_version = args.template_version 58 | 59 | model, llm_kwargs = get_llm_kwargs( 60 | model=args.model, template_version=template_version 61 | ) 62 | template = llm_kwargs.template 63 | # Need to parse json format for json templates 64 | json_template = "json" in template_version 65 | 66 | # This is for visualizing bounding boxes 67 | parse.img_dir = ( 68 | f"img_generations/imgs_{args.prompt_type}_template{template_version}" 69 | ) 70 | if not args.no_visualize: 71 | os.makedirs(parse.img_dir, exist_ok=True) 72 | 73 | cache.cache_path = f'cache/cache_{args.prompt_type.replace("lmd_", "")}{"_" + template_version if args.template_version != "v5" else ""}_{model}.json' 74 | 75 | os.makedirs(os.path.dirname(cache.cache_path), exist_ok=True) 76 | cache.cache_format = "json" 77 | 78 | cache.init_cache() 79 | 80 | prompts_query = get_prompts(args.prompt_type) 81 | 82 | max_attempts = 1 83 | 84 | for ind, prompt in enumerate(prompts_query): 85 | if isinstance(prompt, list): 86 | # prompt, seed 87 | prompt = prompt[0] 88 | prompt = prompt.strip().rstrip(".") 89 | 90 | resp = cache.get_cache(prompt) 91 | if resp is None: 92 | print(f"Cache miss: {prompt}") 93 | 94 | if not args.auto_query: 95 | print("#########") 96 | prompt_full = get_full_prompt(template=template, prompt=prompt) 97 | print(prompt_full) 98 | print("#########") 99 | resp = None 100 | 101 | attempts = 0 102 | while True: 103 | attempts += 1 104 | try: 105 | # The resp from `get_parsed_layout` has already been structured 106 | if args.auto_query: 107 | parsed_layout, resp = get_parsed_layout( 108 | prompt, 109 | llm_kwargs=llm_kwargs, 110 | json_template=json_template, 111 | verbose=False, 112 | ) 113 | print("Response:", resp) 114 | else: 115 | resp = utils.multiline_input( 116 | prompt="Please enter LLM response (use an empty line to end): " 117 | ) 118 | parsed_layout, resp = get_parsed_layout( 119 | prompt, 120 | llm_kwargs=llm_kwargs, 121 | override_response=resp, 122 | max_partial_response_retries=1, 123 | json_template=json_template, 124 | verbose=False, 125 | ) 126 | 127 | except (ValueError, SyntaxError, TypeError) as e: 128 | if attempts > max_attempts: 129 | print("Retrying too many times, skipping") 130 | break 131 | print( 132 | f"Encountered invalid data with prompt {prompt} and response {resp}: {e}, retrying" 133 | ) 134 | time.sleep(1) 135 | continue 136 | 137 | if not args.no_visualize: 138 | visualize_layout(parsed_layout) 139 | if not args.always_save: 140 | save = input("Save (y/n)? ").strip() 141 | else: 142 | save = "y" 143 | if save == "y" or save == "Y": 144 | cache.add_cache(prompt, resp) 145 | else: 146 | print("Not saved. Will generate the same prompt again.") 147 | continue 148 | break 149 | else: 150 | print(f"Cache hit: {prompt}") 151 | 152 | parsed_layout, resp = get_parsed_layout( 153 | prompt, 154 | llm_kwargs=llm_kwargs, 155 | override_response=resp, 156 | max_partial_response_retries=1, 157 | json_template=json_template, 158 | verbose=False, 159 | ) 160 | 161 | if visualize_cache_hit: 162 | visualize_layout(parsed_layout) 163 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu113 2 | numpy 3 | scipy 4 | torch==2.0.0 5 | diffusers==0.27.2 6 | transformers==4.36.2 7 | opencv-python==4.7.0.72 8 | opencv-contrib-python==4.7.0.72 9 | inflect==6.0.4 10 | easydict 11 | accelerate==0.21.0 12 | gradio==3.35.2 13 | pydantic==1.10.7 14 | scikit-video==1.1.11 15 | imageio==2.34.1 16 | pyjson5==1.6.6 17 | joblib==1.4.2 18 | -------------------------------------------------------------------------------- /scripts/eval_owl_vit.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) 5 | 6 | import argparse 7 | import torch 8 | from utils.llm import get_full_model_name, model_names 9 | from transformers import OwlViTProcessor, OwlViTForObjectDetection 10 | from glob import glob 11 | from utils.eval import to_gen_box_format, evaluate_with_layout, class_aware_nms, nms 12 | from tqdm import tqdm 13 | import numpy as np 14 | import json 15 | import joblib 16 | from prompt import get_prompts, prompt_types 17 | 18 | torch.set_grad_enabled(False) 19 | 20 | 21 | def keep_one_box_per_class(boxes, scores, labels): 22 | # Keep the box with highest label per class 23 | 24 | boxes_output, scores_output, labels_output = [], [], [] 25 | labels_unique = np.unique(labels) 26 | for label in labels_unique: 27 | label_mask = labels == label 28 | boxes_label = boxes[label_mask] 29 | scores_label = scores[label_mask] 30 | max_score_index = scores_label.argmax() 31 | box, score = boxes_label[max_score_index], scores_label[max_score_index] 32 | boxes_output.append(box) 33 | scores_output.append(score) 34 | labels_output.append(label) 35 | 36 | return np.array(boxes_output), np.array(scores_output), np.array(labels_output) 37 | 38 | 39 | def eval_prompt( 40 | p, 41 | predicate, 42 | path, 43 | processor, 44 | model, 45 | score_threshold=0.1, 46 | nms_threshold=0.5, 47 | use_class_aware_nms=False, 48 | num_eval_frames=6, 49 | use_cuda=True, 50 | verbose=False, 51 | ): 52 | video = joblib.load(path) 53 | texts = [predicate.texts] 54 | 55 | parsed_layout = {"Prompt": p, "Background keyword": None} 56 | 57 | eval_frame_indices = ( 58 | np.round(np.linspace(0, len(video) - 1, num_eval_frames)).astype(int).tolist() 59 | ) 60 | 61 | assert len(set(eval_frame_indices)) == len( 62 | eval_frame_indices 63 | ), f"Eval indices not unique: {eval_frame_indices}" 64 | 65 | print(f"Eval indices: {eval_frame_indices}") 66 | 67 | frame_ind = 1 68 | for eval_frame_index in eval_frame_indices: 69 | image = video[eval_frame_index] 70 | inputs = processor(text=texts, images=image, return_tensors="pt") 71 | if use_cuda: 72 | inputs = inputs.to("cuda") 73 | outputs = model(**inputs) 74 | 75 | height, width, _ = image.shape 76 | 77 | # Target image sizes (height, width) to rescale box predictions [batch_size, 2] 78 | target_sizes = torch.Tensor([[height, width]]) 79 | if use_cuda: 80 | target_sizes = target_sizes.cuda() 81 | # Convert outputs (bounding boxes and class logits) to COCO API 82 | results = processor.post_process(outputs=outputs, target_sizes=target_sizes) 83 | 84 | i = 0 # Retrieve predictions for the first image for the corresponding text queries 85 | text = texts[i] 86 | boxes, scores, labels = ( 87 | results[i]["boxes"], 88 | results[i]["scores"], 89 | results[i]["labels"], 90 | ) 91 | boxes = boxes.cpu() 92 | # xyxy ranging from 0 to 1 93 | boxes = np.array( 94 | [ 95 | [x_min / width, y_min / height, x_max / width, y_max / height] 96 | for (x_min, y_min, x_max, y_max), score in zip(boxes, scores) 97 | if score >= score_threshold 98 | ] 99 | ) 100 | labels = np.array( 101 | [ 102 | label.cpu().numpy() 103 | for label, score in zip(labels, scores) 104 | if score >= score_threshold 105 | ] 106 | ) 107 | scores = np.array( 108 | [score.cpu().numpy() for score in scores if score >= score_threshold] 109 | ) 110 | 111 | # print(f"Pre-NMS:") 112 | # for box, score, label in zip(boxes, scores, labels): 113 | # box = [round(i, 2) for i in box.tolist()] 114 | # print( 115 | # f"Detected {text[label]} ({label}) with confidence {round(score.item(), 3)} at location {box}") 116 | 117 | print(f"Post-NMS (frame frame_ind):") 118 | 119 | if use_class_aware_nms: 120 | boxes, scores, labels = class_aware_nms( 121 | boxes, scores, labels, nms_threshold 122 | ) 123 | else: 124 | boxes, scores, labels = nms(boxes, scores, labels, nms_threshold) 125 | 126 | for box, score, label in zip(boxes, scores, labels): 127 | box = [round(i, 2) for i in box.tolist()] 128 | print( 129 | f"Detected {text[label]} ({label}) with confidence {round(score.item(), 3)} at location {box}" 130 | ) 131 | 132 | if verbose: 133 | print( 134 | f"prompt: {p}, texts: {texts}, boxes: {boxes}, labels: {labels}, scores: {scores}" 135 | ) 136 | 137 | # Here we are not using a tracker so the box id could mismatch when we have multiple objects with the same label. 138 | # For numeracy, we do not need tracking (mismatch is ok). For other tasks, we only include the box with max confidence. 139 | if predicate.one_box_per_class: 140 | boxes, scores, labels = keep_one_box_per_class(boxes, scores, labels) 141 | 142 | for box, score, label in zip(boxes, scores, labels): 143 | box = [round(i, 2) for i in box.tolist()] 144 | print( 145 | f"After selection one box per class: Detected {text[label]} ({label}) with confidence {round(score.item(), 3)} at location {box}" 146 | ) 147 | 148 | det_boxes = [] 149 | label_counts = {} 150 | 151 | # This ensures boxes of different labels will not be matched to each other. 152 | for box, score, label in zip(boxes, scores, labels): 153 | if label not in label_counts: 154 | label_counts[label] = 0 155 | # Here we convert to gen box format (same as LLM output), xywh (in pixels). This is for compatibility with first stage evaluation. This will be converted to condition format (xyxy, ranging from 0 to 1). 156 | det_boxes.append( 157 | { 158 | "id": label * 100 + label_counts[label], 159 | "name": text[label], 160 | "box": to_gen_box_format(box, width, height, rounding=True), 161 | "score": score, 162 | } 163 | ) 164 | label_counts[label] += 1 165 | 166 | parsed_layout[f"Frame {frame_ind}"] = det_boxes 167 | 168 | frame_ind += 1 169 | 170 | print(f"parsed_layout: {parsed_layout}") 171 | 172 | eval_type, eval_success = evaluate_with_layout( 173 | parsed_layout, 174 | predicate, 175 | num_parsed_layout_frames=num_eval_frames, 176 | height=height, 177 | width=width, 178 | verbose=verbose, 179 | ) 180 | 181 | return eval_type, eval_success 182 | 183 | 184 | if __name__ == "__main__": 185 | parser = argparse.ArgumentParser() 186 | parser.add_argument("--prompt-type", type=str, default="lvd") 187 | parser.add_argument("--run_base_path", type=str) 188 | parser.add_argument("--run_start_ind", default=0, type=int) 189 | parser.add_argument("--num_prompts", default=None, type=int) 190 | parser.add_argument("--num_eval_frames", default=6, type=int) 191 | parser.add_argument("--skip_first_prompts", default=0, type=int) 192 | parser.add_argument("--detection_score_threshold", default=0.05, type=float) 193 | parser.add_argument("--nms_threshold", default=0.5, type=float) 194 | parser.add_argument("--class-aware-nms", action="store_true") 195 | parser.add_argument("--save-eval", action="store_true") 196 | parser.add_argument("--verbose", action="store_true") 197 | parser.add_argument("--no-cuda", action="store_true") 198 | args = parser.parse_args() 199 | 200 | np.set_printoptions(precision=2) 201 | 202 | prompt_predicates = get_prompts(args.prompt_type, return_predicates=True) 203 | num_eval_frames = args.num_eval_frames 204 | 205 | print(f"Number of prompts (predicates): {len(prompt_predicates)}") 206 | print(f"Number of evaluating frames: {num_eval_frames}") 207 | 208 | processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") 209 | owl_vit_model = OwlViTForObjectDetection.from_pretrained( 210 | "google/owlvit-base-patch32" 211 | ) 212 | owl_vit_model.eval() 213 | 214 | use_cuda = not args.no_cuda 215 | 216 | if use_cuda: 217 | owl_vit_model.cuda() 218 | 219 | eval_success_counts = {} 220 | eval_all_counts = {} 221 | 222 | eval_successes = {} 223 | 224 | for ind, (prompt, predicate) in enumerate(tqdm(prompt_predicates)): 225 | if isinstance(prompt, list): 226 | # prompt and kwargs 227 | prompt = prompt[0] 228 | prompt = prompt.strip().rstrip(".") 229 | if ind < args.skip_first_prompts: 230 | continue 231 | if args.num_prompts is not None and ind >= ( 232 | args.skip_first_prompts + args.num_prompts 233 | ): 234 | continue 235 | 236 | search_path = f"{args.run_base_path}/{ind+args.run_start_ind}/video_*.joblib" 237 | 238 | # NOTE: sorted with string type 239 | path = sorted(glob(search_path)) 240 | if len(path) == 0: 241 | print(f"***No image matching {search_path}, skipping***") 242 | continue 243 | elif len(path) > 1: 244 | print(f"***More than one images match {search_path}: {path}, skipping***") 245 | continue 246 | path = path[0] 247 | print(f"Video path: {path} ({path.replace('.joblib', '.gif')})") 248 | 249 | eval_type, eval_success = eval_prompt( 250 | prompt, 251 | predicate, 252 | path, 253 | processor, 254 | owl_vit_model, 255 | score_threshold=args.detection_score_threshold, 256 | nms_threshold=args.nms_threshold, 257 | use_class_aware_nms=args.class_aware_nms, 258 | num_eval_frames=num_eval_frames, 259 | use_cuda=use_cuda, 260 | verbose=args.verbose, 261 | ) 262 | 263 | print(f"Eval success (eval_type):", eval_success) 264 | 265 | if eval_type not in eval_all_counts: 266 | eval_success_counts[eval_type] = 0 267 | eval_all_counts[eval_type] = 0 268 | eval_successes[eval_type] = [] 269 | 270 | eval_success_counts[eval_type] += int(eval_success) 271 | eval_all_counts[eval_type] += 1 272 | eval_successes[eval_type].append(bool(eval_success)) 273 | 274 | summary = [] 275 | eval_success_conut, eval_all_count = 0, 0 276 | for k, v in eval_all_counts.items(): 277 | rate = eval_success_counts[k] / eval_all_counts[k] 278 | print( 279 | f"Eval type: {k}, success: {eval_success_counts[k]}/{eval_all_counts[k]}, rate: {round(rate, 2):.2f}" 280 | ) 281 | eval_success_conut += eval_success_counts[k] 282 | eval_all_count += eval_all_counts[k] 283 | summary.append(rate) 284 | 285 | rate = eval_success_conut / eval_all_count 286 | print(f"Overall: success: {eval_success_conut}/{eval_all_count}, rate: {rate:.2f}") 287 | summary.append(rate) 288 | 289 | summary_str = "/".join([f"{round(rate, 2):.2f}" for rate in summary]) 290 | print(f"Summary: {summary_str}") 291 | 292 | if args.save_eval: 293 | save_eval = {} 294 | save_eval["success_counts"] = eval_success_counts 295 | save_eval["sample_counts"] = eval_all_counts 296 | save_eval["successes"] = eval_successes 297 | save_eval["success_counts_overall"] = eval_success_conut 298 | save_eval["sample_counts_overall"] = eval_all_count 299 | 300 | # Reference: https://stackoverflow.com/questions/58408054/typeerror-object-of-type-bool-is-not-json-serializable 301 | 302 | with open(f"{args.run_base_path}/eval.json", "w") as f: 303 | json.dump(save_eval, f, indent=4) 304 | -------------------------------------------------------------------------------- /scripts/eval_stage_one.py: -------------------------------------------------------------------------------- 1 | # This script allows evaluating stage one and saving the generated prompts to cache 2 | 3 | import sys 4 | import os 5 | 6 | sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) 7 | 8 | import json 9 | import argparse 10 | from prompt import get_prompts, get_num_parsed_layout_frames, template_versions 11 | from utils.llm import get_llm_kwargs, get_parsed_layout_with_cache, model_names 12 | from utils.eval import evaluate_with_layout 13 | from utils import parse, cache 14 | import numpy as np 15 | from tqdm import tqdm 16 | 17 | eval_success_counts = {} 18 | eval_all_counts = {} 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--prompt-type", type=str, default="lvd") 23 | parser.add_argument("--model", choices=model_names, required=True) 24 | parser.add_argument("--template_version", choices=template_versions, required=True) 25 | parser.add_argument("--skip_first_prompts", default=0, type=int) 26 | parser.add_argument("--num_prompts", default=None, type=int) 27 | parser.add_argument("--show-cache-access", action="store_true") 28 | parser.add_argument("--verbose", action="store_true") 29 | args = parser.parse_args() 30 | 31 | np.set_printoptions(precision=2) 32 | 33 | template_version = args.template_version 34 | 35 | json_template = "json" in template_version 36 | 37 | model, llm_kwargs = get_llm_kwargs( 38 | model=args.model, template_version=template_version 39 | ) 40 | 41 | cache.cache_format = "json" 42 | cache.cache_path = f'cache/cache_{args.prompt_type.replace("lmd_", "")}_{template_version}_{model}.json' 43 | cache.init_cache() 44 | 45 | prompt_predicates = get_prompts(args.prompt_type, return_predicates=True) 46 | print(f"Number of prompts (predicates): {len(prompt_predicates)}") 47 | 48 | height, width = parse.size_h, parse.size_w 49 | 50 | for ind, (prompt, predicate) in enumerate(tqdm(prompt_predicates)): 51 | if isinstance(prompt, list): 52 | # prompt and kwargs 53 | prompt = prompt[0] 54 | prompt = prompt.strip().rstrip(".") 55 | if ind < args.skip_first_prompts: 56 | continue 57 | if args.num_prompts is not None and ind >= ( 58 | args.skip_first_prompts + args.num_prompts 59 | ): 60 | continue 61 | 62 | parsed_layout = get_parsed_layout_with_cache( 63 | prompt, llm_kwargs, json_template=json_template, verbose=args.verbose 64 | ) 65 | num_parsed_layout_frames = get_num_parsed_layout_frames(template_version) 66 | eval_type, eval_success = evaluate_with_layout( 67 | parsed_layout, 68 | predicate, 69 | num_parsed_layout_frames, 70 | height=height, 71 | width=width, 72 | verbose=args.verbose, 73 | ) 74 | 75 | print(f"Eval success (eval_type):", eval_success) 76 | 77 | if eval_type not in eval_all_counts: 78 | eval_success_counts[eval_type] = 0 79 | eval_all_counts[eval_type] = 0 80 | eval_success_counts[eval_type] += int(eval_success) 81 | eval_all_counts[eval_type] += 1 82 | 83 | eval_success_conut, eval_all_count = 0, 0 84 | for k, v in eval_all_counts.items(): 85 | print( 86 | f"Eval type: {k}, success: {eval_success_counts[k]}/{eval_all_counts[k]}, rate: {eval_success_counts[k]/eval_all_counts[k]:.2f}" 87 | ) 88 | eval_success_conut += eval_success_counts[k] 89 | eval_all_count += eval_all_counts[k] 90 | 91 | print( 92 | f"Overall: success: {eval_success_conut}/{eval_all_count}, rate: {eval_success_conut/eval_all_count:.2f}" 93 | ) 94 | 95 | if args.show_cache_access: 96 | # Print what are accessed in the cache (may have multiple values in each key) 97 | # Not including the newly added items 98 | print(json.dumps(cache.cache_queries)) 99 | print("Number of accessed keys:", len(cache.cache_queries)) 100 | -------------------------------------------------------------------------------- /scripts/upsample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import joblib 3 | import torch 4 | import imageio 5 | import os 6 | from PIL import Image 7 | import numpy as np 8 | from tqdm import tqdm 9 | import cv2 10 | from diffusers import StableDiffusionXLImg2ImgPipeline, VideoToVideoSDPipeline 11 | from diffusers import DPMSolverMultistepScheduler 12 | 13 | 14 | def prepare_init_upsampled(video_path, horizontal): 15 | video = joblib.load(video_path) 16 | 17 | if horizontal: 18 | video = [ 19 | Image.fromarray(frame).resize((1024, 576), Image.LANCZOS) for frame in video 20 | ] 21 | else: 22 | video = [ 23 | Image.fromarray(frame).resize((1024, 1024), Image.LANCZOS) 24 | for frame in video 25 | ] 26 | 27 | return video 28 | 29 | 30 | def save_images_to_video(images, output_video_path, frame_rate=8.0): 31 | """Save a list of images to a video file.""" 32 | if not len(images): 33 | print("No images to process.") 34 | return 35 | 36 | # Assuming all images are the same size, get dimensions from the first image 37 | height, width, layers = images[0].shape 38 | 39 | # Define the codec and create VideoWriter object 40 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec definition 41 | video = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (width, height)) 42 | 43 | for img in images: 44 | bgr_image = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Convert RGB to BGR 45 | video.write(bgr_image) 46 | 47 | video.release() # Release the video writer 48 | print(f"Video saved to {output_video_path}") 49 | 50 | 51 | def upsample_zsxl( 52 | video_path, 53 | prompt, 54 | horizontal, 55 | negative_prompt, 56 | seed, 57 | strength, 58 | use_zssdxl, 59 | output_mp4, 60 | fps=8, 61 | ): 62 | save_path = video_path.replace( 63 | ".joblib", "_zsxl" if strength == 0.35 else f"_zsxl_s{strength}" 64 | ) 65 | if not os.path.exists(save_path + ".joblib"): 66 | video = prepare_init_upsampled(video_path, horizontal) 67 | g = torch.manual_seed(seed) 68 | video_frames_xl = pipe_xl( 69 | prompt, 70 | negative_prompt=negative_prompt, 71 | video=video, 72 | strength=strength, 73 | generator=g, 74 | ).frames[0] 75 | assert not os.path.exists(save_path + ".joblib"), save_path + ".joblib" 76 | video_frames_xl = (video_frames_xl * 255.0).astype(np.uint8) 77 | if output_mp4: 78 | save_images_to_video( 79 | video_frames_xl, save_path + ".mp4", frame_rate=fps + 0.01 80 | ) 81 | imageio.mimsave( 82 | save_path + ".gif", 83 | video_frames_xl, 84 | format="gif", 85 | loop=0, 86 | duration=1000 * 1 / fps, 87 | ) 88 | joblib.dump(video_frames_xl, save_path + ".joblib", compress=("bz2", 3)) 89 | print(f"Zeroscope XL upsampled image saved at: {save_path + '.gif'}") 90 | else: 91 | print(f"{save_path + '.joblib'} exists, skipping") 92 | 93 | if use_zssdxl: 94 | upsample_sdxl( 95 | save_path + ".joblib", 96 | prompt, 97 | horizontal, 98 | negative_prompt, 99 | seed, 100 | strength=0.1, 101 | ) 102 | 103 | 104 | def upsample_sdxl(video_path, prompt, horizontal, negative_prompt, seed, strength): 105 | save_path = video_path.replace( 106 | ".joblib", "_sdxl" if strength == 0.35 else f"_sdxl_s{strength}" 107 | ) 108 | if not os.path.exists(save_path + ".joblib"): 109 | video = prepare_init_upsampled(video_path, horizontal) 110 | pipe_sdxl.set_progress_bar_config(disable=True) 111 | video_frames_sdxl = [] 112 | for video_frame in tqdm(video): 113 | g = torch.manual_seed(seed) 114 | image = pipe_sdxl( 115 | prompt, 116 | image=video_frame, 117 | negative_prompt=negative_prompt, 118 | strength=strength, 119 | generator=g, 120 | ).images[0] 121 | video_frames_sdxl.append(np.asarray(image)) 122 | assert not os.path.exists(save_path + ".joblib"), save_path + ".joblib" 123 | imageio.mimsave(save_path + ".gif", video_frames_sdxl, format="gif", loop=0) 124 | joblib.dump(video_frames_sdxl, save_path + ".joblib", compress=("bz2", 3)) 125 | 126 | print(f"SDXL upsampled image saved at: {save_path + '.gif'}") 127 | else: 128 | print(f"{save_path + '.joblib'} exists, skipping") 129 | 130 | 131 | if __name__ == "__main__": 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument( 134 | "--videos", 135 | nargs="+", 136 | required=True, 137 | type=str, 138 | help="path to videos in joblib format", 139 | ) 140 | parser.add_argument("--prompts", nargs="+", required=True, type=str, help="prompts") 141 | parser.add_argument("--seed", type=int, default=1) 142 | parser.add_argument("--strength", type=float, default=0.35) 143 | parser.add_argument( 144 | "--negative_prompt", 145 | type=str, 146 | default="dull, gray, unrealistic, colorless, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly", 147 | ) 148 | parser.add_argument("--use_zsxl", action="store_true") 149 | parser.add_argument("--use_sdxl", action="store_true") 150 | parser.add_argument("--use_zssdxl", action="store_true") 151 | parser.add_argument( 152 | "--horizontal", 153 | action="store_true", 154 | help="If True, the video is assumed to be horizontal (576x320 to 1024x576). If False, squared (512x512 to 1024x1024).", 155 | ) 156 | parser.add_argument("--output-mp4", action="store_true", help="Store mp4 videos.") 157 | 158 | args = parser.parse_args() 159 | 160 | if args.use_zsxl: 161 | pipe_xl = VideoToVideoSDPipeline.from_pretrained( 162 | "cerspense/zeroscope_v2_XL", torch_dtype=torch.float16 163 | ) 164 | pipe_xl.scheduler = DPMSolverMultistepScheduler.from_config( 165 | pipe_xl.scheduler.config 166 | ) 167 | pipe_xl.enable_model_cpu_offload() 168 | pipe_xl.enable_vae_slicing() 169 | 170 | if args.use_sdxl or args.use_zssdxl: 171 | pipe_sdxl = StableDiffusionXLImg2ImgPipeline.from_pretrained( 172 | "stabilityai/stable-diffusion-xl-refiner-1.0", 173 | torch_dtype=torch.float16, 174 | variant="fp16", 175 | use_safetensors=True, 176 | ) 177 | pipe_sdxl = pipe_sdxl.to("cuda") 178 | 179 | if len(args.prompts) == 1 and len(args.videos) > 1: 180 | args.prompts = args.prompts * len(args.videos) 181 | 182 | for video_path, prompt in tqdm(zip(args.videos, args.prompts)): 183 | video_path = video_path.replace(".gif", ".joblib") 184 | print(f"Video path: {video_path}, prompt: {prompt}") 185 | 186 | if args.use_zsxl: 187 | upsample_zsxl( 188 | video_path=video_path, 189 | prompt=prompt, 190 | horizontal=args.horizontal, 191 | negative_prompt=args.negative_prompt, 192 | seed=args.seed, 193 | strength=args.strength, 194 | use_zssdxl=args.use_zssdxl, 195 | output_mp4=args.output_mp4, 196 | ) 197 | 198 | if args.use_sdxl: 199 | upsample_sdxl( 200 | video_path=video_path, 201 | prompt=prompt, 202 | horizontal=args.horizontal, 203 | negative_prompt=args.negative_prompt, 204 | seed=args.seed, 205 | strength=args.strength, 206 | ) 207 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | -------------------------------------------------------------------------------- /utils/attn.py: -------------------------------------------------------------------------------- 1 | # visualization-related functions are in vis 2 | import numbers 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import math 7 | import utils 8 | 9 | 10 | def get_token_attnv2( 11 | token_id, 12 | saved_attns, 13 | attn_key, 14 | visualize_step_start=10, 15 | input_ca_has_condition_only=False, 16 | return_np=False, 17 | ): 18 | """ 19 | saved_attns: a list of saved_attn (list is across timesteps) 20 | 21 | moves to cpu by default 22 | """ 23 | saved_attns = saved_attns[visualize_step_start:] 24 | 25 | saved_attns = [saved_attn[attn_key].cpu() for saved_attn in saved_attns] 26 | 27 | attn = torch.stack(saved_attns, dim=0).mean(dim=0) 28 | 29 | # print("attn shape", attn.shape) 30 | 31 | # attn: (batch, head, spatial, text) 32 | 33 | if not input_ca_has_condition_only: 34 | assert ( 35 | attn.shape[0] == 2 36 | ), f"Expect to have 2 items (uncond and cond), but found {attn.shape[0]} items" 37 | attn = attn[1] 38 | else: 39 | assert ( 40 | attn.shape[0] == 1 41 | ), f"Expect to have 1 item (cond only), but found {attn.shape[0]} items" 42 | attn = attn[0] 43 | attn = attn.mean(dim=0)[:, token_id] 44 | H, W = utils.get_hw_from_attn_dim(attn_dim=attn.shape[0]) 45 | attn = attn.reshape((H, W)) 46 | 47 | if return_np: 48 | return attn.numpy() 49 | 50 | return attn 51 | 52 | 53 | def shift_saved_attns_item( 54 | saved_attns_item, offset, guidance_attn_keys, horizontal_shift_only=False 55 | ): 56 | """ 57 | `horizontal_shift_only`: only shift horizontally. If you use `offset` from `compose_latents_with_alignment` with `horizontal_shift_only=True`, the `offset` already has y_offset = 0 and this option is not needed. 58 | """ 59 | x_offset, y_offset = offset 60 | if horizontal_shift_only: 61 | y_offset = 0.0 62 | 63 | new_saved_attns_item = {} 64 | for k in guidance_attn_keys: 65 | attn_map = saved_attns_item[k] 66 | 67 | attn_size = attn_map.shape[-2] 68 | attn_h, attn_w = utils.get_hw_from_attn_dim(attn_dim=attn_size) 69 | # Example dimensions: [batch_size, num_heads, 8, 8, num_tokens] 70 | attn_map = attn_map.unflatten(2, (attn_h, attn_w)) 71 | attn_map = utils.shift_tensor( 72 | attn_map, x_offset, y_offset, offset_normalized=True, ignore_last_dim=True 73 | ) 74 | attn_map = attn_map.flatten(2, 3) 75 | 76 | new_saved_attns_item[k] = attn_map 77 | 78 | return new_saved_attns_item 79 | 80 | 81 | def shift_saved_attns(saved_attns, offset, guidance_attn_keys, **kwargs): 82 | # Iterate over timesteps 83 | shifted_saved_attns = [ 84 | shift_saved_attns_item(saved_attns_item, offset, guidance_attn_keys, **kwargs) 85 | for saved_attns_item in saved_attns 86 | ] 87 | 88 | return shifted_saved_attns 89 | 90 | 91 | class GaussianSmoothing(nn.Module): 92 | """ 93 | Apply gaussian smoothing on a 94 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 95 | in the input using a depthwise convolution. 96 | Arguments: 97 | channels (int, sequence): Number of channels of the input tensors. Output will 98 | have this number of channels as well. 99 | kernel_size (int, sequence): Size of the gaussian kernel. 100 | sigma (float, sequence): Standard deviation of the gaussian kernel. 101 | dim (int, optional): The number of dimensions of the data. 102 | Default value is 2 (spatial). 103 | 104 | Credit: https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/10 105 | """ 106 | 107 | def __init__(self, channels, kernel_size, sigma, dim=2): 108 | super(GaussianSmoothing, self).__init__() 109 | if isinstance(kernel_size, numbers.Number): 110 | kernel_size = [kernel_size] * dim 111 | if isinstance(sigma, numbers.Number): 112 | sigma = [sigma] * dim 113 | 114 | # The gaussian kernel is the product of the 115 | # gaussian function of each dimension. 116 | kernel = 1 117 | meshgrids = torch.meshgrid( 118 | [torch.arange(size, dtype=torch.float32) for size in kernel_size] 119 | ) 120 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 121 | mean = (size - 1) / 2 122 | kernel *= ( 123 | 1 124 | / (std * math.sqrt(2 * math.pi)) 125 | * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) 126 | ) 127 | 128 | # Make sure sum of values in gaussian kernel equals 1. 129 | kernel = kernel / torch.sum(kernel) 130 | 131 | # Reshape to depthwise convolutional weight 132 | kernel = kernel.view(1, 1, *kernel.size()) 133 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 134 | 135 | self.register_buffer("weight", kernel) 136 | self.groups = channels 137 | 138 | if dim == 1: 139 | self.conv = F.conv1d 140 | elif dim == 2: 141 | self.conv = F.conv2d 142 | elif dim == 3: 143 | self.conv = F.conv3d 144 | else: 145 | raise RuntimeError( 146 | "Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim) 147 | ) 148 | 149 | def forward(self, input): 150 | """ 151 | Apply gaussian filter to input. 152 | Arguments: 153 | input (torch.Tensor): Input to apply gaussian filter on. 154 | Returns: 155 | filtered (torch.Tensor): Filtered output. 156 | """ 157 | return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) 158 | -------------------------------------------------------------------------------- /utils/cache.py: -------------------------------------------------------------------------------- 1 | # If for a prompt we query fewer or equal to the times we have cache, we return from the cache sequentially. Otherwise we store into cache. 2 | # Need to set up a new cache if the hyperparam or the template changes. 3 | 4 | import os 5 | import pickle, json 6 | 7 | cache_path = "" 8 | cache_format = "json" 9 | 10 | global_cache = {} 11 | 12 | # Always obtain the first item (should be enabled for notebook debugging only) 13 | force_first_item = False 14 | 15 | # The cache records the access times to load more than one value in the cache when the keys repeat. 16 | global_cache_index = {} 17 | # This is for export and debugging the queries 18 | cache_queries = {} 19 | 20 | 21 | def reset_cache_access(): 22 | global global_cache_index, cache_queries 23 | global_cache_index = {} 24 | cache_queries = {} 25 | 26 | 27 | def values_accessed(): 28 | return sum(global_cache_index.values()) 29 | 30 | 31 | def init_cache(allow_nonexist=True): 32 | global global_cache 33 | assert cache_path, "Need to set cache path" 34 | 35 | print(f"Cache path: {cache_path}") 36 | 37 | if not allow_nonexist: 38 | assert os.path.exists(cache_path), f"{cache_path} does not exist" 39 | 40 | if os.path.exists(cache_path): 41 | if cache_format == "pickle": 42 | with open(cache_path, "rb") as f: 43 | global_cache = pickle.load(f) 44 | elif cache_format == "json": 45 | with open(cache_path, "r") as f: 46 | global_cache = json.load(f) 47 | 48 | 49 | def get_cache(key): 50 | if key not in global_cache: 51 | global_cache[key] = [] 52 | 53 | if key not in global_cache_index: 54 | global_cache_index[key] = 0 55 | 56 | current_items = global_cache[key] 57 | current_index = global_cache_index[key] 58 | if len(current_items) > current_index: 59 | if not force_first_item: 60 | global_cache_index[key] += 1 61 | if key not in cache_queries: 62 | cache_queries[key] = [] 63 | cache_queries[key].append(current_items[current_index]) 64 | return current_items[current_index] 65 | 66 | return None 67 | 68 | 69 | def add_cache(key, value): 70 | global_cache_index[key] += 1 71 | global_cache[key].append(value) 72 | 73 | if cache_format == "pickle": 74 | with open(cache_path, "wb") as f: 75 | pickle.dump(global_cache, f) 76 | elif cache_format == "json": 77 | with open(cache_path, "w") as f: 78 | json.dump(global_cache, f, indent=4) 79 | 80 | return value 81 | 82 | 83 | def pkl_to_json(filename): 84 | assert "pkl" in filename, filename 85 | with open(filename, "rb") as f: 86 | cache = pickle.load(f) 87 | del f 88 | filename = filename.replace("pkl", "json") 89 | assert not os.path.exists(filename) 90 | print(cache) 91 | with open(filename, "w") as f: 92 | json.dump(cache, f, indent=4) 93 | -------------------------------------------------------------------------------- /utils/eval/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval import * 2 | -------------------------------------------------------------------------------- /utils/eval/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import parse 3 | 4 | 5 | def nms( 6 | bounding_boxes, 7 | confidence_score, 8 | labels, 9 | threshold, 10 | input_in_pixels=False, 11 | return_array=True, 12 | ): 13 | """ 14 | This NMS processes boxes of all labels. It not only removes the box with the same label. 15 | 16 | Adapted from https://github.com/amusi/Non-Maximum-Suppression/blob/master/nms.py 17 | """ 18 | # If no bounding boxes, return empty list 19 | if len(bounding_boxes) == 0: 20 | return np.array([]), np.array([]), np.array([]) 21 | 22 | # Bounding boxes 23 | boxes = np.array(bounding_boxes) 24 | 25 | # coordinates of bounding boxes 26 | start_x = boxes[:, 0] 27 | start_y = boxes[:, 1] 28 | end_x = boxes[:, 2] 29 | end_y = boxes[:, 3] 30 | 31 | # Confidence scores of bounding boxes 32 | score = np.array(confidence_score) 33 | 34 | # Picked bounding boxes 35 | picked_boxes = [] 36 | picked_score = [] 37 | picked_labels = [] 38 | 39 | # Compute areas of bounding boxes 40 | if input_in_pixels: 41 | areas = (end_x - start_x + 1) * (end_y - start_y + 1) 42 | else: 43 | areas = (end_x - start_x) * (end_y - start_y) 44 | 45 | # Sort by confidence score of bounding boxes 46 | order = np.argsort(score) 47 | 48 | # Iterate bounding boxes 49 | while order.size > 0: 50 | # The index of largest confidence score 51 | index = order[-1] 52 | 53 | # Pick the bounding box with largest confidence score 54 | picked_boxes.append(bounding_boxes[index]) 55 | picked_score.append(confidence_score[index]) 56 | picked_labels.append(labels[index]) 57 | 58 | # Compute ordinates of intersection-over-union(IOU) 59 | x1 = np.maximum(start_x[index], start_x[order[:-1]]) 60 | x2 = np.minimum(end_x[index], end_x[order[:-1]]) 61 | y1 = np.maximum(start_y[index], start_y[order[:-1]]) 62 | y2 = np.minimum(end_y[index], end_y[order[:-1]]) 63 | 64 | # Compute areas of intersection-over-union 65 | if input_in_pixels: 66 | w = np.maximum(0.0, x2 - x1 + 1) 67 | h = np.maximum(0.0, y2 - y1 + 1) 68 | else: 69 | w = np.maximum(0.0, x2 - x1) 70 | h = np.maximum(0.0, y2 - y1) 71 | intersection = w * h 72 | 73 | # Compute the ratio between intersection and union 74 | ratio = intersection / (areas[index] + areas[order[:-1]] - intersection) 75 | 76 | left = np.where(ratio < threshold) 77 | order = order[left] 78 | 79 | if return_array: 80 | picked_boxes, picked_score, picked_labels = ( 81 | np.array(picked_boxes), 82 | np.array(picked_score), 83 | np.array(picked_labels), 84 | ) 85 | 86 | return picked_boxes, picked_score, picked_labels 87 | 88 | 89 | def class_aware_nms( 90 | bounding_boxes, confidence_score, labels, threshold, input_in_pixels=False 91 | ): 92 | """ 93 | This NMS processes boxes of each label individually. 94 | """ 95 | # If no bounding boxes, return empty list 96 | if len(bounding_boxes) == 0: 97 | return np.array([]), np.array([]), np.array([]) 98 | 99 | picked_boxes, picked_score, picked_labels = [], [], [] 100 | 101 | labels_unique = np.unique(labels) 102 | for label in labels_unique: 103 | bounding_boxes_label = [ 104 | bounding_box 105 | for i, bounding_box in enumerate(bounding_boxes) 106 | if labels[i] == label 107 | ] 108 | confidence_score_label = [ 109 | confidence_score_item 110 | for i, confidence_score_item in enumerate(confidence_score) 111 | if labels[i] == label 112 | ] 113 | labels_label = [label] * len(bounding_boxes_label) 114 | picked_boxes_label, picked_score_label, picked_labels_label = nms( 115 | bounding_boxes_label, 116 | confidence_score_label, 117 | labels_label, 118 | threshold=threshold, 119 | input_in_pixels=input_in_pixels, 120 | return_array=False, 121 | ) 122 | picked_boxes += picked_boxes_label 123 | picked_score += picked_score_label 124 | picked_labels += picked_labels_label 125 | 126 | picked_boxes, picked_score, picked_labels = ( 127 | np.array(picked_boxes), 128 | np.array(picked_score), 129 | np.array(picked_labels), 130 | ) 131 | 132 | return picked_boxes, picked_score, picked_labels 133 | 134 | 135 | def evaluate_with_layout( 136 | parsed_layout, predicate, num_parsed_layout_frames, height, width, verbose=False 137 | ): 138 | condition = parse.parsed_layout_to_condition( 139 | parsed_layout, 140 | tokenizer=None, 141 | height=height, 142 | width=width, 143 | num_parsed_layout_frames=num_parsed_layout_frames, 144 | num_condition_frames=num_parsed_layout_frames, 145 | strip_phrases=True, 146 | verbose=True, 147 | ) 148 | 149 | print("condition:", condition) 150 | 151 | prompt_type = predicate.type 152 | success = predicate(condition, verbose=verbose) 153 | 154 | return prompt_type, success 155 | 156 | 157 | def to_gen_box_format(box, width, height, rounding): 158 | # Input: xyxy, ranging from 0 to 1 159 | # Output: xywh, unnormalized (in pixels) 160 | x_min, y_min, x_max, y_max = box 161 | if rounding: 162 | return [ 163 | round(x_min * width), 164 | round(y_min * height), 165 | round((x_max - x_min) * width), 166 | round((y_max - y_min) * height), 167 | ] 168 | return [ 169 | x_min * width, 170 | y_min * height, 171 | (x_max - x_min) * width, 172 | (y_max - y_min) * height, 173 | ] 174 | -------------------------------------------------------------------------------- /utils/eval/lvd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from functools import partial 3 | from .utils import ( 4 | p, 5 | predicate_numeracy, 6 | predicate_attribution, 7 | predicate_visibility, 8 | predicate_1obj_dynamic_spatial, 9 | predicate_2obj_dynamic_spatial, 10 | predicate_sequentialv2, 11 | ) 12 | 13 | prompt_prefix = "A realistic lively video of a scene" 14 | prompt_top_down_prefix = "A realistic lively video of a top-down viewed scene" 15 | 16 | evaluate_classes = [ 17 | ("moving car", "car"), 18 | ("lively cat", "cat"), 19 | ("flying bird", "bird"), 20 | ("moving ball", "ball"), 21 | ("walking dog", "dog"), 22 | ] 23 | evaluate_classes_no_attribute = [ 24 | evaluate_class_no_attribute 25 | for evaluate_class, evaluate_class_no_attribute in evaluate_classes 26 | ] 27 | 28 | 29 | def get_prompt_predicates_numeracy(min_num=1, max_num=5, repeat=2): 30 | modifier = "" 31 | 32 | prompt_predicates = [] 33 | 34 | for number in range(min_num, max_num + 1): 35 | for object_name, object_name_no_attribute in evaluate_classes: 36 | if prompt_prefix: 37 | prompt = f"{prompt_prefix} with {p.number_to_words(number) if number < 21 else number}{modifier} {p.plural(object_name) if number > 1 else object_name}" 38 | else: 39 | prompt = f"{p.number_to_words(number) if number < 21 else number}{modifier} {p.plural(object_name) if number > 1 else object_name}" 40 | prompt = prompt.strip() 41 | 42 | # `query_names` needs to match with `texts` since `query_names` will be searched in the detection of `texts` 43 | query_names = (object_name_no_attribute,) 44 | predicate = partial(predicate_numeracy, query_names, number) 45 | predicate.type = "numeracy" 46 | predicate.texts = [f"a photo of {p.a(object_name_no_attribute)}"] 47 | # We don't have tracking, but mismatch does not matter for numeracy. 48 | predicate.one_box_per_class = False 49 | prompt_predicate = prompt, predicate 50 | 51 | prompt_predicates += [prompt_predicate] * repeat 52 | 53 | return prompt_predicates 54 | 55 | 56 | def process_object_name(object_name): 57 | if isinstance(object_name, tuple): 58 | query_names = object_name 59 | object_name = object_name[0] 60 | else: 61 | query_names = (object_name,) 62 | 63 | return object_name, query_names 64 | 65 | 66 | def get_prompt_predicates_attribution(num_prompts=100, repeat=1): 67 | prompt_predicates = [] 68 | 69 | intended_count1, intended_count2 = 1, 1 70 | 71 | modifiers = [ 72 | "red", 73 | "orange", 74 | "yellow", 75 | "green", 76 | "blue", 77 | "purple", 78 | "pink", 79 | "brown", 80 | "black", 81 | "white", 82 | "gray", 83 | ] 84 | 85 | for ind in range(num_prompts): 86 | np.random.seed(ind) 87 | modifier1, modifier2 = np.random.choice(modifiers, 2, replace=False) 88 | object_name1, object_name2 = np.random.choice( 89 | evaluate_classes_no_attribute, 2, replace=False 90 | ) 91 | 92 | object_name1, query_names1 = process_object_name(object_name1) 93 | object_name2, query_names2 = process_object_name(object_name2) 94 | 95 | if prompt_prefix: 96 | prompt = f"{prompt_prefix} with {p.a(modifier1)} {object_name1} and {p.a(modifier2)} {object_name2}" 97 | else: 98 | prompt = ( 99 | f"{p.a(modifier1)} {object_name1} and {p.a(modifier2)} {object_name2}" 100 | ) 101 | prompt = prompt.strip() 102 | 103 | # `query_names` needs to match with `texts` since `query_names` will be searched in the detection of `texts` 104 | predicate = partial( 105 | predicate_attribution, 106 | query_names1, 107 | query_names2, 108 | modifier1, 109 | modifier2, 110 | intended_count1, 111 | intended_count2, 112 | ) 113 | 114 | prompt_predicate = prompt, predicate 115 | 116 | predicate.type = "attribution" 117 | predicate.texts = [ 118 | f"a photo of {p.a(modifier1)} {object_name1}", 119 | f"a photo of {p.a(modifier2)} {object_name2}", 120 | ] 121 | # Limit to one box per class. 122 | predicate.one_box_per_class = True 123 | 124 | prompt_predicates += [prompt_predicate] * repeat 125 | 126 | return prompt_predicates 127 | 128 | 129 | def get_prompt_predicates_visibility(repeat=2): 130 | prompt_predicates = [] 131 | 132 | for object_name, object_name_no_attribute in evaluate_classes: 133 | # `query_names` needs to match with `texts` since `query_names` will be searched in the detection of `texts` 134 | query_names = (object_name_no_attribute,) 135 | 136 | for i in range(2): 137 | # i == 0: appeared 138 | # i == 1: disappeared 139 | 140 | if i == 0: 141 | prompt = f"{prompt_prefix} in which {p.a(object_name)} appears only in the second half of the video" 142 | # Shouldn't use lambda here since query_names (and number) might change. 143 | predicate = partial(predicate_visibility, query_names, True) 144 | prompt_predicate = prompt, predicate 145 | else: 146 | prompt = f"{prompt_prefix} in which {p.a(object_name)} appears only in the first half of the video" 147 | # Shouldn't use lambda here since query_names (and number) might change. 148 | predicate = partial(predicate_visibility, query_names, False) 149 | prompt_predicate = prompt, predicate 150 | 151 | predicate.type = "visibility" 152 | predicate.texts = [f"a photo of {p.a(object_name_no_attribute)}"] 153 | # Limit to one box per class. 154 | predicate.one_box_per_class = True 155 | 156 | prompt_predicates += [prompt_predicate] * repeat 157 | 158 | return prompt_predicates 159 | 160 | 161 | def get_prompt_predicates_1obj_dynamic_spatial(repeat=1, left_right_only=True): 162 | prompt_predicates = [] 163 | 164 | # NOTE: the boxes are in (x_min, y_min, x_max, y_max) format. This is because LVD uses `condition` rather than `gen_boxes` format as the input. `condition` has processed the coordinates. 165 | locations = [ 166 | ( 167 | "left", 168 | "right", 169 | lambda box1, box2: (box1[0] + box1[2]) / 2 < (box2[0] + box2[2]) / 2, 170 | ), 171 | ( 172 | "right", 173 | "left", 174 | lambda box1, box2: (box1[0] + box1[2]) / 2 > (box2[0] + box2[2]) / 2, 175 | ), 176 | ] 177 | if not left_right_only: 178 | # NOTE: the boxes are in (x_min, y_min, x_max, y_max) format. 179 | locations += [ 180 | ( 181 | "top", 182 | "bottom", 183 | lambda box1, box2: (box1[1] + box1[3]) / 2 < (box2[1] + box2[3]) / 2, 184 | ), 185 | ( 186 | "bottom", 187 | "top", 188 | lambda box1, box2: (box1[1] + box1[3]) / 2 > (box2[1] + box2[3]) / 2, 189 | ), 190 | ] 191 | 192 | # We use object names without motion attributes for spatial since the attribute words may interfere with the intended motion. 193 | for object_name_no_attribute in evaluate_classes_no_attribute: 194 | # `query_names` needs to match with `texts` since `query_names` will be searched in the detection of `texts` 195 | query_names = (object_name_no_attribute,) 196 | 197 | for location1, location2, verify_fn in locations: 198 | prompt = f"{prompt_prefix} with {p.a(object_name_no_attribute)} moving from the {location1} to the {location2}" 199 | prompt = prompt.strip() 200 | 201 | predicate = partial(predicate_1obj_dynamic_spatial, query_names, verify_fn) 202 | prompt_predicate = prompt, predicate 203 | predicate.type = "dynamic_spatial" 204 | predicate.texts = [f"a photo of {p.a(object_name_no_attribute)}"] 205 | # Limit to one box per class. 206 | predicate.one_box_per_class = True 207 | 208 | prompt_predicates += [prompt_predicate] * repeat 209 | 210 | return prompt_predicates 211 | 212 | 213 | def get_prompt_predicates_2obj_dynamic_spatial( 214 | num_prompts=10, repeat=1, left_right_only=True 215 | ): 216 | prompt_predicates = [] 217 | 218 | # NOTE: the boxes are in (x_min, y_min, x_max, y_max) format. This is because LVD uses `condition` rather than `gen_boxes` format as the input. `condition` has processed the coordinates. 219 | locations = [ 220 | ( 221 | "left", 222 | "right", 223 | lambda box1, box2: (box1[0] + box1[2]) / 2 < (box2[0] + box2[2]) / 2, 224 | ), 225 | ( 226 | "right", 227 | "left", 228 | lambda box1, box2: (box1[0] + box1[2]) / 2 > (box2[0] + box2[2]) / 2, 229 | ), 230 | ] 231 | if not left_right_only: 232 | # NOTE: the boxes are in (x_min, y_min, x_max, y_max) format. 233 | locations += [ 234 | ( 235 | "top", 236 | "bottom", 237 | lambda box1, box2: (box1[1] + box1[3]) / 2 < (box2[1] + box2[3]) / 2, 238 | ), 239 | ( 240 | "bottom", 241 | "top", 242 | lambda box1, box2: (box1[1] + box1[3]) / 2 > (box2[1] + box2[3]) / 2, 243 | ), 244 | ] 245 | 246 | # We use object names without motion attributes for spatial since the attribute words may interfere with the intended motion. 247 | for ind in range(num_prompts): 248 | np.random.seed(ind) 249 | for location1, location2, verify_fn in locations: 250 | object_name1, object_name2 = np.random.choice( 251 | evaluate_classes_no_attribute, 2, replace=False 252 | ) 253 | 254 | object_name1, query_names1 = process_object_name(object_name1) 255 | object_name2, query_names2 = process_object_name(object_name2) 256 | 257 | prompt = f"{prompt_prefix} with {p.a(object_name1)} moving from the {location1} of {p.a(object_name2)} to its {location2}" 258 | prompt = prompt.strip() 259 | 260 | # `query_names` needs to match with `texts` since `query_names` will be searched in the detection of `texts` 261 | predicate = partial( 262 | predicate_2obj_dynamic_spatial, query_names1, query_names2, verify_fn 263 | ) 264 | prompt_predicate = prompt, predicate 265 | predicate.type = "dynamic_spatial" 266 | predicate.texts = [ 267 | f"a photo of {p.a(object_name1)}", 268 | f"a photo of {p.a(object_name2)}", 269 | ] 270 | # Limit to one box per class. 271 | predicate.one_box_per_class = True 272 | prompt_predicates += [prompt_predicate] * repeat 273 | 274 | return prompt_predicates 275 | 276 | 277 | def get_prompt_predicates_sequential(repeat=1): 278 | prompt_predicates = [] 279 | 280 | locations = [ 281 | ("lower left", "lower right", "upper right"), 282 | ("lower left", "upper left", "upper right"), 283 | ("lower right", "lower left", "upper left"), 284 | ("lower right", "upper right", "upper left"), 285 | ] 286 | verify_fns = { 287 | # lower: y is large 288 | "lower left": lambda box: (box[1] + box[3]) / 2 > 0.5 289 | and (box[0] + box[2]) / 2 < 0.5, 290 | "lower right": lambda box: (box[1] + box[3]) / 2 > 0.5 291 | and (box[0] + box[2]) / 2 > 0.5, 292 | "upper left": lambda box: (box[1] + box[3]) / 2 < 0.5 293 | and (box[0] + box[2]) / 2 < 0.5, 294 | "upper right": lambda box: (box[1] + box[3]) / 2 < 0.5 295 | and (box[0] + box[2]) / 2 > 0.5, 296 | } 297 | 298 | for object_name_no_attribute in evaluate_classes_no_attribute: 299 | # `query_names` needs to match with `texts` since `query_names` will be searched in the detection of `texts` 300 | query_names = (object_name_no_attribute,) 301 | 302 | for location1, location2, location3 in locations: 303 | # We check the appearance/disappearance in addition to whether the object is on the right side in the last frame compared to the initial frame. 304 | prompt = f"{prompt_top_down_prefix} in which {p.a(object_name_no_attribute)} initially on the {location1} of the scene. It first moves to the {location2} of the scene and then moves to the {location3} of the scene." 305 | 306 | # Shouldn't use lambda here since query_names (and number) might change. 307 | predicate = partial( 308 | predicate_sequentialv2, 309 | query_names, 310 | verify_fns[location1], 311 | verify_fns[location2], 312 | verify_fns[location3], 313 | ) 314 | # predicate = partial(predicate_sequentialv2, query_names, verify_fn1, verify_fn2) 315 | prompt_predicate = prompt, predicate 316 | predicate.type = "sequential" 317 | predicate.texts = [f"a photo of {p.a(object_name_no_attribute)}"] 318 | # Limit to one box per class. 319 | predicate.one_box_per_class = True 320 | prompt_predicates += [prompt_predicate] * repeat 321 | 322 | return prompt_predicates 323 | 324 | 325 | def get_lvd_full_prompt_predicates(prompt_type=None): 326 | # numeracy: 100 prompts, number 1 to 4, 5 classes, repeat 5 times 327 | prompt_predicates_numeracy = get_prompt_predicates_numeracy(max_num=4, repeat=5) 328 | # attribution: 100 prompts, two objects in each prompt, each with attributes (randomly sampled) 329 | prompt_predicates_attribution = get_prompt_predicates_attribution(num_prompts=100) 330 | # visibility: 100 prompts, 5 classes, appear/disappear, repeat 10 times 331 | prompt_predicates_visibility = get_prompt_predicates_visibility(repeat=10) 332 | # dynamic spatial: 100 prompts 333 | # 1 object: 50 prompts, 5 classes, left/right, repeat 5 times 334 | # 2 objects: 50 prompts, randomly sample two objects 25 times, left/right 335 | prompt_predicates_1obj_dynamic_spatial = get_prompt_predicates_1obj_dynamic_spatial( 336 | repeat=5 337 | ) 338 | prompt_predicates_2obj_dynamic_spatial = get_prompt_predicates_2obj_dynamic_spatial( 339 | num_prompts=25 340 | ) 341 | prompt_predicates_dynamic_spatial = ( 342 | prompt_predicates_1obj_dynamic_spatial + prompt_predicates_2obj_dynamic_spatial 343 | ) 344 | # sequential: 100 prompts, 5 classes, 4 location triplets, repeat 5 times 345 | prompt_predicates_sequential = get_prompt_predicates_sequential(repeat=5) 346 | 347 | prompt_predicates_static_all = ( 348 | prompt_predicates_numeracy + prompt_predicates_attribution 349 | ) 350 | prompts_predicates_dynamic_all = ( 351 | prompt_predicates_visibility 352 | + prompt_predicates_dynamic_spatial 353 | + prompt_predicates_sequential 354 | ) 355 | 356 | # Each one has 100 prompts 357 | prompt_predicates_all = ( 358 | prompt_predicates_numeracy 359 | + prompt_predicates_attribution 360 | + prompt_predicates_visibility 361 | + prompt_predicates_dynamic_spatial 362 | + prompt_predicates_sequential 363 | ) 364 | 365 | prompt_predicates = { 366 | "lvd": prompt_predicates_all, 367 | "lvd_static": prompt_predicates_static_all, 368 | "lvd_numeracy": prompt_predicates_numeracy, 369 | "lvd_attribution": prompt_predicates_attribution, 370 | "lvd_dynamic": prompts_predicates_dynamic_all, 371 | "lvd_dynamic_spatial": prompt_predicates_dynamic_spatial, 372 | "lvd_visibility": prompt_predicates_visibility, 373 | "lvd_sequential": prompt_predicates_sequential, 374 | } 375 | 376 | if prompt_type is not None: 377 | return prompt_predicates[prompt_type] 378 | else: 379 | return prompt_predicates 380 | 381 | 382 | def get_lvd_full_prompts(prompt_type): 383 | prompt_predicates = get_lvd_full_prompt_predicates(prompt_type) 384 | if prompt_type is not None: 385 | return [item[0] for item in prompt_predicates] 386 | else: 387 | return {k: [item[0] for item in v] for k, v in prompt_predicates.items()} 388 | 389 | 390 | if __name__ == "__main__": 391 | prompt_predicates = get_lvd_full_prompt_predicates("lvdv1.1") 392 | 393 | print( 394 | np.unique( 395 | [predicate.type for prompt, predicate in prompt_predicates], 396 | return_counts=True, 397 | ) 398 | ) 399 | print(len(prompt_predicates)) 400 | -------------------------------------------------------------------------------- /utils/eval/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import inflect 3 | import re 4 | from utils.parse import Condition 5 | 6 | p = inflect.engine() 7 | 8 | def find_word_after(text, word): 9 | pattern = r"\b" + re.escape(word) + r"\s+(.+)" 10 | match = re.search(pattern, text) 11 | if match: 12 | return match.group(1) 13 | else: 14 | return None 15 | 16 | 17 | word_to_num_mapping = {p.number_to_words(i): i for i in range(1, 21)} 18 | 19 | locations_xyxy = { 20 | ("left", "right"): (lambda box1, box2: (box1[0] + box1[2]) < (box2[0] + box2[2])), 21 | ("right", "left"): (lambda box1, box2: (box1[0] + box1[2]) > (box2[0] + box2[2])), 22 | ("top", "bottom"): (lambda box1, box2: (box1[1] + box1[3]) < (box2[1] + box2[3])), 23 | ("bottom", "top"): (lambda box1, box2: (box1[1] + box1[3]) > (box2[1] + box2[3])), 24 | } 25 | 26 | locations_xywh = { 27 | ("left", "right"): ( 28 | lambda box1, box2: box1[0] + box1[2] / 2 < box2[0] + box2[2] / 2 29 | ), 30 | ("right", "left"): ( 31 | lambda box1, box2: box1[0] + box1[2] / 2 > box2[0] + box2[2] / 2 32 | ), 33 | ("top", "bottom"): ( 34 | lambda box1, box2: box1[1] + box1[3] / 2 < box2[1] + box2[3] / 2 35 | ), 36 | ("bottom", "top"): ( 37 | lambda box1, box2: box1[1] + box1[3] / 2 > box2[1] + box2[3] / 2 38 | ), 39 | } 40 | 41 | 42 | def singular(noun): 43 | singular_noun = p.singular_noun(noun) 44 | if singular_noun is False: 45 | return noun 46 | return singular_noun 47 | 48 | 49 | def get_box(condition: Condition, name_include): 50 | # This prevents substring match on non-word boundaries: carrot vs car 51 | box_match = [ 52 | any( 53 | [ 54 | ( 55 | (name_include_item + " ") in phrase 56 | or phrase.endswith(name_include_item) 57 | ) 58 | for name_include_item in name_include 59 | ] 60 | ) 61 | for phrase in condition.phrases 62 | ] 63 | 64 | if not any(box_match): 65 | return None 66 | 67 | boxes = condition.boxes 68 | box_ind = np.min(np.where(box_match)[0]) 69 | 70 | return boxes[box_ind] 71 | 72 | 73 | def get_box_counts(condition): 74 | if len(condition.boxes) == 0: 75 | # No boxes 76 | return None 77 | 78 | box_counts = None 79 | 80 | for i, box in enumerate(condition.boxes): 81 | if i == 0: 82 | num_frames = len(box) 83 | box_counts = [0 for _ in range(num_frames)] 84 | else: 85 | assert num_frames == len(box), f"{num_frames} != {len(box)}" 86 | valid_frames = box_to_valid_frames(box) 87 | 88 | for frame_index, valid in enumerate(valid_frames): 89 | if valid: 90 | box_counts[frame_index] += 1 91 | 92 | return box_counts 93 | 94 | 95 | def predicate_numeracy(query_names, intended_count, condition, verbose=False): 96 | assert len(query_names) == 1 97 | name_include = query_names 98 | box_match = [ 99 | any( 100 | [ 101 | ( 102 | (name_include_item + " ") in phrase 103 | or phrase.endswith(name_include_item) 104 | ) 105 | for name_include_item in name_include 106 | ] 107 | ) 108 | for phrase in condition.phrases 109 | ] 110 | 111 | # We do not have tracking in stage 2 evaluation, so let's put this assertion to be safe. 112 | # This could only be a problem for stage 1 where additional non-relevant boxes are generated, but so far we did not see additional boxes in stage 1. 113 | assert len(box_match) == len( 114 | condition.boxes 115 | ), "Currently do not support the case where other boxes are also generated" 116 | 117 | box_counts = get_box_counts(condition) 118 | 119 | if box_counts is None: 120 | majority_box_counts = 0 121 | else: 122 | majority_box_counts = np.bincount(box_counts).argmax() 123 | 124 | object_count = majority_box_counts 125 | if verbose: 126 | print( 127 | f"box_counts: {box_counts}, object_count: {object_count}, intended_count: {intended_count} (condition: {condition}, query_names: {query_names})" 128 | ) 129 | 130 | success = object_count == intended_count 131 | 132 | return success 133 | 134 | 135 | def box_to_valid_frames(object_box): 136 | object_box = np.array(object_box) 137 | x, y, w, h = object_box[:, 0], object_box[:, 1], object_box[:, 2], object_box[:, 3] 138 | # If the box has 0 width or height, it is not valid. 139 | valid_frames = (w != 0) & (h != 0) 140 | 141 | return valid_frames 142 | 143 | 144 | def predicate_visibility(query_names, test_appearance, condition, verbose=False): 145 | # condition: dict with keys 'name' and 'bounding_box' 146 | 147 | object_box = get_box(condition, query_names) 148 | if not object_box: 149 | return False 150 | 151 | valid_frames = box_to_valid_frames(object_box) 152 | 153 | num_frames = len(valid_frames) 154 | first_half_index = num_frames // 2 155 | 156 | # Ignore the two frames in the middle since there may be discrepancies between the LLM's understanding of the middle frame and the middle frame after interpolation (in generation) and sampling (in evaluation). 157 | valid_frames_first_half, valid_frames_second_half = ( 158 | valid_frames[: first_half_index - 1], 159 | valid_frames[first_half_index + 1 :], 160 | ) 161 | present_in_first_half, present_in_second_half = ( 162 | any(valid_frames_first_half), 163 | any(valid_frames_second_half), 164 | ) 165 | 166 | if test_appearance: 167 | # Test appearing: we ensure the object is not in the first half but needs to be present in the second half. 168 | success = (not present_in_first_half) and present_in_second_half 169 | else: 170 | # Test disappearing: we ensure the object is in the first half but needs to be absent in the second half. 171 | success = present_in_first_half and (not present_in_second_half) 172 | 173 | if verbose: 174 | print( 175 | f"Test appearance: {test_appearance}, valid_frames: {valid_frames}, appeared at first half: {present_in_first_half}, appeared at second half: {present_in_second_half}" 176 | ) 177 | 178 | return success 179 | 180 | 181 | def predicate_attribution( 182 | query_names1, 183 | query_names2, 184 | modifier1, 185 | modifier2, 186 | intended_count1, 187 | intended_count2, 188 | condition, 189 | verbose=False, 190 | ): 191 | # Attribution does not use count now 192 | assert intended_count1 == 1 and intended_count2 == 1 193 | 194 | if modifier1: 195 | query_names1 = [f"{modifier1} {item}" for item in query_names1] 196 | object_box1 = get_box(condition, name_include=query_names1) 197 | 198 | if object_box1 is None: 199 | return False 200 | 201 | valid_frames1 = box_to_valid_frames(object_box1) 202 | if valid_frames1.mean() < 0.5: 203 | # Not detected at more than half of frames 204 | return False 205 | 206 | if query_names2 is None: 207 | # Only one object 208 | return True 209 | 210 | if modifier2: 211 | query_names2 = [f"{modifier2} {item}" for item in query_names2] 212 | object_box2 = get_box(condition, name_include=query_names2) 213 | 214 | if object_box2 is None: 215 | return False 216 | 217 | valid_frames2 = box_to_valid_frames(object_box2) 218 | if valid_frames2.mean() < 0.5: 219 | # Not detected at more than half of frames 220 | return False 221 | 222 | if verbose: 223 | print(f"Object box 1: {object_box1}, Object box 2: {object_box2}") 224 | 225 | return True 226 | 227 | 228 | def predicate_1obj_dynamic_spatial(query_names, verify_fn, condition, verbose=False): 229 | object_box = get_box(condition, query_names) 230 | if not object_box: 231 | return False 232 | 233 | valid_frames = box_to_valid_frames(object_box) 234 | if not valid_frames[0] or not valid_frames[-1]: 235 | return False 236 | 237 | # For example, from the left to the right: object in the first frame is on the left compared to the object in the last frame 238 | success = verify_fn(object_box[0], object_box[-1]) 239 | 240 | return success 241 | 242 | 243 | def predicate_2obj_dynamic_spatial( 244 | query_names1, query_names2, verify_fn, condition, verbose=False 245 | ): 246 | object_box1 = get_box(condition, query_names1) 247 | object_box2 = get_box(condition, query_names2) 248 | 249 | if verbose: 250 | print(f"object_box1: {object_box1}, object_box2: {object_box2}") 251 | 252 | if not object_box1 or not object_box2: 253 | return False 254 | 255 | valid_frames1 = box_to_valid_frames(object_box1) 256 | valid_frames2 = box_to_valid_frames(object_box2) 257 | if ( 258 | not valid_frames1[0] 259 | or not valid_frames2[0] 260 | or not valid_frames1[-1] 261 | or not valid_frames2[-1] 262 | ): 263 | return False 264 | 265 | # For example, `object 1` moving from the left of `object 2` to the right: object 1 in the first frame is on the left compared to object 1 in the first frame; object 1 in the last frame is on the right compared to object 2 in the last frame 266 | success1 = verify_fn(object_box1[0], object_box2[0]) 267 | success2 = verify_fn(object_box2[-1], object_box1[-1]) 268 | success = success1 and success2 269 | 270 | return success 271 | 272 | 273 | def predicate_sequentialv2( 274 | query_names, verify_fn1, verify_fn2, verify_fn3, condition, verbose=False 275 | ): 276 | # condition: dict with keys 'name' and 'bounding_box' 277 | 278 | object_box = get_box(condition, query_names) 279 | if verbose: 280 | print(f"object_box: {object_box}") 281 | 282 | if not object_box: 283 | return False 284 | 285 | valid_frames = box_to_valid_frames(object_box) 286 | if verbose: 287 | print(f"valid_frames: {valid_frames}") 288 | 289 | num_frames = len(valid_frames) 290 | middle_frame_index = num_frames // 2 291 | 292 | # Need to be present in the first, the middle, and the last frame 293 | if ( 294 | not valid_frames[0] 295 | or not valid_frames[middle_frame_index] 296 | or not valid_frames[-1] 297 | ): 298 | return False 299 | 300 | # Need to be on the right place in the first, middle, and last frame 301 | success1 = verify_fn1(object_box[0]) 302 | success2 = verify_fn2(object_box[middle_frame_index]) 303 | success3 = verify_fn3(object_box[-1]) 304 | 305 | if verbose: 306 | print( 307 | f"success1: {success1} ({object_box[0]}), success2: {success2} ({object_box[middle_frame_index]}), success3: {success3} ({object_box[-1]})" 308 | ) 309 | 310 | success = success1 and success2 and success3 311 | return success 312 | -------------------------------------------------------------------------------- /utils/latents.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from . import utils 4 | from utils import torch_device 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def get_unscaled_latents(batch_size, in_channels, height, width, generator, dtype): 9 | """ 10 | in_channels: often obtained with `unet.config.in_channels` 11 | """ 12 | # Obtain with torch.float32 and cast to float16 if needed 13 | # Directly obtaining latents in float16 will lead to different latents 14 | latents_base = torch.randn( 15 | (batch_size, in_channels, height // 8, width // 8), 16 | generator=generator, 17 | dtype=dtype, 18 | ).to(torch_device, dtype=dtype) 19 | 20 | return latents_base 21 | 22 | 23 | def get_scaled_latents( 24 | batch_size, in_channels, height, width, generator, dtype, scheduler 25 | ): 26 | latents_base = get_unscaled_latents( 27 | batch_size, in_channels, height, width, generator, dtype 28 | ) 29 | latents_base = latents_base * scheduler.init_noise_sigma 30 | return latents_base 31 | 32 | 33 | def blend_latents(latents_bg, latents_fg, fg_mask, fg_blending_ratio=0.01): 34 | """ 35 | in_channels: often obtained with `unet.config.in_channels` 36 | """ 37 | assert not torch.allclose( 38 | latents_bg, latents_fg 39 | ), "latents_bg should be independent with latents_fg" 40 | 41 | dtype = latents_bg.dtype 42 | latents = ( 43 | latents_bg * (1.0 - fg_mask) 44 | + ( 45 | latents_bg * np.sqrt(1.0 - fg_blending_ratio) 46 | + latents_fg * np.sqrt(fg_blending_ratio) 47 | ) 48 | * fg_mask 49 | ) 50 | latents = latents.to(dtype=dtype) 51 | 52 | return latents 53 | 54 | 55 | @torch.no_grad() 56 | def compose_latents( 57 | model_dict, 58 | latents_all_list, 59 | mask_tensor_list, 60 | num_inference_steps, 61 | overall_batch_size, 62 | height, 63 | width, 64 | latents_bg=None, 65 | bg_seed=None, 66 | compose_box_to_bg=True, 67 | use_fast_schedule=False, 68 | fast_after_steps=None, 69 | ): 70 | unet, scheduler, dtype = model_dict.unet, model_dict.scheduler, model_dict.dtype 71 | 72 | if latents_bg is None: 73 | generator = torch.manual_seed( 74 | bg_seed 75 | ) # Seed generator to create the inital latent noise 76 | latents_bg = get_scaled_latents( 77 | overall_batch_size, 78 | unet.config.in_channels, 79 | height, 80 | width, 81 | generator, 82 | dtype, 83 | scheduler, 84 | ) 85 | 86 | # Other than t=T (idx=0), we only have masked latents. This is to prevent accidentally loading from non-masked part. Use same mask as the one used to compose the latents. 87 | if use_fast_schedule: 88 | # If we use fast schedule, we only compose the frozen steps because the later steps do not match. 89 | composed_latents = torch.zeros( 90 | (fast_after_steps + 1, *latents_bg.shape), dtype=dtype 91 | ) 92 | else: 93 | # Otherwise we compose all steps so that we don't need to compose again if we change the frozen steps. 94 | composed_latents = torch.zeros( 95 | (num_inference_steps + 1, *latents_bg.shape), dtype=dtype 96 | ) 97 | composed_latents[0] = latents_bg 98 | 99 | foreground_indices = torch.zeros(latents_bg.shape[-2:], dtype=torch.long) 100 | 101 | mask_size = np.array([mask_tensor.sum().item() for mask_tensor in mask_tensor_list]) 102 | # Compose the largest mask first 103 | mask_order = np.argsort(-mask_size) 104 | 105 | if compose_box_to_bg: 106 | # This has two functionalities: 107 | # 1. copies the right initial latents from the right place (for centered so generation), 2. copies the right initial latents (since we have foreground blending) for centered/original so generation. 108 | for mask_idx in mask_order: 109 | latents_all, mask_tensor = ( 110 | latents_all_list[mask_idx], 111 | mask_tensor_list[mask_idx], 112 | ) 113 | 114 | # Note: need to be careful to not copy from zeros due to shifting. 115 | mask_tensor = utils.binary_mask_to_box_mask(mask_tensor, to_device=False) 116 | 117 | mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype) 118 | composed_latents[0] = ( 119 | composed_latents[0] * (1.0 - mask_tensor_expanded) 120 | + latents_all[0] * mask_tensor_expanded 121 | ) 122 | 123 | # This is still needed with `compose_box_to_bg` to ensure the foreground latent is still visible and to compute foreground indices. 124 | for mask_idx in mask_order: 125 | latents_all, mask_tensor = ( 126 | latents_all_list[mask_idx], 127 | mask_tensor_list[mask_idx], 128 | ) 129 | foreground_indices = ( 130 | foreground_indices * (~mask_tensor) + (mask_idx + 1) * mask_tensor 131 | ) 132 | mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype) 133 | if use_fast_schedule: 134 | composed_latents = ( 135 | composed_latents * (1.0 - mask_tensor_expanded) 136 | + latents_all[: fast_after_steps + 1] * mask_tensor_expanded 137 | ) 138 | else: 139 | composed_latents = ( 140 | composed_latents * (1.0 - mask_tensor_expanded) 141 | + latents_all * mask_tensor_expanded 142 | ) 143 | 144 | composed_latents, foreground_indices = ( 145 | composed_latents.to(torch_device), 146 | foreground_indices.to(torch_device), 147 | ) 148 | return composed_latents, foreground_indices 149 | 150 | 151 | def align_with_bboxes( 152 | latents_all_list, mask_tensor_list, bboxes, horizontal_shift_only=False 153 | ): 154 | """ 155 | Each offset in `offset_list` is `(x_offset, y_offset)` (normalized). 156 | """ 157 | new_latents_all_list, new_mask_tensor_list, offset_list = [], [], [] 158 | for latents_all, mask_tensor, bbox in zip( 159 | latents_all_list, mask_tensor_list, bboxes 160 | ): 161 | x_src_center, y_src_center = utils.binary_mask_to_center( 162 | mask_tensor, normalize=True 163 | ) 164 | x_min_dest, y_min_dest, x_max_dest, y_max_dest = bbox 165 | x_dest_center, y_dest_center = ( 166 | (x_min_dest + x_max_dest) / 2, 167 | (y_min_dest + y_max_dest) / 2, 168 | ) 169 | # print("src (x,y):", x_src_center, y_src_center, "dest (x,y):", x_dest_center, y_dest_center) 170 | x_offset, y_offset = x_dest_center - x_src_center, y_dest_center - y_src_center 171 | if horizontal_shift_only: 172 | y_offset = 0.0 173 | offset = x_offset, y_offset 174 | latents_all = utils.shift_tensor( 175 | latents_all, x_offset, y_offset, offset_normalized=True 176 | ) 177 | mask_tensor = utils.shift_tensor( 178 | mask_tensor, x_offset, y_offset, offset_normalized=True 179 | ) 180 | new_latents_all_list.append(latents_all) 181 | new_mask_tensor_list.append(mask_tensor) 182 | offset_list.append(offset) 183 | 184 | return new_latents_all_list, new_mask_tensor_list, offset_list 185 | 186 | 187 | @torch.no_grad() 188 | def compose_latents_with_alignment( 189 | model_dict, 190 | latents_all_list, 191 | mask_tensor_list, 192 | num_inference_steps, 193 | overall_batch_size, 194 | height, 195 | width, 196 | align_with_overall_bboxes=True, 197 | overall_bboxes=None, 198 | horizontal_shift_only=False, 199 | **kwargs, 200 | ): 201 | if align_with_overall_bboxes and len(latents_all_list): 202 | expanded_overall_bboxes = utils.expand_overall_bboxes(overall_bboxes) 203 | latents_all_list, mask_tensor_list, offset_list = align_with_bboxes( 204 | latents_all_list, 205 | mask_tensor_list, 206 | bboxes=expanded_overall_bboxes, 207 | horizontal_shift_only=horizontal_shift_only, 208 | ) 209 | else: 210 | offset_list = [(0.0, 0.0) for _ in range(len(latents_all_list))] 211 | composed_latents, foreground_indices = compose_latents( 212 | model_dict, 213 | latents_all_list, 214 | mask_tensor_list, 215 | num_inference_steps, 216 | overall_batch_size, 217 | height, 218 | width, 219 | **kwargs, 220 | ) 221 | return composed_latents, foreground_indices, offset_list 222 | 223 | 224 | def get_input_latents_list( 225 | model_dict, 226 | bg_seed, 227 | fg_seed_start, 228 | fg_blending_ratio, 229 | height, 230 | width, 231 | so_prompt_phrase_box_list=None, 232 | so_boxes=None, 233 | verbose=False, 234 | ): 235 | """ 236 | Note: the returned input latents are scaled by `scheduler.init_noise_sigma` 237 | 238 | fg_seed_start: int or list. If int, `fg_seed = fg_seed_start + idx`. If list, `fg_seed = fg_seed_start[idx]`. 239 | """ 240 | unet, scheduler, dtype = model_dict.unet, model_dict.scheduler, model_dict.dtype 241 | 242 | generator_bg = torch.manual_seed( 243 | bg_seed 244 | ) # Seed generator to create the inital latent noise 245 | latents_bg = get_unscaled_latents( 246 | batch_size=1, 247 | in_channels=unet.config.in_channels, 248 | height=height, 249 | width=width, 250 | generator=generator_bg, 251 | dtype=dtype, 252 | ) 253 | 254 | input_latents_list = [] 255 | 256 | if so_boxes is None: 257 | # For compatibility 258 | so_boxes = [item[-1] for item in so_prompt_phrase_box_list] 259 | 260 | # change this changes the foreground initial noise 261 | for idx, obj_box in enumerate(so_boxes): 262 | H, W = height // 8, width // 8 263 | fg_mask = utils.proportion_to_mask(obj_box, H, W) 264 | 265 | if verbose >= 2: 266 | plt.imshow(fg_mask.cpu().numpy()) 267 | plt.show() 268 | 269 | if isinstance(fg_seed_start, list): 270 | fg_seed = fg_seed_start[idx] 271 | else: 272 | fg_seed = fg_seed_start + idx 273 | assert ( 274 | bg_seed != fg_seed 275 | ), f"Need to use different seeds for bg and fg: fg_seed ({fg_seed}) and bg_seed ({bg_seed})" 276 | 277 | generator_fg = torch.manual_seed(fg_seed) 278 | latents_fg = get_unscaled_latents( 279 | batch_size=1, 280 | in_channels=unet.config.in_channels, 281 | height=height, 282 | width=width, 283 | generator=generator_fg, 284 | dtype=dtype, 285 | ) 286 | 287 | fg_blending_ratio_item = ( 288 | fg_blending_ratio[idx] 289 | if isinstance(fg_blending_ratio, list) 290 | else fg_blending_ratio 291 | ) 292 | input_latents = blend_latents( 293 | latents_bg, latents_fg, fg_mask, fg_blending_ratio=fg_blending_ratio_item 294 | ) 295 | 296 | input_latents = input_latents * scheduler.init_noise_sigma 297 | 298 | input_latents_list.append(input_latents) 299 | 300 | latents_bg = latents_bg * scheduler.init_noise_sigma 301 | 302 | return input_latents_list, latents_bg 303 | -------------------------------------------------------------------------------- /utils/llm.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from prompt import templates, stop, required_lines, required_lines_ast 3 | from easydict import EasyDict 4 | from utils.cache import get_cache, add_cache 5 | import ast 6 | import traceback 7 | import time 8 | import pyjson5 9 | 10 | model_names = [ 11 | "vicuna", 12 | "vicuna-13b", 13 | "vicuna-13b-v1.3", 14 | "vicuna-33b-v1.3", 15 | "Llama-2-7b-hf", 16 | "Llama-2-13b-hf", 17 | "Llama-2-70b-hf", 18 | "FreeWilly2", 19 | "gpt-3.5-turbo", 20 | "gpt-3.5", 21 | "gpt-4", 22 | "gpt-4-1106-preview", 23 | ] 24 | 25 | 26 | def get_full_chat_prompt(template, prompt, suffix=None, query_prefix="Caption: "): 27 | # query_prefix should be "Caption: " to match with the prompt. This was fixed when template v1.7.2 is in use. 28 | if isinstance(template, str): 29 | full_prompt = [ 30 | {"role": "system", "content": "You are a helpful assistant."}, 31 | { 32 | "role": "user", 33 | "content": get_full_prompt(template, prompt, suffix).strip(), 34 | }, 35 | ] 36 | else: 37 | print("**Using chat prompt**") 38 | assert suffix is None 39 | full_prompt = [*template, {"role": "user", "content": query_prefix + prompt}] 40 | return full_prompt 41 | 42 | 43 | def get_full_prompt(template, prompt, suffix=None): 44 | assert isinstance(template, str), "Chat template requires `get_full_chat_prompt`" 45 | full_prompt = template.replace("{prompt}", prompt) 46 | if suffix: 47 | full_prompt = full_prompt.strip() + suffix 48 | return full_prompt 49 | 50 | 51 | def get_full_model_name(model): 52 | if model == "gpt-3.5": 53 | model = "gpt-3.5-turbo" 54 | elif model == "vicuna": 55 | model = "vicuna-13b" 56 | elif model == "gpt-4": 57 | model = "gpt-4" 58 | 59 | return model 60 | 61 | 62 | def get_llm_kwargs(model, template_version): 63 | model = get_full_model_name(model) 64 | 65 | print(f"Using template: {template_version}") 66 | 67 | template = templates[template_version] 68 | 69 | if ( 70 | "vicuna" in model.lower() 71 | or "llama" in model.lower() 72 | or "freewilly" in model.lower() 73 | ): 74 | api_base = "http://localhost:8000/v1" 75 | max_tokens = 900 76 | temperature = 0.25 77 | headers = {} 78 | else: 79 | from utils.api_key import api_key 80 | 81 | api_base = "https://api.openai.com/v1" 82 | max_tokens = 900 83 | temperature = 0.25 84 | headers = {"Authorization": f"Bearer {api_key}"} 85 | 86 | llm_kwargs = EasyDict( 87 | model=model, 88 | template=template, 89 | api_base=api_base, 90 | max_tokens=max_tokens, 91 | temperature=temperature, 92 | headers=headers, 93 | stop=stop, 94 | ) 95 | 96 | return model, llm_kwargs 97 | 98 | 99 | def get_layout(prompt, llm_kwargs, suffix="", query_prefix="Caption: ", verbose=False): 100 | # No cache in this function 101 | model, template, api_base, max_tokens, temperature, stop, headers = ( 102 | llm_kwargs.model, 103 | llm_kwargs.template, 104 | llm_kwargs.api_base, 105 | llm_kwargs.max_tokens, 106 | llm_kwargs.temperature, 107 | llm_kwargs.stop, 108 | llm_kwargs.headers, 109 | ) 110 | 111 | if verbose: 112 | print("prompt:", prompt, "with suffix", suffix) 113 | 114 | done = False 115 | attempts = 0 116 | while not done: 117 | if "gpt" in model: 118 | r = requests.post( 119 | f"{api_base}/chat/completions", 120 | json={ 121 | "model": model, 122 | "messages": get_full_chat_prompt( 123 | template, prompt, suffix, query_prefix=query_prefix 124 | ), 125 | "max_tokens": max_tokens, 126 | "temperature": temperature, 127 | "stop": stop if isinstance(template, str) else None, 128 | }, 129 | headers=headers, 130 | ) 131 | else: 132 | r = requests.post( 133 | f"{api_base}/completions", 134 | json={ 135 | "model": model, 136 | "prompt": get_full_prompt(template, prompt, suffix).strip(), 137 | "max_tokens": max_tokens, 138 | "temperature": temperature, 139 | "stop": stop, 140 | }, 141 | headers=headers, 142 | ) 143 | 144 | done = r.status_code == 200 145 | 146 | if not done: 147 | print(r.json()) 148 | attempts += 1 149 | if attempts >= 3 and "gpt" in model: 150 | print("Retrying after 1 minute") 151 | time.sleep(60) 152 | if attempts >= 5 and "gpt" in model: 153 | print("Exiting due to many non-successful attempts") 154 | exit() 155 | 156 | if "gpt" in model: 157 | if verbose > 1: 158 | print(f"***{r.json()}***") 159 | response = r.json()["choices"][0]["message"]["content"] 160 | else: 161 | response = r.json()["choices"][0]["text"] 162 | 163 | if verbose: 164 | print("resp", response) 165 | 166 | return response 167 | 168 | 169 | def get_parsed_layout(*args, json_template=False, **kwargs): 170 | if json_template: 171 | return get_parsed_layout_json_resp(*args, **kwargs) 172 | else: 173 | return get_parsed_layout_text_resp(*args, **kwargs) 174 | 175 | 176 | def get_parsed_layout_text_resp( 177 | prompt, 178 | llm_kwargs=None, 179 | max_partial_response_retries=1, 180 | override_response=None, 181 | strip_chars=" \t\n`", 182 | save_leading_text=True, 183 | **kwargs, 184 | ): 185 | """ 186 | override_response: override the LLM response (will not query the LLM), useful for parsing existing response 187 | """ 188 | if override_response is not None: 189 | assert ( 190 | max_partial_response_retries == 1 191 | ), "override_response is specified so no partial queries are allowed" 192 | 193 | process_index = 0 194 | retries = 0 195 | suffix = None 196 | parsed_layout = {} 197 | reconstructed_response = "" 198 | while process_index < len(required_lines): 199 | retries += 1 200 | if retries > max_partial_response_retries: 201 | raise ValueError( 202 | f"Erroring due to many non-successful attempts on prompt: {prompt} with response {response}" 203 | ) 204 | if override_response is not None: 205 | response = override_response 206 | else: 207 | response = get_layout( 208 | prompt, llm_kwargs=llm_kwargs, suffix=suffix, **kwargs 209 | ) 210 | # print(f"Current response: {response}") 211 | if required_lines[process_index] in response: 212 | response_split = response.split(required_lines[process_index]) 213 | 214 | # print(f"Unused leading text: {response_split[0]}") 215 | 216 | if save_leading_text: 217 | reconstructed_response += ( 218 | response_split[0] + required_lines[process_index] 219 | ) 220 | response = response_split[1] 221 | 222 | while process_index < len(required_lines): 223 | required_line = required_lines[process_index] 224 | next_required_line = ( 225 | required_lines[process_index + 1] 226 | if process_index + 1 < len(required_lines) 227 | else "" 228 | ) 229 | if next_required_line in response: 230 | if next_required_line != "": 231 | required_line_idx = response.find(next_required_line) 232 | line_content = response[:required_line_idx].strip(strip_chars) 233 | else: 234 | line_content = response.strip(strip_chars) 235 | if required_lines_ast[process_index]: 236 | # LLMs sometimes give comments starting with " - " 237 | line_content = line_content.split(" - ")[0].strip() 238 | 239 | # LLMs sometimes give a list: `- name: content` 240 | if line_content.startswith("-"): 241 | line_content = line_content[ 242 | line_content.find("-") + 1 : 243 | ].strip() 244 | 245 | try: 246 | line_content = ast.literal_eval(line_content) 247 | except SyntaxError as e: 248 | print( 249 | f"Encountered SyntaxError with content {line_content}: {e}" 250 | ) 251 | raise e 252 | parsed_layout[required_line.rstrip(":")] = line_content 253 | reconstructed_response += response[ 254 | : required_line_idx + len(next_required_line) 255 | ] 256 | response = response[required_line_idx + len(next_required_line) :] 257 | process_index += 1 258 | else: 259 | break 260 | if process_index == 0: 261 | # Nothing matches, retry without adding suffix 262 | continue 263 | elif process_index < len(required_lines): 264 | # Partial matches, add the match to the suffix and retry 265 | suffix = ( 266 | "\n" 267 | + response.rstrip(strip_chars) 268 | + "\n" 269 | + required_lines[process_index] 270 | ) 271 | 272 | parsed_layout["Prompt"] = prompt 273 | 274 | return parsed_layout, reconstructed_response 275 | 276 | 277 | def get_parsed_layout_json_resp( 278 | prompt, 279 | llm_kwargs=None, 280 | max_partial_response_retries=1, 281 | override_response=None, 282 | strip_chars=" \t\n`", 283 | save_leading_text=True, 284 | **kwargs, 285 | ): 286 | """ 287 | override_response: override the LLM response (will not query the LLM), useful for parsing existing response 288 | save_leading_text: ignored since we do not allow leading text in JSON 289 | max_partial_response_retries: ignored since we do not allow partial response in JSON 290 | """ 291 | assert ( 292 | max_partial_response_retries == 1 293 | ), "no partial queries are allowed in with JSON format templates" 294 | if override_response is not None: 295 | response = override_response 296 | else: 297 | response = get_layout(prompt, llm_kwargs=llm_kwargs, suffix=None, **kwargs) 298 | 299 | response = response.strip(strip_chars) 300 | 301 | # Alternatively we can use `removeprefix` in `str`. 302 | response = ( 303 | response[len("Response:") :] if response.startswith("Response:") else response 304 | ) 305 | 306 | response = response.strip(strip_chars) 307 | 308 | # print("Response:", response) 309 | 310 | try: 311 | parsed_layout = pyjson5.loads(response) 312 | except ( 313 | ValueError, 314 | pyjson5.Json5Exception, 315 | pyjson5.Json5EOF, 316 | pyjson5.Json5DecoderException, 317 | pyjson5.Json5IllegalCharacter, 318 | ) as e: 319 | print( 320 | f"Encountered exception in parsing the response with content {response}: {e}" 321 | ) 322 | raise e 323 | 324 | reconstructed_response = response 325 | 326 | parsed_layout["Prompt"] = prompt 327 | 328 | return parsed_layout, reconstructed_response 329 | 330 | 331 | def get_parsed_layout_with_cache( 332 | prompt, 333 | llm_kwargs, 334 | verbose=False, 335 | max_retries=3, 336 | cache_miss_allowed=True, 337 | json_template=False, 338 | **kwargs, 339 | ): 340 | """ 341 | Get parsed_layout with cache support. This function will only add to cache after all lines are obtained and parsed. If partial response is obtained, the model will query with the previous partial response. 342 | """ 343 | # Note that cache path needs to be set correctly, as get_cache does not check whether the cache is generated with the given model in the given setting. 344 | 345 | response = get_cache(prompt) 346 | 347 | if response is not None: 348 | print(f"Cache hit: {prompt}") 349 | parsed_layout, _ = get_parsed_layout( 350 | prompt, 351 | llm_kwargs=llm_kwargs, 352 | max_partial_response_retries=1, 353 | override_response=response, 354 | json_template=json_template, 355 | ) 356 | return parsed_layout 357 | 358 | print(f"Cache miss: {prompt}") 359 | 360 | assert cache_miss_allowed, "Cache miss is not allowed" 361 | 362 | done = False 363 | retries = 0 364 | while not done: 365 | retries += 1 366 | if retries >= max_retries: 367 | raise ValueError( 368 | f"Erroring due to many non-successful attempts on prompt: {prompt}" 369 | ) 370 | try: 371 | parsed_layout, reconstructed_response = get_parsed_layout( 372 | prompt, llm_kwargs=llm_kwargs, json_template=json_template, **kwargs 373 | ) 374 | except Exception as e: 375 | print(f"Error: {e}, retrying") 376 | traceback.print_exc() 377 | continue 378 | 379 | done = True 380 | 381 | add_cache(prompt, reconstructed_response) 382 | 383 | if verbose: 384 | print(f"parsed_layout = {parsed_layout}") 385 | 386 | return parsed_layout 387 | -------------------------------------------------------------------------------- /utils/parse.py: -------------------------------------------------------------------------------- 1 | from matplotlib.patches import Polygon 2 | from matplotlib.collections import PatchCollection 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import os 6 | from . import guidance 7 | from collections import namedtuple 8 | import imageio 9 | import shutil 10 | 11 | Condition = namedtuple( 12 | "Condition", ["prompt", "boxes", "phrases", "object_positions", "token_map"] 13 | ) 14 | 15 | img_dir = "imgs" 16 | 17 | # h, w used in the layouts 18 | size = (512, 512) 19 | size_h, size_w = size 20 | # print(f"Using box scale: {size}") 21 | 22 | 23 | def draw_boxes(condition, frame_index=None): 24 | boxes, phrases = condition.boxes, condition.phrases 25 | 26 | ax = plt.gca() 27 | ax.set_autoscale_on(False) 28 | polygons = [] 29 | color = [] 30 | for box_ind, (box, name) in enumerate(zip(boxes, phrases)): 31 | if isinstance(box, dict): 32 | if frame_index not in box: 33 | continue 34 | else: 35 | if frame_index >= len(box): 36 | continue 37 | 38 | box = box[frame_index] if frame_index is not None else box 39 | # Each phrase may be a list to allow different phrase per timestep 40 | name = ( 41 | name[frame_index] 42 | if frame_index is not None and isinstance(name, (dict, list, tuple)) 43 | else name 44 | ) 45 | 46 | # This ensures different frames have the same box color. 47 | rng = np.random.default_rng(box_ind) 48 | c = rng.random((1, 3)) * 0.6 + 0.4 49 | [bbox_x, bbox_y, bbox_x_max, bbox_y_max] = box 50 | if bbox_x_max <= bbox_x or bbox_y_max <= bbox_y: 51 | # Filters out the box in the frames without this box 52 | continue 53 | bbox_x, bbox_y, bbox_x_max, bbox_y_max = ( 54 | bbox_x * size_w, 55 | bbox_y * size_h, 56 | bbox_x_max * size_w, 57 | bbox_y_max * size_h, 58 | ) 59 | poly = [ 60 | [bbox_x, bbox_y], 61 | [bbox_x, bbox_y_max], 62 | [bbox_x_max, bbox_y_max], 63 | [bbox_x_max, bbox_y], 64 | ] 65 | np_poly = np.array(poly).reshape((4, 2)) 66 | polygons.append(Polygon(np_poly)) 67 | color.append(c) 68 | 69 | # print(ann) 70 | ax.text( 71 | bbox_x, 72 | bbox_y, 73 | name, 74 | style="italic", 75 | bbox={"facecolor": "white", "alpha": 0.7, "pad": 5}, 76 | ) 77 | 78 | p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2) 79 | ax.add_collection(p) 80 | 81 | 82 | def show_boxes( 83 | condition, frame_index=None, ind=None, show=True, show_prompt=True, save=False 84 | ): 85 | """ 86 | This draws the boxes in `frame_index`. 87 | """ 88 | boxes, phrases = condition.boxes, condition.phrases 89 | 90 | if len(boxes) == 0: 91 | return 92 | 93 | # White background (to allow line to show on the edge) 94 | I = np.ones((size[0] + 4, size[1] + 4, 3), dtype=np.uint8) * 255 95 | 96 | plt.imshow(I) 97 | plt.axis("off") 98 | 99 | bg_prompt = getattr(condition, "prompt", None) 100 | neg_prompt = getattr(condition, "neg_prompt", None) 101 | 102 | ax = plt.gca() 103 | if show_prompt and bg_prompt is not None: 104 | ax.text( 105 | 0, 106 | 0, 107 | bg_prompt + f"(Neg: {neg_prompt})" if neg_prompt else bg_prompt, 108 | style="italic", 109 | bbox={"facecolor": "white", "alpha": 0.7, "pad": 5}, 110 | ) 111 | c = np.zeros((1, 3)) 112 | [bbox_x, bbox_y, bbox_w, bbox_h] = (0, 0, size[1], size[0]) 113 | poly = [ 114 | [bbox_x, bbox_y], 115 | [bbox_x, bbox_y + bbox_h], 116 | [bbox_x + bbox_w, bbox_y + bbox_h], 117 | [bbox_x + bbox_w, bbox_y], 118 | ] 119 | np_poly = np.array(poly).reshape((4, 2)) 120 | polygons = [Polygon(np_poly)] 121 | color = [c] 122 | p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2) 123 | ax.add_collection(p) 124 | 125 | draw_boxes(condition, frame_index=frame_index) 126 | if show: 127 | plt.show() 128 | 129 | if save: 130 | print("Saved to", f"{img_dir}/boxes.png", f"ind: {ind}") 131 | plt.savefig(f"{img_dir}/boxes.png") 132 | if ind is not None: 133 | shutil.copy(f"{img_dir}/boxes.png", f"{img_dir}/boxes_{ind}.png") 134 | 135 | 136 | def show_video_boxes( 137 | condition, 138 | figsize=(4, 4), 139 | ind=None, 140 | show=False, 141 | save=False, 142 | save_each_frame=False, 143 | fps=8, 144 | save_name="boxes", 145 | **kwargs, 146 | ): 147 | boxes, phrases = condition.boxes, condition.phrases 148 | 149 | assert len(boxes) == len(phrases), f"{len(boxes)} != {len(phrases)}" 150 | 151 | if len(boxes) == 0: 152 | return 153 | 154 | num_frames = len(boxes[0]) 155 | 156 | boxes_frames = [] 157 | 158 | for frame_index in range(num_frames): 159 | fig = plt.figure(figsize=figsize) 160 | # https://stackoverflow.com/questions/7821518/save-plot-to-numpy-array 161 | show_boxes(condition, frame_index=frame_index, show=False, save=False, **kwargs) 162 | # If we haven't already shown or saved the plot, then we need to 163 | # draw the figure first... 164 | fig.canvas.draw() 165 | 166 | # Now we can save it to a numpy array. 167 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 168 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 169 | plt.close() 170 | 171 | boxes_frames.append(data) 172 | 173 | if show: 174 | video = imageio.mimsave( 175 | imageio.RETURN_BYTES, 176 | boxes_frames, 177 | format="gif", 178 | loop=0, 179 | duration=1000 * 1 / fps, 180 | ) 181 | from IPython.display import display, Image as IPyImage 182 | 183 | display(IPyImage(data=video, format="gif")) 184 | 185 | if save: 186 | imageio.mimsave( 187 | f"{img_dir}/{save_name}.gif", 188 | boxes_frames, 189 | format="gif", 190 | loop=0, 191 | duration=1000 * 1 / fps, 192 | ) 193 | if ind is not None: 194 | shutil.copy( 195 | f"{img_dir}/{save_name}.gif", f"{img_dir}/{save_name}_{ind}.gif" 196 | ) 197 | print(f'Saved to "{img_dir}/{save_name}.gif"', f"ind: {ind}") 198 | 199 | if save_each_frame: 200 | os.makedirs(f"{img_dir}/{save_name}", exist_ok=True) 201 | for frame_ind, frame in enumerate(boxes_frames): 202 | imageio.imsave( 203 | f"{img_dir}/{save_name}/{frame_ind}.png", frame, format="png" 204 | ) 205 | print(f'Saved frames to "{img_dir}/{save_name}"', f"ind: {ind}") 206 | 207 | 208 | def show_masks(masks): 209 | masks_to_show = np.zeros((*size, 3), dtype=np.float32) 210 | for mask in masks: 211 | c = np.random.random((3,)) * 0.6 + 0.4 212 | 213 | masks_to_show += mask[..., None] * c[None, None, :] 214 | plt.imshow(masks_to_show) 215 | plt.savefig(f"{img_dir}/masks.png") 216 | plt.show() 217 | plt.close() 218 | 219 | 220 | def convert_box(box, height, width): 221 | # box: x, y, w, h (in 512 format) -> x_min, y_min, x_max, y_max 222 | x_min, y_min = box[0] / width, box[1] / height 223 | w_box, h_box = box[2] / width, box[3] / height 224 | 225 | x_max, y_max = x_min + w_box, y_min + h_box 226 | 227 | return x_min, y_min, x_max, y_max 228 | 229 | 230 | def interpolate_box(box, num_input_frames=6, num_output_frames=24, repeat=1): 231 | output_boxes = np.zeros((num_output_frames, 4)) 232 | box_time_indices = np.sort(list(box.keys())) 233 | xs = np.concatenate( 234 | [box_time_indices / (num_input_frames - 1) + i for i in range(repeat)] 235 | ) 236 | # The subtraction is to prevent the boundary effect with modulus. 237 | xs_query = np.linspace(0, repeat - 1e-5, num_output_frames) 238 | mask = np.isin(np.floor((xs_query % 1.0) * num_input_frames), box_time_indices) 239 | 240 | # 4: x_min, y_min, x_max, y_max 241 | for i in range(4): 242 | ys = np.array( 243 | [box[box_time_index][i] for box_time_index in box_time_indices] * repeat 244 | ) 245 | # If the mask is False (the object does not exist in this timestep, the box has all items 0) 246 | output_boxes[:, i] = np.interp(xs_query, xs, ys) * mask 247 | 248 | return output_boxes.tolist() 249 | 250 | 251 | def parsed_layout_to_condition( 252 | parsed_layout, 253 | height, 254 | width, 255 | num_parsed_layout_frames=6, 256 | num_condition_frames=24, 257 | interpolate_boxes=True, 258 | tokenizer=None, 259 | output_phrase_per_timestep=False, 260 | add_background_to_prompt=True, 261 | strip_phrases=False, 262 | verbose=False, 263 | ): 264 | """ 265 | Infer condition from parsed layout. 266 | Boxes can appear or disappear. 267 | """ 268 | 269 | prompt = parsed_layout["Prompt"] 270 | 271 | if add_background_to_prompt and parsed_layout["Background keyword"]: 272 | prompt += f", {parsed_layout['Background keyword']} background" 273 | 274 | id_to_phrase, id_to_box = {}, {} 275 | 276 | box_ids = [] 277 | 278 | for frame_ind in range(num_parsed_layout_frames): 279 | object_dicts = parsed_layout[f"Frame {frame_ind + 1}"] 280 | for object_dict in object_dicts: 281 | current_box_id = object_dict["id"] 282 | if current_box_id not in id_to_phrase: 283 | if output_phrase_per_timestep: 284 | # Only the phrase at the first occurrence is used if `output_phrase_per_timestep` is False 285 | id_to_phrase[current_box_id] = {} 286 | else: 287 | id_to_phrase[current_box_id] = ( 288 | object_dict["name"] 289 | if "name" in object_dict 290 | else object_dict["keyword"] 291 | ) 292 | 293 | # Use `dict` to handle appearance and disappearance of objects 294 | id_to_box[current_box_id] = {} 295 | 296 | box_ids.append(current_box_id) 297 | 298 | box = object_dict["box"] 299 | converted_box = convert_box(box, height=height, width=width) 300 | id_to_box[current_box_id][frame_ind] = converted_box 301 | 302 | if output_phrase_per_timestep: 303 | id_to_phrase[current_box_id][frame_ind] = ( 304 | object_dict["name"] 305 | if "name" in object_dict 306 | else object_dict["keyword"] 307 | ) 308 | 309 | boxes = [id_to_box[box_id] for box_id in box_ids] 310 | phrases = [id_to_phrase[box_id] for box_id in box_ids] 311 | 312 | if verbose: 313 | boxes_before_interpolation = boxes 314 | 315 | # Frames in interpolated boxes are consecutive, but some boxes may have all coordinates as 0 to indicate disappearance 316 | if interpolate_boxes: 317 | assert ( 318 | not output_phrase_per_timestep 319 | ), "box interpolation with phrase per timestep is not implemented" 320 | boxes = [ 321 | interpolate_box( 322 | box, 323 | num_parsed_layout_frames, 324 | num_condition_frames, 325 | repeat=parsed_layout.get("Repeat", 1), 326 | ) 327 | for box in boxes 328 | ] 329 | 330 | if tokenizer is not None: 331 | for phrase in phrases: 332 | found, _ = guidance.refine_phrase(prompt, phrase, verbose=True) 333 | 334 | if not found: 335 | # Suffix the prompt with object name (before the refinement) for attention guidance if object is not in the prompt, using "|" to separate the prompt and the suffix 336 | prompt += "| " + phrase 337 | 338 | print(f'**Adding {phrase} to the prompt. Using prompt: "{prompt}"') 339 | 340 | # `phrases` might not correspond to the first occurrence in the prompt, which is not handled now. 341 | token_map = guidance.get_token_map( 342 | tokenizer, prompt=prompt, verbose=verbose, padding="do_not_pad" 343 | ) 344 | object_positions = guidance.get_phrase_indices( 345 | tokenizer, prompt, phrases, token_map=token_map, verbose=verbose 346 | ) 347 | else: 348 | token_map = None 349 | object_positions = None 350 | 351 | if verbose: 352 | print("prompt:", prompt) 353 | print("boxes (before interpolation):", boxes_before_interpolation) 354 | if verbose >= 2: 355 | print("boxes (after interpolation):", np.round(np.array(boxes), 2)) 356 | print("phrases:", phrases) 357 | if object_positions is not None: 358 | print("object_positions:", object_positions) 359 | 360 | if strip_phrases: 361 | phrases = [phrase.strip("1234567890 ") for phrase in phrases] 362 | 363 | return Condition(prompt, boxes, phrases, object_positions, token_map) 364 | -------------------------------------------------------------------------------- /utils/schedule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | 4 | 5 | def get_fast_schedule(origial_timesteps, fast_after_steps, fast_rate): 6 | if fast_after_steps >= len(origial_timesteps) - 1: 7 | return origial_timesteps 8 | new_timesteps = torch.cat( 9 | ( 10 | origial_timesteps[:fast_after_steps], 11 | origial_timesteps[fast_after_steps + 1 :: fast_rate], 12 | ), 13 | dim=0, 14 | ) 15 | return new_timesteps 16 | 17 | 18 | def dynamically_adjust_inference_steps(scheduler, index, t): 19 | prev_t = ( 20 | scheduler.timesteps[index + 1] if index + 1 < len(scheduler.timesteps) else -1 21 | ) 22 | scheduler.num_inference_steps = scheduler.config.num_train_timesteps // (t - prev_t) 23 | if index + 1 < len(scheduler.timesteps): 24 | if ( 25 | scheduler.config.num_train_timesteps // scheduler.num_inference_steps 26 | != t - prev_t 27 | ): 28 | warnings.warn( 29 | f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) != ({t} - {prev_t}), so the step sizes may not be accurate" 30 | ) 31 | else: 32 | # as long as we hit final cumprob, it should be fine. 33 | if ( 34 | scheduler.config.num_train_timesteps // scheduler.num_inference_steps 35 | > t - prev_t 36 | ): 37 | warnings.warn( 38 | f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) > ({t} - {prev_t}), so the step sizes may not be accurate" 39 | ) 40 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import ImageDraw 3 | import numpy as np 4 | import os 5 | import gc 6 | import math 7 | from typing import List 8 | import cv2 9 | import skvideo.io 10 | 11 | torch_device = "cuda" 12 | 13 | 14 | def draw_box(pil_img, bboxes, phrases, ignore_all_zeros=True): 15 | W, H = pil_img.size 16 | draw = ImageDraw.Draw(pil_img) 17 | 18 | for obj_bbox, phrase in zip(bboxes, phrases): 19 | x_0, y_0, x_1, y_1 = obj_bbox[0], obj_bbox[1], obj_bbox[2], obj_bbox[3] 20 | if ignore_all_zeros and x_0 == 0 and y_0 == 0 and x_1 == 0 and y_1 == 0: 21 | continue 22 | draw.rectangle( 23 | [int(x_0 * W), int(y_0 * H), int(x_1 * W), int(y_1 * H)], 24 | outline="red", 25 | width=5, 26 | ) 27 | draw.text( 28 | (int(x_0 * W) + 5, int(y_0 * H) + 5), phrase, font=None, fill=(255, 0, 0) 29 | ) 30 | 31 | return pil_img 32 | 33 | 34 | def get_centered_box( 35 | box, 36 | horizontal_center_only=True, 37 | vertical_placement="centered", 38 | vertical_center=0.5, 39 | floor_padding=None, 40 | ): 41 | x_min, y_min, x_max, y_max = box 42 | w = x_max - x_min 43 | 44 | x_min_new = 0.5 - w / 2 45 | x_max_new = 0.5 + w / 2 46 | 47 | if horizontal_center_only: 48 | return [x_min_new, y_min, x_max_new, y_max] 49 | 50 | h = y_max - y_min 51 | 52 | if vertical_placement == "centered": 53 | assert ( 54 | floor_padding is None 55 | ), "Set vertical_placement to floor_padding to use floor padding" 56 | 57 | y_min_new = vertical_center - h / 2 58 | y_max_new = vertical_center + h / 2 59 | elif vertical_placement == "floor_padding": 60 | # Ignores `vertical_center` 61 | 62 | y_max_new = 1 - floor_padding 63 | y_min_new = y_max_new - h 64 | else: 65 | raise ValueError(f"Unknown vertical placement: {vertical_placement}") 66 | 67 | return [x_min_new, y_min_new, x_max_new, y_max_new] 68 | 69 | 70 | # NOTE: this changes the behavior of the function 71 | def proportion_to_mask(obj_box, H, W, use_legacy=False, return_np=False): 72 | x_min, y_min, x_max, y_max = scale_proportion(obj_box, H, W, use_legacy) 73 | if return_np: 74 | mask = np.zeros((H, W)) 75 | else: 76 | mask = torch.zeros(H, W).to(torch_device) 77 | mask[y_min:y_max, x_min:x_max] = 1.0 78 | 79 | return mask 80 | 81 | 82 | def scale_proportion(obj_box, H, W, use_legacy=False): 83 | if use_legacy: 84 | # Bias towards the top-left corner 85 | x_min, y_min, x_max, y_max = ( 86 | int(obj_box[0] * W), 87 | int(obj_box[1] * H), 88 | int(obj_box[2] * W), 89 | int(obj_box[3] * H), 90 | ) 91 | else: 92 | # Separately rounding box_w and box_h to allow shift invariant box sizes. Otherwise box sizes may change when both coordinates being rounded end with ".5". 93 | x_min, y_min = round(obj_box[0] * W), round(obj_box[1] * H) 94 | box_w, box_h = ( 95 | round((obj_box[2] - obj_box[0]) * W), 96 | round((obj_box[3] - obj_box[1]) * H), 97 | ) 98 | x_max, y_max = x_min + box_w, y_min + box_h 99 | 100 | x_min, y_min = max(x_min, 0), max(y_min, 0) 101 | x_max, y_max = min(x_max, W), min(y_max, H) 102 | 103 | return x_min, y_min, x_max, y_max 104 | 105 | 106 | def binary_mask_to_box(mask, enlarge_box_by_one=True, w_scale=1, h_scale=1): 107 | if isinstance(mask, torch.Tensor): 108 | mask_loc = torch.where(mask) 109 | else: 110 | mask_loc = np.where(mask) 111 | height, width = mask.shape 112 | if len(mask_loc) == 0: 113 | raise ValueError("The mask is empty") 114 | if enlarge_box_by_one: 115 | ymin, ymax = max(min(mask_loc[0]) - 1, 0), min(max(mask_loc[0]) + 1, height) 116 | xmin, xmax = max(min(mask_loc[1]) - 1, 0), min(max(mask_loc[1]) + 1, width) 117 | else: 118 | ymin, ymax = min(mask_loc[0]), max(mask_loc[0]) 119 | xmin, xmax = min(mask_loc[1]), max(mask_loc[1]) 120 | box = [xmin * w_scale, ymin * h_scale, xmax * w_scale, ymax * h_scale] 121 | 122 | return box 123 | 124 | 125 | def binary_mask_to_box_mask(mask, to_device=True): 126 | box = binary_mask_to_box(mask) 127 | x_min, y_min, x_max, y_max = box 128 | 129 | H, W = mask.shape 130 | mask = torch.zeros(H, W) 131 | if to_device: 132 | mask = mask.to(torch_device) 133 | mask[y_min : y_max + 1, x_min : x_max + 1] = 1.0 134 | 135 | return mask 136 | 137 | 138 | def binary_mask_to_center(mask, normalize=False): 139 | """ 140 | This computes the mass center of the mask. 141 | normalize: the coords range from 0 to 1 142 | 143 | Reference: https://stackoverflow.com/a/66184125 144 | """ 145 | h, w = mask.shape 146 | 147 | total = mask.sum() 148 | if isinstance(mask, torch.Tensor): 149 | x_coord = ((mask.sum(dim=0) @ torch.arange(w)) / total).item() 150 | y_coord = ((mask.sum(dim=1) @ torch.arange(h)) / total).item() 151 | else: 152 | x_coord = (mask.sum(axis=0) @ np.arange(w)) / total 153 | y_coord = (mask.sum(axis=1) @ np.arange(h)) / total 154 | 155 | if normalize: 156 | x_coord, y_coord = x_coord / w, y_coord / h 157 | return x_coord, y_coord 158 | 159 | 160 | def iou(mask, masks, eps=1e-6): 161 | # mask: [h, w], masks: [n, h, w] 162 | mask = mask[None].astype(bool) 163 | masks = masks.astype(bool) 164 | i = (mask & masks).sum(axis=(1, 2)) 165 | u = (mask | masks).sum(axis=(1, 2)) 166 | 167 | return i / (u + eps) 168 | 169 | 170 | def free_memory(): 171 | gc.collect() 172 | torch.cuda.empty_cache() 173 | 174 | 175 | def expand_overall_bboxes(overall_bboxes): 176 | """ 177 | Expand overall bboxes from a 3d list to 2d list: 178 | Input: [[box 1 for phrase 1, box 2 for phrase 1], ...] 179 | Output: [box 1, box 2, ...] 180 | """ 181 | return sum(overall_bboxes, start=[]) 182 | 183 | 184 | def shift_tensor( 185 | tensor, 186 | x_offset, 187 | y_offset, 188 | base_w=8, 189 | base_h=8, 190 | offset_normalized=False, 191 | ignore_last_dim=False, 192 | ): 193 | """base_w and base_h: make sure the shift is aligned in the latent and multiple levels of cross attention""" 194 | if ignore_last_dim: 195 | tensor_h, tensor_w = tensor.shape[-3:-1] 196 | else: 197 | tensor_h, tensor_w = tensor.shape[-2:] 198 | if offset_normalized: 199 | assert ( 200 | tensor_h % base_h == 0 and tensor_w % base_w == 0 201 | ), f"{tensor_h, tensor_w} is not a multiple of {base_h, base_w}" 202 | scale_from_base_h, scale_from_base_w = tensor_h // base_h, tensor_w // base_w 203 | x_offset, y_offset = ( 204 | round(x_offset * base_w) * scale_from_base_w, 205 | round(y_offset * base_h) * scale_from_base_h, 206 | ) 207 | new_tensor = torch.zeros_like(tensor) 208 | 209 | overlap_w = tensor_w - abs(x_offset) 210 | overlap_h = tensor_h - abs(y_offset) 211 | 212 | if y_offset >= 0: 213 | y_src_start = 0 214 | y_dest_start = y_offset 215 | else: 216 | y_src_start = -y_offset 217 | y_dest_start = 0 218 | 219 | if x_offset >= 0: 220 | x_src_start = 0 221 | x_dest_start = x_offset 222 | else: 223 | x_src_start = -x_offset 224 | x_dest_start = 0 225 | 226 | if ignore_last_dim: 227 | # For cross attention maps, the third to last and the second to last are the 2D dimensions after unflatten. 228 | new_tensor[ 229 | ..., 230 | y_dest_start : y_dest_start + overlap_h, 231 | x_dest_start : x_dest_start + overlap_w, 232 | :, 233 | ] = tensor[ 234 | ..., 235 | y_src_start : y_src_start + overlap_h, 236 | x_src_start : x_src_start + overlap_w, 237 | :, 238 | ] 239 | else: 240 | new_tensor[ 241 | ..., 242 | y_dest_start : y_dest_start + overlap_h, 243 | x_dest_start : x_dest_start + overlap_w, 244 | ] = tensor[ 245 | ..., 246 | y_src_start : y_src_start + overlap_h, 247 | x_src_start : x_src_start + overlap_w, 248 | ] 249 | 250 | return new_tensor 251 | 252 | 253 | def get_hw_from_attn_dim(attn_dim, base_attn_dim): 254 | # base_attn_dim: (40, 72) for zeroscope (width 576, height 320) 255 | scale = int(math.sqrt((base_attn_dim[0] * base_attn_dim[1]) / attn_dim)) 256 | return base_attn_dim[0] // scale, base_attn_dim[1] // scale 257 | 258 | 259 | # Reference: https://github.com/huggingface/diffusers/blob/v0.20.0/src/diffusers/utils/testing_utils.py#L400 260 | def export_to_video( 261 | video_frames: List[np.ndarray], 262 | output_video_path: str, 263 | fps: int = 8, 264 | fourcc: str = "mp4v", 265 | use_opencv=False, 266 | crf=17, 267 | ) -> str: 268 | if use_opencv: 269 | # This requires a cv2 installation that has video encoder support. 270 | 271 | fourcc = cv2.VideoWriter_fourcc(*fourcc) 272 | h, w, c = video_frames[0].shape 273 | video_writer = cv2.VideoWriter( 274 | output_video_path, fourcc, fps=fps, frameSize=(w, h) 275 | ) 276 | for i in range(len(video_frames)): 277 | img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) 278 | video_writer.write(img) 279 | else: 280 | skvideo.io.vwrite( 281 | output_video_path, 282 | video_frames, 283 | inputdict={"-framerate": str(fps)}, 284 | outputdict={"-vcodec": "libx264", "-pix_fmt": "yuv420p", "-crf": str(crf)}, 285 | ) 286 | return output_video_path 287 | 288 | 289 | def multiline_input(prompt, return_on_empty_lines=True): 290 | # Adapted from https://stackoverflow.com/questions/30239092/how-to-get-multiline-input-from-the-user 291 | 292 | print(prompt, end="", flush=True) 293 | contents = "" 294 | while True: 295 | try: 296 | line = input() 297 | except EOFError: 298 | break 299 | if line == "" and return_on_empty_lines: 300 | break 301 | contents += line + "\n" 302 | 303 | return contents 304 | 305 | 306 | def find_gen_dir(gen_name, create_dir=True): 307 | base_save_dir = f"img_generations/{gen_name}" 308 | run_ind = 0 309 | 310 | while True: 311 | gen_dir = f"{base_save_dir}/run{run_ind}" 312 | if not os.path.exists(gen_dir): 313 | break 314 | run_ind += 1 315 | 316 | print(f"Save results at {gen_dir}") 317 | if create_dir: 318 | os.makedirs(gen_dir, exist_ok=False) 319 | 320 | return gen_dir 321 | -------------------------------------------------------------------------------- /utils/vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import utils 4 | from . import parse 5 | import imageio 6 | import joblib 7 | 8 | 9 | def visualize(image, title, colorbar=False, show_plot=True, **kwargs): 10 | plt.title(title) 11 | plt.imshow(image, **kwargs) 12 | if colorbar: 13 | plt.colorbar() 14 | if show_plot: 15 | plt.show() 16 | 17 | 18 | def visualize_arrays( 19 | image_title_pairs, 20 | colorbar_index=-1, 21 | show_plot=True, 22 | figsize=None, 23 | no_axis=False, 24 | **kwargs, 25 | ): 26 | if figsize is not None: 27 | plt.figure(figsize=figsize) 28 | num_subplots = len(image_title_pairs) 29 | for idx, image_title_pair in enumerate(image_title_pairs): 30 | plt.subplot(1, num_subplots, idx + 1) 31 | if isinstance(image_title_pair, (list, tuple)): 32 | image, title = image_title_pair 33 | else: 34 | image, title = image_title_pair, None 35 | 36 | if title is not None: 37 | plt.title(title) 38 | 39 | plt.imshow(image, **kwargs) 40 | if no_axis: 41 | plt.axis("off") 42 | if idx == colorbar_index: 43 | plt.colorbar() 44 | 45 | if show_plot: 46 | plt.show() 47 | 48 | 49 | def visualize_masked_latents( 50 | latents_all, masked_latents, timestep_T=False, timestep_0=True 51 | ): 52 | if timestep_T: 53 | # from T to 0 54 | latent_idx = 0 55 | 56 | plt.subplot(1, 2, 1) 57 | plt.title("latents_all (t=T)") 58 | plt.imshow( 59 | ( 60 | latents_all[latent_idx, 0, :3] 61 | .cpu() 62 | .permute(1, 2, 0) 63 | .numpy() 64 | .astype(float) 65 | / 1.5 66 | ).clip(0.0, 1.0), 67 | cmap="gray", 68 | ) 69 | 70 | plt.subplot(1, 2, 2) 71 | plt.title("mask latents (t=T)") 72 | plt.imshow( 73 | ( 74 | masked_latents[latent_idx, 0, :3] 75 | .cpu() 76 | .permute(1, 2, 0) 77 | .numpy() 78 | .astype(float) 79 | / 1.5 80 | ).clip(0.0, 1.0), 81 | cmap="gray", 82 | ) 83 | 84 | plt.show() 85 | 86 | if timestep_0: 87 | latent_idx = -1 88 | plt.subplot(1, 2, 1) 89 | plt.title("latents_all (t=0)") 90 | plt.imshow( 91 | ( 92 | latents_all[latent_idx, 0, :3] 93 | .cpu() 94 | .permute(1, 2, 0) 95 | .numpy() 96 | .astype(float) 97 | / 1.5 98 | ).clip(0.0, 1.0), 99 | cmap="gray", 100 | ) 101 | 102 | plt.subplot(1, 2, 2) 103 | plt.title("mask latents (t=0)") 104 | plt.imshow( 105 | ( 106 | masked_latents[latent_idx, 0, :3] 107 | .cpu() 108 | .permute(1, 2, 0) 109 | .numpy() 110 | .astype(float) 111 | / 1.5 112 | ).clip(0.0, 1.0), 113 | cmap="gray", 114 | ) 115 | 116 | plt.show() 117 | 118 | 119 | def visualize_bboxes(bboxes, H, W): 120 | num_boxes = len(bboxes) 121 | for ind, bbox in enumerate(bboxes): 122 | plt.subplot(1, num_boxes, ind + 1) 123 | fg_mask = utils.proportion_to_mask(bbox, H, W) 124 | plt.title(f"transformed bbox ({ind})") 125 | plt.imshow(fg_mask.cpu().numpy()) 126 | plt.show() 127 | 128 | 129 | def save_image(image, save_prefix="", ind=None): 130 | global save_ind 131 | if save_prefix != "": 132 | save_prefix = save_prefix + "_" 133 | ind = f"{ind}_" if ind is not None else "" 134 | path = f"{parse.img_dir}/{save_prefix}{ind}{save_ind}.png" 135 | 136 | print(f"Saved to {path}") 137 | 138 | image.save(path) 139 | save_ind = save_ind + 1 140 | 141 | 142 | def save_frames(path, frames, formats="gif", fps=8): 143 | if isinstance(formats, (list, tuple)): 144 | for format in formats: 145 | save_frames(path, frames, format, fps) 146 | return 147 | 148 | if formats == "gif": 149 | imageio.mimsave( 150 | f"{path}.gif", frames, format="gif", loop=0, duration=1000 * 1 / fps 151 | ) 152 | elif formats == "mp4": 153 | utils.export_to_video( 154 | video_frames=frames, output_video_path=f"{path}.mp4", fps=fps 155 | ) 156 | elif formats == "npz": 157 | np.savez_compressed(f"{path}.npz", frames) 158 | elif formats == "joblib": 159 | joblib.dump(frames, f"{path}.joblib", compress=("bz2", 3)) 160 | else: 161 | raise ValueError(f"Unknown format: {formats}") 162 | --------------------------------------------------------------------------------