├── .gitignore
├── README.md
├── assets
├── example_video_lvd_gligen_zeroscope.gif
├── example_video_lvd_zeroscope.gif
└── example_video_zeroscope_baseline.gif
├── cache
├── cache_demo_v0.1_gpt-4-1106-preview.json
├── cache_lvd_v0.1_gpt-3.5-turbo.json
└── cache_lvd_v0.1_gpt-4-1106-preview.json
├── generate.py
├── generation
├── lvd.py
├── lvd_gligen.py
├── lvd_plus.py
├── modelscope_dpm.py
└── zeroscope_dpm.py
├── models
├── __init__.py
├── attention.py
├── attention_processor.py
├── controllable_pipeline_text_to_video_synth.py
├── models.py
├── pipelines.py
├── transformer_2d.py
├── transformer_temporal.py
├── unet_2d_blocks.py
├── unet_2d_condition.py
├── unet_3d_blocks.py
└── unet_3d_condition.py
├── prompt.py
├── prompt_batch.py
├── requirements.txt
├── scripts
├── eval_owl_vit.py
├── eval_stage_one.py
└── upsample.py
└── utils
├── __init__.py
├── attn.py
├── cache.py
├── eval
├── __init__.py
├── eval.py
├── lvd.py
└── utils.py
├── guidance.py
├── latents.py
├── llm.py
├── parse.py
├── schedule.py
├── utils.py
└── vis.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | api_key.py
163 | *_ignored*
164 |
165 | img_generations
166 |
167 | .ruff_cache
168 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LLM-grounded Video Diffusion Models
2 | [Long Lian](https://tonylian.com/), [Baifeng Shi](https://bfshi.github.io/), [Adam Yala](https://www.adamyala.org/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/), [Boyi Li](https://sites.google.com/site/boyilics/home) at UC Berkeley/UCSF.
3 |
4 | ***International Conference on Learning Representations (ICLR) 2024***
5 |
6 | [Paper PDF](https://openreview.net/pdf?id=exKHibougU) | [Arxiv](https://arxiv.org/abs/2309.17444) | [Project Page](https://llm-grounded-video-diffusion.github.io/) | [Related Project: LMD](https://llm-grounded-diffusion.github.io/) | [Citation](#citation)
7 |
8 | **TL;DR**: Text Prompt -> LLM Spatiotemporal Planner -> Dynamic Scene Layouts (DSLs) -> DSL-grounded Video Generator -> Video. This gives you much better **prompt understanding capabilities** for text-to-video generation.
9 |
10 | 
11 |
12 | ## Updates
13 | **[2024.4]** From the IGLIGEN project that offers a modern GLIGEN training codebase, several GLIGEN adapters are trained for image and video generation! This can be directly plugged into our DSL-grounded video generator and is compatible with the cross-attention control introduced in our paper. We offer huggingface diffusers integration that works off-the-shelf (**no repo clone or additional packages needed**)! Check out [this demo colab notebook](https://colab.research.google.com/drive/17He4bFAF8lXmT9Nfv-Sg29iKtPelDUNZ) for box-conditioned Modelscope with IGLIGEN adapters. See more details and comparisons for how to use it with this repo [here](#example-videos).
14 |
15 | # LLM-grounded Video Diffusion Models (LVD)
16 | The codebase is based on [LLM-grounded Diffusion (LMD)](https://github.com/TonyLianLong/LLM-groundedDiffusion) repo, so please also check out the instructions and FAQs there if your issues are not covered in this repo.
17 |
18 | ## Installation
19 | Install the dependencies:
20 | ```
21 | pip install -r requirements.txt
22 | ```
23 |
24 | ## Stage 1: Generating Dynamic Scene Layouts (DSLs) from text
25 | **Note that we have uploaded the layout caches for the benchmark onto this repo so that you can skip this step if you don't need layouts for new prompts (i.e., just want to use LLM-generated layouts for benchmarking stage 2).**
26 |
27 | Since we have cached the layout generation (which will be downloaded when you clone the repo), **you need to remove the cache in `cache` directory if you want to re-generate the layout with the same prompts**.
28 |
29 | **Our layout generation format:** The LLM takes in a text prompt describing the image and outputs dynamic scene layouts that consist of three elements: **1.** a reasoning statement for analysis, **2.** a captioned box for each object in each frame, and **3.** a background prompt. The template and example prompts are in [prompt.py](prompt.py). You can edit the template, the example prompts, and the parsing function to ask the LLM to generate additional things or even perform chain-of-thought for better generation.
30 |
31 | ### Automated query from OpenAI API
32 | Again, if you just want to evaluate stage 2 (layout to video stage), you can skip stage 1 as we have uploaded the layout caches onto this repo. **You don't need an OpenAI API key in stage 2.**
33 |
34 | If you have an [OpenAI API key](https://openai.com/blog/openai-api), you can put the API key in `utils/api_key.py` or set `OPENAI_API_KEY` environment variable. Then you can use OpenAI's API for batch text-to-layout generation by querying an LLM, with GPT-4 as an example:
35 | ```shell
36 | python prompt_batch.py --prompt-type demo --model gpt-4-1106-preview --auto-query --always-save --template_version v0.1
37 | ```
38 | `--prompt-type demo` includes a few prompts for demonstrations. You can change them in [prompt.py](prompt.py). The layout generation will be cached so it does not query the LLM again with the same prompt (lowers the cost).
39 |
40 | You can visualize the dynamic scene layouts in the form of bounding boxes in `img_generations/imgs_demo_templatev0.1`. They are saved as `gif` files. For horizontal video generation with zeroscope, the square layout will be scaled according to the video aspect ratio.
41 |
42 | ### Run our benchmark on text-to-layout generation evaluation
43 | We provide a benchmark that applies both to stage 1 and stage 2. This benchmarks includes a set of prompts with five tasks (numeracy, attribution, visibility, dynamic satial, and sequential) as well as unified benchmarking code for all implemented methods and both stages.
44 |
45 | This will generate layouts from the prompts in the benchmark (with `--prompt-type lvd`) and evaluate the results:
46 | ```shell
47 | python prompt_batch.py --prompt-type lvd --model gpt-4-1106-preview --auto-query --always-save --template_version v0.1
48 | python scripts/eval_stage_one.py --prompt-type lvd --model gpt-4-1106-preview --template_version v0.1
49 | ```
50 |
51 | Our reference benchmark results (stage 1, evaluating the generated layouts only)
52 |
53 | | Method | Numeracy | Attribution | Visibility | Dynamics | Sequential | Overall |
54 | | -------- | -------- | ----------- | ---------- | -------- | ---------- | ---------- |
55 | | GPT-3.5 | 100 | 100 | 100 | 71 | 16 | 77% |
56 | | GPT-3.5* | 100 | 100 | 100 | 73 | 15 | 78% |
57 | | GPT-4 | 100 | 100 | 100 | 100 | 88 | **98%** |
58 |
59 | \* The generated cache in this repo comes from this rerun. It differs marginally (1%) from the run that we reported.
60 |
61 |
62 | ## Stage 2: Generating Videos from Dynamic Scene Layouts
63 | Note that since we provide caches for stage 1, you don't need to run stage 1 on your own for cached prompts that we provide (i.e., you don't need an OpenAI API key or to query an LLM).
64 |
65 | Run layout-to-video generation using the gpt-4 cache and **LVD with Zeroscope** (resolution 576x320, 24 frames\*):
66 | ```shell
67 | # Zeroscope (horizontal videos)
68 | python generate.py --model gpt-4-1106-preview --run-model lvd_zeroscope --prompt-type demo --save-suffix weak_guidance --template_version v0.1 --seed_offset 0 --repeats 10 --loss_scale 2.5 --loss_threshold 350. --max_iter 1 --max_index_step 10 --fg_top_p 0.25 --bg_top_p 0.25 --fg_weight 1.0 --bg_weight 2.0 --num_frames 24
69 | ```
70 |
71 | Run video generation with **LVD with Modelscope** (resolution 256x256, 16 frames\*):
72 | ```shell
73 | # Modelscope (square videos)
74 | python generate.py --model gpt-4-1106-preview --run-model lvd_modelscope256 --prompt-type demo --save-suffix weak_guidance --template_version v0.1 --seed_offset 0 --repeats 10 --loss_scale 2.5 --loss_threshold 250. --max_iter 1 --max_index_step 10 --fg_top_p 0.25 --bg_top_p 0.25 --fg_weight 1.0 --bg_weight 2.0 --num_frames 16
75 | ```
76 |
77 | \* For context, [Zeroscope](https://huggingface.co/cerspense/zeroscope_v2_576w) is a model fine-tuned from [Modelscope](https://huggingface.co/ali-vilab/text-to-video-ms-1.7b) that generates horizontal videos without watermark. In contrast, Modelscope generates square videos, and Modelscope's generated videos often come with watermark.
78 |
79 | ### Update: LVD with GLIGEN adapters for lower inference costs
80 | Similar to LMD+, you can also integrate GLIGEN adapters trained in the [IGLIGEN project](https://github.com/TonyLianLong/igligen) with Modelscope in stage 2. Using GLIGEN adapters on Modelscope **requires less memory and is faster than training-free cross-attention control**. It could run on a T4 GPU on Colab. To use IGLIGEN Modelscope/Zeroscope adapters as our stage 2, you can use this command:
81 | ```shell
82 | # Zeroscope (horizontal videos)
83 | python generate.py --model gpt-4-1106-preview --run-model lvd-gligen_zeroscope --prompt-type demo --save-suffix lvd_gligen --template_version v0.1 --seed_offset 0 --repeats 10 --num_frames 24 --gligen_scheduled_sampling_beta 0.4
84 |
85 | # Modelscope (square videos)
86 | python generate.py --model gpt-4-1106-preview --run-model lvd-gligen_modelscope256 --prompt-type demo --save-suffix lvd_gligen --template_version v0.1 --seed_offset 0 --repeats 10 --num_frames 16 --gligen_scheduled_sampling_beta 0.4
87 | ```
88 |
89 | Training-based methods such as GLIGEN typiclly have better spatial control, but sometimes can lead to different interpretations of words w.r.t. the base diffusion model or limited diversity for rare objects. **We recommend trying out both this variant and the original cross-attention-based control variant to get the best of both variants.**
90 |
91 | ### Baselines
92 |
93 | Run baselines to compare the generated video
94 |
95 | Run Zeroscope baseline:
96 | ```shell
97 | python generate.py --model gpt-4-1106-preview --run-model zeroscope --prompt-type demo --save-suffix baseline --template_version v0.1 --seed_offset 0 --repeats 10 --num_frames 24
98 | ```
99 |
100 | Run Modelscope baseline:
101 | ```shell
102 | python generate.py --model gpt-4-1106-preview --run-model modelscope_256 --prompt-type demo --save-suffix baseline --template_version v0.1 --seed_offset 0 --repeats 10 --num_frames 16
103 | ```
104 |
105 |
106 | You can use `--save-suffix` to specify the suffix added to the name of the run. `--run-model` specifies the method to run. You can set to LVD/LVD-GLIGEN or the implemented baselines.
107 |
108 | **Note:** there is a tradeoff between the strength of the control and the overall quality of the image with cross-attention control. The commands above with suffix `weak_guidance` trades control strength for lower level of artifacts. See the note in the benchmark section in this README for details.
109 |
110 | ### Saving formats
111 | The generated videos are saved in `img_generations/imgs_demo_templatev0.1_gpt-4-1106-preview_lvd_zeroscope_weak_guidance` (or other directories in `img_generations` if you run other commands). Note that the script saves the video in two formats: **1.** a `gif` file is saved for each video for quick and easy visualizations and debugging, **2.** a `joblib` file that has the original uncompressed numpy array is saved. **Note that `gif` files only support up to 256 colors, so `joblib` is recommended for future processing/export.** You can easily load with `joblib` and export to `mp4` or do other analysis/visualizations.
112 |
113 | ### Upsampling to get high-resolution videos
114 | This repo supports both [Zeroscope v2 XL](https://huggingface.co/cerspense/zeroscope_v2_XL) and [Stable Diffusion XL refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0) upsampling.
115 |
116 | Note that SDXL performs upsampling per frame, so it's likely to have some jittering. We recommend using Zeroscope v2 XL upsampler to upsample videos generated by Zeroscope and LVD with Zeroscope.
117 |
118 | ```shell
119 | python scripts/upsample.py --videos path_to_joblib_video --prompts "prompt for the video" --use_zsxl --horizontal --output-mp4
120 | ```
121 |
122 | You can use `--output-mp4` to also output `mp4` video. Otherwise only `gif` and `joblib` files will be saved.
123 |
124 | ### Example videos
125 | Example videos generated with the scripts above with prompt `A bear walking from the left to the right`:
126 | | LVD (Ours, using cross-attention control) | LVD-GLIGEN (using IGLIGEN adapters) | Baseline |
127 | | ---- | -------- | ---- |
128 | |  |  |  |
129 |
130 | As you can see, our method leads to improved prompt understanding, while the baseline does not understand "walking from the left to the right" correctly.
131 |
132 | ### Run our benchmark on layout-to-image generation evaluation
133 | We use a unified evaluation metric as stage 1 in stage 2 (`--prompt-type lvd`). Since we have layout boxes for stage 1 but only images for stage 2, we use OWL-ViT in order to detect the objects and ensure they are generated (or not generated in negation) in the right number, with the right attributes, and in the right place.
134 |
135 | This runs generation with LVD and evaluate the generation (resolution 256x256):
136 | ```shell
137 | # Use GPT-4 layouts as an example
138 |
139 | ## Generation 1
140 | python generate.py --model "gpt-4-1106-preview" --force_run_ind 0 --run-model lvd_modelscope256 --prompt-type lvd --template_version v0.1 --seed_offset 0 --com_loss_scale 0.03
141 | python scripts/eval_owl_vit.py --run_base_path ./img_generations/imgs_lvd_templatev0.1_gpt-4-1106-preview_lvd_modelscope256/run0 --skip_first_prompts 0 --verbose --detection_score_threshold 0.1 --nms_threshold 0.1 --class-aware-nms --num_eval_frames 6 --prompt-type lvd
142 |
143 | ## Generation 2
144 | python generate.py --model "gpt-4-1106-preview" --force_run_ind 1 --run-model lvd_modelscope256 --prompt-type lvd --template_version v0.1 --seed_offset 500 --com_loss_scale 0.03
145 | python scripts/eval_owl_vit.py --run_base_path ./img_generations/imgs_lvd_templatev0.1_gpt-4-1106-preview_lvd_modelscope256/run1 --skip_first_prompts 0 --verbose --detection_score_threshold 0.1 --nms_threshold 0.1 --class-aware-nms --num_eval_frames 6 --prompt-type lvd
146 | ```
147 |
148 | Each generation run will lead to 500 videos. Then, please take the average of the two generation runs, **taking 1000 videos into account** (i.e., two videos per generated layout). **The script supports resuming**, and you can simply rerun the same command to resume generation with `--force_run_ind` set (otherwise it will save into a new directory). You can also run these two runs in parallel. Check out `--skip_first_prompts` and `--num_prompts` in [generate.py](generate.py) to parallelize each run across GPUs.
149 |
150 | **Note:** there is a tradeoff between the strength of the control and the overall quality of the image with cross-attention control. The hyperparams used in benchmark generation are tuned for the benchmarks, and we observe quality degredation for some of the generated videos, despite with the correct objects generated. We recommend using lower guidance strength for non-benchmarking purpose so that even though some cases the objects may not align closely to the layout, the composition often turns out to be more natural. Similarly, for LVD with IGLIGEN adapters, `gligen_scheduled_sampling_beta` controls the strength of conditioning.
151 |
152 | ##### Our reference benchmark results
153 | | Method | Numeracy | Attribution | Visibility | Dynamics | Sequential | Overall |
154 | | ------------- | -------- | ----------- | ---------- | -------- | ---------- | ---------- |
155 | | ModelScope | 32 | 54 | 8 | 21 | 0 | 23.0 |
156 | | LVD (GPT-3.5) | 52 | 79 | 64 | 37 | 2 | 46.4% |
157 | | LVD (GPT-4) | 41 | 64 | 55 | 51 | 38 | **49.4%** |
158 |
159 | ## Contact us
160 | Please contact Long (Tony) Lian if you have any questions: `longlian@berkeley.edu`.
161 |
162 | ## Acknowledgements
163 | This repo is based on [LMD repo](https://github.com/TonyLianLong/LLM-groundedDiffusion), which is based on [diffusers](https://huggingface.co/docs/diffusers/index) and references [GLIGEN](https://github.com/gligen/GLIGEN), [layout-guidance](https://github.com/silent-chen/layout-guidance). This repo uses the same license as LMD.
164 |
165 | ## Citation
166 | If you use our work or our implementation in this repo, or find them helpful, please consider giving a citation.
167 | ```
168 | @inproceedings{lian2023llmgroundedvideo,
169 | title={LLM-grounded Video Diffusion Models},
170 | author={Lian, Long and Shi, Baifeng and Yala, Adam and Darrell, Trevor and Li, Boyi},
171 | booktitle={The Twelfth International Conference on Learning Representations},
172 | year={2023}
173 | }
174 |
175 | @article{lian2023llmgrounded,
176 | title={LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models},
177 | author={Lian, Long and Li, Boyi and Yala, Adam and Darrell, Trevor},
178 | journal={arXiv preprint arXiv:2305.13655},
179 | year={2023}
180 | }
181 | ```
182 |
--------------------------------------------------------------------------------
/assets/example_video_lvd_gligen_zeroscope.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TonyLianLong/LLM-groundedVideoDiffusion/22d90d2994ac3a938367383a11a5608f7fe73147/assets/example_video_lvd_gligen_zeroscope.gif
--------------------------------------------------------------------------------
/assets/example_video_lvd_zeroscope.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TonyLianLong/LLM-groundedVideoDiffusion/22d90d2994ac3a938367383a11a5608f7fe73147/assets/example_video_lvd_zeroscope.gif
--------------------------------------------------------------------------------
/assets/example_video_zeroscope_baseline.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TonyLianLong/LLM-groundedVideoDiffusion/22d90d2994ac3a938367383a11a5608f7fe73147/assets/example_video_zeroscope_baseline.gif
--------------------------------------------------------------------------------
/cache/cache_demo_v0.1_gpt-4-1106-preview.json:
--------------------------------------------------------------------------------
1 | {
2 | "A bear walks from the left to the right": [
3 | "Reasoning: A bear walking from left to right will have its x-coordinate increasing with each frame, while its y-coordinate should remain relatively constant as it is walking on a flat surface.\n\nFrame 1: [{'id': 0, 'name': 'bear', 'box': [0, 256, 100, 100]}]\nFrame 2: [{'id': 0, 'name': 'bear', 'box': [85, 256, 100, 100]}]\nFrame 3: [{'id': 0, 'name': 'bear', 'box': [170, 256, 100, 100]}]\nFrame 4: [{'id': 0, 'name': 'bear', 'box': [255, 256, 100, 100]}]\nFrame 5: [{'id': 0, 'name': 'bear', 'box': [340, 256, 100, 100]}]\nFrame 6: [{'id': 0, 'name': 'bear', 'box': [425, 256, 100, 100]}]\nBackground keyword: forest"
4 | ]
5 | }
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | from utils import parse, vis, cache
2 | from utils.llm import get_full_model_name, model_names, get_parsed_layout
3 | from utils.parse import show_video_boxes, size
4 | from tqdm import tqdm
5 | import os
6 | from prompt import get_prompts, template_versions
7 | import matplotlib.pyplot as plt
8 | import traceback
9 | import bdb
10 | import time
11 | import argparse
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--save-suffix", default=None, type=str)
15 | parser.add_argument(
16 | "--model",
17 | choices=model_names,
18 | required=True,
19 | help="LLM model to load the cache from",
20 | )
21 | parser.add_argument(
22 | "--repeats", default=1, type=int, help="Number of samples for each prompt"
23 | )
24 | parser.add_argument(
25 | "--regenerate",
26 | default=1,
27 | type=int,
28 | help="Number of regenerations. Different from repeats, regeneration happens after everything is generated",
29 | )
30 | parser.add_argument(
31 | "--force_run_ind",
32 | default=None,
33 | type=int,
34 | help="If this is enabled, we use this run_ind and skips generated images. If this is not enabled, we create a new run after existing runs.",
35 | )
36 | parser.add_argument(
37 | "--skip_first_prompts",
38 | default=0,
39 | type=int,
40 | help="Skip the first prompts in generation (useful for parallel generation)",
41 | )
42 | parser.add_argument(
43 | "--seed_offset",
44 | default=0,
45 | type=int,
46 | help="Offset to the seed (seed starts from this number)",
47 | )
48 | parser.add_argument(
49 | "--num_prompts",
50 | default=None,
51 | type=int,
52 | help="The number of prompts to generate (useful for parallel generation)",
53 | )
54 | parser.add_argument(
55 | "--run-model",
56 | default="lvd",
57 | choices=[
58 | "lvd",
59 | "lvd_zeroscope",
60 | "lvd_modelscope256",
61 | "lvd-gligen_modelscope256",
62 | "lvd-gligen_zeroscope",
63 | "lvd-plus_modelscope256",
64 | "lvd_modelscope512",
65 | "modelscope",
66 | "modelscope_256",
67 | "zeroscope",
68 | "zeroscope_xl",
69 | ],
70 | help="The model to use (modelscope has the option to generate with resolution 256x256)",
71 | )
72 | parser.add_argument("--visualize", action="store_true")
73 | parser.add_argument("--no-continue-on-error", action="store_true")
74 | parser.add_argument("--prompt-type", type=str, default="demo")
75 | parser.add_argument("--template_version", choices=template_versions, required=True)
76 | parser.add_argument("--dry-run", action="store_true", help="skip the generation")
77 |
78 | float_args = [
79 | "fg_top_p",
80 | "bg_top_p",
81 | "fg_weight",
82 | "bg_weight",
83 | "loss_threshold",
84 | "loss_scale",
85 | "boxdiff_loss_scale",
86 | "com_loss_scale",
87 | "gligen_scheduled_sampling_beta",
88 | ]
89 | for float_arg in float_args:
90 | parser.add_argument("--" + float_arg, default=None, type=float)
91 |
92 | # `use_ratio_based_loss` should be 0 or 1 (as it is a bool)
93 | int_args = [
94 | "num_inference_steps",
95 | "max_iter",
96 | "max_index_step",
97 | "num_frames",
98 | "use_ratio_based_loss",
99 | "boxdiff_normed",
100 | ]
101 | for int_arg in int_args:
102 | parser.add_argument("--" + int_arg, default=None, type=int)
103 |
104 | str_args = []
105 | for str_arg in str_args:
106 | parser.add_argument("--" + str_arg, default=None, type=str)
107 |
108 | args = parser.parse_args()
109 |
110 |
111 | if not args.dry_run:
112 | run_model = args.run_model
113 | baseline = run_model in [
114 | "modelscope",
115 | "zeroscope",
116 | "modelscope_256",
117 | "zeroscope_xl",
118 | ]
119 |
120 | if "_" in run_model:
121 | option = run_model.split("_")[1]
122 | else:
123 | option = ""
124 |
125 | if run_model.startswith("lvd-plus"):
126 | import generation.lvd_plus as generation
127 |
128 | base_model = option if option else "modelscope"
129 | H, W = generation.init(base_model=base_model)
130 | elif run_model.startswith("lvd-gligen"):
131 | import generation.lvd_gligen as generation
132 |
133 | base_model = option if option else "modelscope"
134 | H, W = generation.init(base_model=base_model)
135 | elif (run_model == "lvd") or (run_model.startswith("lvd_")):
136 | import generation.lvd as generation
137 |
138 | # Use modelscope as the default model
139 | base_model = option if option else "modelscope"
140 | H, W = generation.init(base_model=base_model)
141 | elif run_model == "modelscope" or run_model == "modelscope_256":
142 | import generation.modelscope_dpm as generation
143 |
144 | H, W = generation.init(option=option)
145 | elif run_model == "zeroscope" or run_model == "zeroscope_xl":
146 | import generation.zeroscope_dpm as generation
147 |
148 | H, W = generation.init(option=option)
149 | else:
150 | raise ValueError(f"Unknown model: {run_model}")
151 |
152 | if "zeroscope" in run_model and (
153 | (args.num_frames is not None and args.num_frames < 24)
154 | or ((not baseline) and args.num_frames is None)
155 | ):
156 | # num_frames is 16 by default in non-baseline models (as it uses modelscope by default)
157 | raise ValueError(
158 | "Running zeroscope with fewer than 24 frames. This may lead to suboptimal results. Comment this out if you still want to run."
159 | )
160 |
161 | version = generation.version
162 | assert (
163 | version == args.run_model.split("_")[0]
164 | ), f"{version} != {args.run_model.split('_')[0]}"
165 | run = generation.run
166 | else:
167 | version = "dry_run"
168 | run = None
169 | generation = argparse.Namespace()
170 |
171 | # set visualizations to no-op in batch generation
172 | for k in vis.__dict__.keys():
173 | if k.startswith("visualize"):
174 | vis.__dict__[k] = lambda *args, **kwargs: None
175 |
176 |
177 | ## Visualize
178 | def visualize_layout(parsed_layout):
179 | H, W = size
180 | condition = parse.parsed_layout_to_condition(
181 | parsed_layout, tokenizer=None, height=H, width=W, verbose=True
182 | )
183 |
184 | show_video_boxes(condition, ind=ind, save=True)
185 |
186 | print(f"Visualize masks at {parse.img_dir}")
187 |
188 |
189 | # close the figure when plt.show is called
190 | plt.show = plt.close
191 |
192 | prompt_type = args.prompt_type
193 | template_version = args.template_version
194 | json_template = "json" in template_version
195 |
196 | # Use cache
197 | model = get_full_model_name(model=args.model)
198 |
199 | if not baseline:
200 | cache.cache_format = "json"
201 | cache.cache_path = f'cache/cache_{args.prompt_type.replace("lmd_", "")}_{template_version}_{model}.json'
202 | print(f"Loading LLM responses from cache {cache.cache_path}")
203 | cache.init_cache(allow_nonexist=False)
204 |
205 | prompts = get_prompts(prompt_type)
206 |
207 | save_suffix = ("_" + args.save_suffix) if args.save_suffix else ""
208 | repeats = args.repeats
209 | seed_offset = args.seed_offset
210 |
211 | model_in_base_save_dir = "" if model == "gpt-4" else f"_{model}"
212 | base_save_dir = f"img_generations/imgs_{prompt_type}_template{args.template_version}{model_in_base_save_dir}_{run_model}{save_suffix}"
213 |
214 | run_kwargs = {}
215 |
216 | argnames = float_args + int_args + str_args
217 |
218 | for argname in argnames:
219 | argvalue = getattr(args, argname)
220 | if argvalue is not None:
221 | run_kwargs[argname] = argvalue
222 |
223 | is_notebook = False
224 |
225 | if args.force_run_ind is not None:
226 | run_ind = args.force_run_ind
227 | save_dir = f"{base_save_dir}/run{run_ind}"
228 | else:
229 | run_ind = 0
230 | while True:
231 | save_dir = f"{base_save_dir}/run{run_ind}"
232 | if not os.path.exists(save_dir):
233 | break
234 | run_ind += 1
235 |
236 | if hasattr(generation, "use_autocast") and generation.use_autocast:
237 | save_dir += "_amp"
238 |
239 | print(f"Save dir: {save_dir}")
240 |
241 | LARGE_CONSTANT = 123456789
242 | LARGE_CONSTANT2 = 56789
243 | LARGE_CONSTANT3 = 6789
244 |
245 | ind = 0
246 | if args.regenerate > 1:
247 | # Need to fix the ind
248 | assert args.skip_first_prompts == 0
249 |
250 | for regenerate_ind in range(args.regenerate):
251 | print("regenerate_ind:", regenerate_ind)
252 | if not baseline:
253 | cache.reset_cache_access()
254 | for prompt_ind, prompt in enumerate(tqdm(prompts, desc=f"Run: {save_dir}")):
255 | if prompt_ind < args.skip_first_prompts:
256 | ind += 1
257 | continue
258 | if args.num_prompts is not None and prompt_ind >= (
259 | args.skip_first_prompts + args.num_prompts
260 | ):
261 | ind += 1
262 | continue
263 |
264 | # get prompt from prompts, if prompt is a list, then prompt includes both the prompt and kwargs
265 | if isinstance(prompt, list):
266 | prompt, kwargs = prompt
267 | else:
268 | kwargs = {}
269 |
270 | prompt = prompt.strip().rstrip(".")
271 |
272 | ind_override = kwargs.get("seed", None)
273 |
274 | # Load from cache
275 | if baseline:
276 | resp = None
277 | else:
278 | resp = cache.get_cache(prompt)
279 |
280 | if resp is None:
281 | print(f"Cache miss, skipping prompt: {prompt}")
282 | ind += 1
283 | continue
284 |
285 | print(f"***run: {run_ind}***")
286 | print(f"prompt: {prompt}, resp: {resp}")
287 | parse.img_dir = f"{save_dir}/{ind}"
288 | # Skip if image is already generared
289 | if not (
290 | os.path.exists(parse.img_dir)
291 | and len(
292 | [
293 | img
294 | for img in os.listdir(parse.img_dir)
295 | if img.startswith("video") and img.endswith("joblib")
296 | ]
297 | )
298 | >= args.repeats
299 | ):
300 | os.makedirs(parse.img_dir, exist_ok=True)
301 | try:
302 | if baseline:
303 | parsed_layout = {"Prompt": prompt}
304 | else:
305 | parsed_layout, _ = get_parsed_layout(
306 | prompt,
307 | max_partial_response_retries=1,
308 | override_response=resp,
309 | json_template=json_template,
310 | )
311 |
312 | print("parsed_layout:", parsed_layout)
313 |
314 | if args.dry_run:
315 | # Skip generation
316 | ind += 1
317 | continue
318 |
319 | if args.visualize:
320 | assert (
321 | not baseline
322 | ), "baseline methods do not have layouts from the LLM to visualize"
323 | visualize_layout(parsed_layout)
324 |
325 | original_ind_base = (
326 | ind_override + regenerate_ind * LARGE_CONSTANT2
327 | if ind_override is not None
328 | else ind
329 | )
330 |
331 | for repeat_ind in range(repeats):
332 | ind_offset = repeat_ind * LARGE_CONSTANT3 + seed_offset
333 | run(
334 | parsed_layout,
335 | seed=original_ind_base + ind_offset,
336 | repeat_ind=repeat_ind,
337 | **run_kwargs,
338 | )
339 |
340 | except (KeyboardInterrupt, bdb.BdbQuit) as e:
341 | print(e)
342 | exit()
343 | except RuntimeError:
344 | print(
345 | "***RuntimeError: might run out of memory, skipping the current one***"
346 | )
347 | print(traceback.format_exc())
348 | time.sleep(10)
349 | except Exception as e:
350 | print(f"***Error: {e}***")
351 | print(traceback.format_exc())
352 | if args.no_continue_on_error:
353 | raise e
354 | else:
355 | print(f"Image exists at {parse.img_dir}, skipping")
356 | ind += 1
357 |
358 | if not baseline and cache.values_accessed() != len(prompts):
359 | print(
360 | f"**Cache is hit {cache.values_accessed()} time(s) but we have {len(prompts)} prompts. There may be cache misses or inconsistencies between the prompts and the cache such as extra items in the cache.**"
361 | )
362 |
--------------------------------------------------------------------------------
/generation/lvd.py:
--------------------------------------------------------------------------------
1 | from models.controllable_pipeline_text_to_video_synth import TextToVideoSDPipeline
2 | from diffusers import DPMSolverMultistepScheduler
3 | from models.unet_3d_condition import UNet3DConditionModel
4 | from utils import parse, vis
5 | from prompt import negative_prompt
6 | import utils
7 | import numpy as np
8 | import torch
9 | from PIL import Image
10 | import os
11 |
12 | version = "lvd"
13 |
14 | # %%
15 | # H, W are generation H and W. box_W and box_W are for scaling the boxes to [0, 1].
16 | pipe, base_attn_dim, H, W, box_H, box_W = None, None, None, None, None, None
17 |
18 |
19 | def init(base_model):
20 | global pipe, base_attn_dim, H, W, box_H, box_W
21 | if base_model == "modelscope512":
22 | model_key = "damo-vilab/text-to-video-ms-1.7b"
23 | base_attn_dim = (64, 64)
24 | H, W = 512, 512
25 | box_H, box_W = parse.size
26 | elif base_model == "modelscope256":
27 | model_key = "damo-vilab/text-to-video-ms-1.7b"
28 | base_attn_dim = (32, 32)
29 | H, W = 256, 256
30 | box_H, box_W = parse.size
31 | elif base_model == "zeroscope":
32 | model_key = "cerspense/zeroscope_v2_576w"
33 | base_attn_dim = (40, 72)
34 | H, W = 320, 576
35 | box_H, box_W = parse.size
36 | else:
37 | raise ValueError(f"Unknown base model: {base_model}")
38 |
39 | unet = UNet3DConditionModel.from_pretrained(model_key, subfolder="unet").to(
40 | torch.float16
41 | )
42 | pipe = TextToVideoSDPipeline.from_pretrained(
43 | model_key, unet=unet, torch_dtype=torch.float16
44 | )
45 | # The default one is DDIMScheduler
46 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
47 | pipe.to("cuda")
48 | pipe.enable_vae_slicing()
49 |
50 | # No auxiliary guidance
51 | pipe.guidance_models = None
52 |
53 | return H, W
54 |
55 |
56 | # %%
57 | upsample_scale, upsample_mode = 1, "bilinear"
58 |
59 | # %%
60 | # For visualizations: set to False to not save
61 | return_guidance_saved_attn = False
62 | # This is the main attn, not the attn for guidance.
63 | save_keys = []
64 |
65 | # %%
66 | overall_guidance_attn_keys = [
67 | ("down", 1, 0, 0),
68 | ("down", 2, 0, 0),
69 | ("down", 2, 1, 0),
70 | ("up", 1, 0, 0),
71 | ("up", 1, 1, 0),
72 | ("up", 2, 2, 0),
73 | ]
74 |
75 | # Seems like `enable_model_cpu_offload` performs deepcopy so `save_attn_to_dict` does not save the attn
76 | cross_attention_kwargs = {
77 | "save_attn_to_dict": {},
78 | "save_keys": save_keys,
79 | # This is for visualizations
80 | # 'offload_cross_attn_to_cpu': True
81 | }
82 |
83 |
84 | # %%
85 | def run(
86 | parsed_layout,
87 | seed,
88 | num_inference_steps=40,
89 | num_frames=16,
90 | repeat_ind=None,
91 | save_annotated_videos=False,
92 | loss_scale=5.0,
93 | loss_threshold=200.0,
94 | max_iter=5,
95 | max_index_step=10,
96 | fg_top_p=0.75,
97 | bg_top_p=0.75,
98 | fg_weight=1.0,
99 | bg_weight=4.0,
100 | attn_sync_weight=0.0,
101 | boxdiff_loss_scale=0.0,
102 | boxdiff_normed=True,
103 | com_loss_scale=0.0,
104 | use_ratio_based_loss=False,
105 | save_formats=["gif", "joblib"],
106 | ):
107 | condition = parse.parsed_layout_to_condition(
108 | parsed_layout,
109 | tokenizer=pipe.tokenizer,
110 | height=box_H,
111 | width=box_W,
112 | num_condition_frames=num_frames,
113 | verbose=True,
114 | )
115 | prompt, bboxes, phrases, object_positions, token_map = (
116 | condition.prompt,
117 | condition.boxes,
118 | condition.phrases,
119 | condition.object_positions,
120 | condition.token_map,
121 | )
122 |
123 | backward_guidance_kwargs = dict(
124 | bboxes=bboxes,
125 | object_positions=object_positions,
126 | loss_scale=loss_scale,
127 | loss_threshold=loss_threshold,
128 | max_iter=max_iter,
129 | max_index_step=max_index_step,
130 | fg_top_p=fg_top_p,
131 | bg_top_p=bg_top_p,
132 | fg_weight=fg_weight,
133 | bg_weight=bg_weight,
134 | use_ratio_based_loss=use_ratio_based_loss,
135 | guidance_attn_keys=overall_guidance_attn_keys,
136 | exclude_bg_heads=False,
137 | upsample_scale=upsample_scale,
138 | upsample_mode=upsample_mode,
139 | base_attn_dim=base_attn_dim,
140 | attn_sync_weight=attn_sync_weight,
141 | boxdiff_loss_scale=boxdiff_loss_scale,
142 | boxdiff_normed=boxdiff_normed,
143 | com_loss_scale=com_loss_scale,
144 | verbose=True,
145 | )
146 |
147 | if repeat_ind is not None:
148 | save_suffix = repeat_ind
149 |
150 | else:
151 | save_suffix = f"seed{seed}"
152 |
153 | save_path = f"{parse.img_dir}/video_{save_suffix}.gif"
154 | if os.path.exists(save_path):
155 | print(f"Skipping {save_path}")
156 | return
157 |
158 | print("Generating")
159 | generator = torch.Generator(device="cuda").manual_seed(seed)
160 |
161 | video_frames = pipe(
162 | prompt,
163 | negative_prompt=negative_prompt,
164 | num_inference_steps=num_inference_steps,
165 | height=H,
166 | width=W,
167 | num_frames=num_frames,
168 | cross_attention_kwargs=cross_attention_kwargs,
169 | generator=generator,
170 | guidance_callback=None,
171 | backward_guidance_kwargs=backward_guidance_kwargs,
172 | return_guidance_saved_attn=return_guidance_saved_attn,
173 | guidance_type="main",
174 | ).frames
175 | video_frames = (video_frames[0] * 255.0).astype(np.uint8)
176 |
177 | # %%
178 |
179 | if save_annotated_videos:
180 | annotated_frames = [
181 | np.array(
182 | utils.draw_box(
183 | Image.fromarray(video_frame), [bbox[i] for bbox in bboxes], phrases
184 | )
185 | )
186 | for i, video_frame in enumerate(video_frames)
187 | ]
188 | vis.save_frames(
189 | f"{save_path}/video_seed{seed}_with_box",
190 | frames=annotated_frames,
191 | formats="gif",
192 | )
193 |
194 | vis.save_frames(
195 | f"{parse.img_dir}/video_{save_suffix}", video_frames, formats=save_formats
196 | )
197 |
--------------------------------------------------------------------------------
/generation/lvd_gligen.py:
--------------------------------------------------------------------------------
1 | from models.controllable_pipeline_text_to_video_synth import TextToVideoSDPipeline
2 | from diffusers import DPMSolverMultistepScheduler
3 | from models.unet_3d_condition import UNet3DConditionModel
4 | from utils import parse, vis
5 | from prompt import negative_prompt
6 | import utils
7 | import numpy as np
8 | import torch
9 | from PIL import Image
10 | import os
11 |
12 | version = "lvd-gligen"
13 |
14 | # %%
15 | # H, W are generation H and W. box_W and box_W are for scaling the boxes to [0, 1].
16 | pipe, H, W, box_H, box_W = None, None, None, None, None
17 |
18 |
19 | def init(base_model):
20 | global pipe, H, W, box_H, box_W
21 | if base_model == "modelscope256":
22 | model_key = "longlian/text-to-video-lvd-ms"
23 | H, W = 256, 256
24 | box_H, box_W = parse.size
25 | elif base_model == "zeroscope":
26 | model_key = "longlian/text-to-video-lvd-zs"
27 | H, W = 320, 576
28 | box_H, box_W = parse.size
29 | else:
30 | raise ValueError(f"Unknown base model: {base_model}")
31 |
32 | pipe = TextToVideoSDPipeline.from_pretrained(
33 | model_key, trust_remote_code=True, torch_dtype=torch.float16
34 | )
35 | # The default one is DDIMScheduler
36 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
37 | pipe.to("cuda")
38 | pipe.enable_vae_slicing()
39 |
40 | # No auxiliary guidance
41 | pipe.guidance_models = None
42 |
43 | return H, W
44 |
45 |
46 | # %%
47 | upsample_scale, upsample_mode = 1, "bilinear"
48 |
49 | # %%
50 |
51 | # Seems like `enable_model_cpu_offload` performs deepcopy so `save_attn_to_dict` does not save the attn
52 | cross_attention_kwargs = {
53 | # This is for visualizations
54 | # 'offload_cross_attn_to_cpu': True
55 | }
56 |
57 |
58 | # %%
59 | def run(
60 | parsed_layout,
61 | seed,
62 | num_inference_steps=40,
63 | num_frames=16,
64 | gligen_scheduled_sampling_beta=1.0,
65 | repeat_ind=None,
66 | save_annotated_videos=False,
67 | save_formats=["gif", "joblib"],
68 | ):
69 | condition = parse.parsed_layout_to_condition(
70 | parsed_layout,
71 | tokenizer=pipe.tokenizer,
72 | height=box_H,
73 | width=box_W,
74 | num_condition_frames=num_frames,
75 | verbose=True,
76 | )
77 | prompt, bboxes, phrases, object_positions, token_map = (
78 | condition.prompt,
79 | condition.boxes,
80 | condition.phrases,
81 | condition.object_positions,
82 | condition.token_map,
83 | )
84 |
85 | if repeat_ind is not None:
86 | save_suffix = repeat_ind
87 |
88 | else:
89 | save_suffix = f"seed{seed}"
90 |
91 | save_path = f"{parse.img_dir}/video_{save_suffix}.gif"
92 | if os.path.exists(save_path):
93 | print(f"Skipping {save_path}")
94 | return
95 |
96 | print("Generating")
97 | generator = torch.Generator(device="cuda").manual_seed(seed)
98 |
99 | lvd_gligen_boxes = []
100 | lvd_gligen_phrases = []
101 | for i in range(num_frames):
102 | lvd_gligen_boxes.append(
103 | [
104 | bboxes_item[i]
105 | for phrase, bboxes_item in zip(phrases, bboxes)
106 | if bboxes_item[i] != [0.0, 0.0, 0.0, 0.0]
107 | ]
108 | )
109 | lvd_gligen_phrases.append(
110 | [
111 | phrase
112 | for phrase, bboxes_item in zip(phrases, bboxes)
113 | if bboxes_item[i] != [0.0, 0.0, 0.0, 0.0]
114 | ]
115 | )
116 |
117 | video_frames = pipe(
118 | prompt,
119 | negative_prompt=negative_prompt,
120 | num_inference_steps=num_inference_steps,
121 | height=H,
122 | width=W,
123 | num_frames=num_frames,
124 | cross_attention_kwargs=cross_attention_kwargs,
125 | generator=generator,
126 | lvd_gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
127 | lvd_gligen_boxes=lvd_gligen_boxes,
128 | lvd_gligen_phrases=lvd_gligen_phrases,
129 | ).frames
130 | # `diffusers` has a backward-breaking change
131 | # video_frames = (video_frames[0] * 255.).astype(np.uint8)
132 |
133 | # %%
134 |
135 | if save_annotated_videos:
136 | annotated_frames = [
137 | np.array(
138 | utils.draw_box(
139 | Image.fromarray(video_frame), [bbox[i] for bbox in bboxes], phrases
140 | )
141 | )
142 | for i, video_frame in enumerate(video_frames)
143 | ]
144 | vis.save_frames(
145 | f"{save_path}/video_seed{seed}_with_box",
146 | frames=annotated_frames,
147 | formats="gif",
148 | )
149 |
150 | vis.save_frames(
151 | f"{parse.img_dir}/video_{save_suffix}", video_frames, formats=save_formats
152 | )
153 |
--------------------------------------------------------------------------------
/generation/lvd_plus.py:
--------------------------------------------------------------------------------
1 | from models.controllable_pipeline_text_to_video_synth import TextToVideoSDPipeline
2 | from diffusers import DPMSolverMultistepScheduler
3 | from models.unet_3d_condition import UNet3DConditionModel
4 | from utils import parse, vis
5 | from prompt import negative_prompt
6 | import utils
7 | import numpy as np
8 | import torch
9 | from PIL import Image
10 | import os
11 |
12 | version = "lvd-plus"
13 |
14 | # %%
15 | # H, W are generation H and W. box_W and box_W are for scaling the boxes to [0, 1].
16 | pipe, base_attn_dim, H, W, box_H, box_W = None, None, None, None, None, None
17 |
18 |
19 | def init(base_model):
20 | global pipe, base_attn_dim, H, W, box_H, box_W
21 | if base_model == "modelscope256":
22 | model_key = "longlian/text-to-video-lvd-ms"
23 | base_attn_dim = (32, 32)
24 | H, W = 256, 256
25 | box_H, box_W = parse.size
26 | else:
27 | raise ValueError(f"Unknown base model: {base_model}")
28 |
29 | unet = UNet3DConditionModel.from_pretrained(
30 | model_key, subfolder="unet", revision="weights_only"
31 | ).to(torch.float16)
32 | pipe = TextToVideoSDPipeline.from_pretrained(
33 | model_key, unet=unet, torch_dtype=torch.float16, revision="weights_only"
34 | )
35 | # The default one is DDIMScheduler
36 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
37 | pipe.to("cuda")
38 | pipe.enable_vae_slicing()
39 |
40 | # No auxiliary guidance
41 | pipe.guidance_models = None
42 |
43 | return H, W
44 |
45 |
46 | # %%
47 | upsample_scale, upsample_mode = 1, "bilinear"
48 |
49 | # %%
50 | # For visualizations: set to False to not save
51 | return_guidance_saved_attn = False
52 | # This is the main attn, not the attn for guidance.
53 | save_keys = []
54 |
55 | # %%
56 | overall_guidance_attn_keys = [
57 | ("down", 1, 0, 0),
58 | ("down", 2, 0, 0),
59 | ("down", 2, 1, 0),
60 | ("up", 1, 0, 0),
61 | ("up", 1, 1, 0),
62 | ("up", 2, 2, 0),
63 | ]
64 |
65 | # Seems like `enable_model_cpu_offload` performs deepcopy so `save_attn_to_dict` does not save the attn
66 | cross_attention_kwargs = {
67 | "save_attn_to_dict": {},
68 | "save_keys": save_keys,
69 | # This is for visualizations
70 | # 'offload_cross_attn_to_cpu': True
71 | }
72 |
73 |
74 | # %%
75 | def run(
76 | parsed_layout,
77 | seed,
78 | num_inference_steps=40,
79 | num_frames=16,
80 | gligen_scheduled_sampling_beta=1.0,
81 | repeat_ind=None,
82 | save_annotated_videos=False,
83 | loss_scale=5.0,
84 | loss_threshold=200.0,
85 | max_iter=5,
86 | max_index_step=10,
87 | fg_top_p=0.75,
88 | bg_top_p=0.75,
89 | fg_weight=1.0,
90 | bg_weight=4.0,
91 | attn_sync_weight=0.0,
92 | boxdiff_loss_scale=0.0,
93 | boxdiff_normed=True,
94 | com_loss_scale=0.0,
95 | use_ratio_based_loss=False,
96 | save_formats=["gif", "joblib"],
97 | ):
98 | condition = parse.parsed_layout_to_condition(
99 | parsed_layout,
100 | tokenizer=pipe.tokenizer,
101 | height=box_H,
102 | width=box_W,
103 | num_condition_frames=num_frames,
104 | verbose=True,
105 | )
106 | prompt, bboxes, phrases, object_positions, token_map = (
107 | condition.prompt,
108 | condition.boxes,
109 | condition.phrases,
110 | condition.object_positions,
111 | condition.token_map,
112 | )
113 |
114 | backward_guidance_kwargs = dict(
115 | bboxes=bboxes,
116 | object_positions=object_positions,
117 | loss_scale=loss_scale,
118 | loss_threshold=loss_threshold,
119 | max_iter=max_iter,
120 | max_index_step=max_index_step,
121 | fg_top_p=fg_top_p,
122 | bg_top_p=bg_top_p,
123 | fg_weight=fg_weight,
124 | bg_weight=bg_weight,
125 | use_ratio_based_loss=use_ratio_based_loss,
126 | guidance_attn_keys=overall_guidance_attn_keys,
127 | exclude_bg_heads=False,
128 | upsample_scale=upsample_scale,
129 | upsample_mode=upsample_mode,
130 | base_attn_dim=base_attn_dim,
131 | attn_sync_weight=attn_sync_weight,
132 | boxdiff_loss_scale=boxdiff_loss_scale,
133 | boxdiff_normed=boxdiff_normed,
134 | com_loss_scale=com_loss_scale,
135 | verbose=True,
136 | )
137 |
138 | if repeat_ind is not None:
139 | save_suffix = repeat_ind
140 |
141 | else:
142 | save_suffix = f"seed{seed}"
143 |
144 | save_path = f"{parse.img_dir}/video_{save_suffix}.gif"
145 | if os.path.exists(save_path):
146 | print(f"Skipping {save_path}")
147 | return
148 |
149 | print("Generating")
150 | generator = torch.Generator(device="cuda").manual_seed(seed)
151 |
152 | # print(bboxes, phrases)
153 |
154 | lvd_gligen_boxes = []
155 | lvd_gligen_phrases = []
156 | for i in range(num_frames):
157 | lvd_gligen_boxes.append(
158 | [
159 | bboxes_item[i]
160 | for phrase, bboxes_item in zip(phrases, bboxes)
161 | if bboxes_item[i] != [0.0, 0.0, 0.0, 0.0]
162 | ]
163 | )
164 | lvd_gligen_phrases.append(
165 | [
166 | phrase
167 | for phrase, bboxes_item in zip(phrases, bboxes)
168 | if bboxes_item[i] != [0.0, 0.0, 0.0, 0.0]
169 | ]
170 | )
171 |
172 | video_frames = pipe(
173 | prompt,
174 | negative_prompt=negative_prompt,
175 | num_inference_steps=num_inference_steps,
176 | height=H,
177 | width=W,
178 | num_frames=num_frames,
179 | cross_attention_kwargs=cross_attention_kwargs,
180 | generator=generator,
181 | gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
182 | gligen_boxes=lvd_gligen_boxes,
183 | gligen_phrases=lvd_gligen_phrases,
184 | guidance_callback=None,
185 | backward_guidance_kwargs=backward_guidance_kwargs,
186 | return_guidance_saved_attn=return_guidance_saved_attn,
187 | guidance_type="main",
188 | ).frames
189 | video_frames = (video_frames[0] * 255.0).astype(np.uint8)
190 |
191 | # %%
192 |
193 | if save_annotated_videos:
194 | annotated_frames = [
195 | np.array(
196 | utils.draw_box(
197 | Image.fromarray(video_frame), [bbox[i] for bbox in bboxes], phrases
198 | )
199 | )
200 | for i, video_frame in enumerate(video_frames)
201 | ]
202 | vis.save_frames(
203 | f"{save_path}/video_seed{seed}_with_box",
204 | frames=annotated_frames,
205 | formats="gif",
206 | )
207 |
208 | vis.save_frames(
209 | f"{parse.img_dir}/video_{save_suffix}", video_frames, formats=save_formats
210 | )
211 |
--------------------------------------------------------------------------------
/generation/modelscope_dpm.py:
--------------------------------------------------------------------------------
1 | from diffusers import TextToVideoSDPipeline
2 | from diffusers import DPMSolverMultistepScheduler
3 | from utils import parse, vis
4 | from prompt import negative_prompt
5 | import torch
6 | import numpy as np
7 | import os
8 |
9 | version = "modelscope"
10 |
11 | # %%
12 | model_key = "damo-vilab/text-to-video-ms-1.7b"
13 |
14 | pipe = TextToVideoSDPipeline.from_pretrained(model_key, torch_dtype=torch.float16)
15 | # The default one is DDIMScheduler
16 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
17 | # pipe.unet = UNet3DConditionModel.from_pretrained(
18 | # model_key, subfolder="unet").to(torch.float16)
19 | pipe.to("cuda")
20 | pipe.enable_vae_slicing()
21 | # No auxiliary guidance
22 | pipe.guidance_models = None
23 |
24 | # %%
25 | H, W = None, None
26 |
27 |
28 | def init(option):
29 | global H, W
30 | if option == "":
31 | H, W = 512, 512
32 | elif option == "256":
33 | H, W = 256, 256
34 | else:
35 | raise ValueError(f"Unknown option: {option}")
36 |
37 | return H, W
38 |
39 |
40 | def run(
41 | parsed_layout,
42 | seed,
43 | *,
44 | num_inference_steps=40,
45 | num_frames=16,
46 | repeat_ind=None,
47 | save_formats=["gif", "joblib"],
48 | ):
49 | prompt = parsed_layout["Prompt"]
50 |
51 | if repeat_ind is not None:
52 | save_suffix = repeat_ind
53 | else:
54 | save_suffix = f"seed{seed}"
55 |
56 | save_path = f"{parse.img_dir}/video_{save_suffix}.gif"
57 | if os.path.exists(save_path):
58 | print(f"Skipping {save_path}")
59 | return
60 |
61 | print("Generating")
62 | generator = torch.Generator(device="cuda").manual_seed(seed)
63 |
64 | video_frames = pipe(
65 | prompt,
66 | negative_prompt=negative_prompt,
67 | num_inference_steps=num_inference_steps,
68 | height=H,
69 | width=W,
70 | num_frames=num_frames,
71 | cross_attention_kwargs=None,
72 | generator=generator,
73 | ).frames
74 |
75 | video_frames = (video_frames[0] * 255.0).astype(np.uint8)
76 |
77 | # %%
78 | vis.save_frames(
79 | f"{parse.img_dir}/video_{save_suffix}", video_frames, formats=save_formats
80 | )
81 |
--------------------------------------------------------------------------------
/generation/zeroscope_dpm.py:
--------------------------------------------------------------------------------
1 | from diffusers import TextToVideoSDPipeline, VideoToVideoSDPipeline
2 | from diffusers import DPMSolverMultistepScheduler
3 | from utils import parse, vis
4 | from prompt import negative_prompt
5 | import torch
6 | from PIL import Image
7 | import numpy as np
8 | import os
9 |
10 | version = "zeroscope"
11 |
12 | # %%
13 | H, W = None, None
14 |
15 | pipe = TextToVideoSDPipeline.from_pretrained(
16 | "cerspense/zeroscope_v2_576w",
17 | # unet = UNet3DConditionModel.from_pretrained(
18 | # "cerspense/zeroscope_v2_576w", subfolder="unet"
19 | # ).to(torch.float16),
20 | torch_dtype=torch.float16,
21 | )
22 | # The default one is DDIMScheduler
23 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
24 | pipe.to("cuda")
25 | pipe.enable_vae_slicing()
26 | pipe_xl = None
27 |
28 |
29 | def init(option):
30 | global pipe_xl, H, W
31 |
32 | if option == "":
33 | H, W = 320, 576
34 | elif option == "xl":
35 | # the base model is still in 320, 576. The xl model outputs (576, 1024).
36 | H, W = 320, 576
37 |
38 | pipe_xl = VideoToVideoSDPipeline.from_pretrained(
39 | "cerspense/zeroscope_v2_XL", torch_dtype=torch.float16
40 | )
41 | pipe_xl.scheduler = DPMSolverMultistepScheduler.from_config(
42 | pipe.scheduler.config
43 | )
44 | # pipe_xl.enable_model_cpu_offload()
45 | pipe_xl.to("cuda")
46 | pipe_xl.enable_vae_slicing()
47 | else:
48 | raise ValueError(f"Unknown option: {option}")
49 | # WIP
50 | return H, W
51 |
52 |
53 | # %%
54 | def run(
55 | parsed_layout,
56 | seed,
57 | *,
58 | num_inference_steps=40,
59 | num_frames=24,
60 | repeat_ind=None,
61 | save_formats=["gif", "joblib"],
62 | ):
63 | prompt = parsed_layout["Prompt"]
64 |
65 | if repeat_ind is not None:
66 | save_suffix = repeat_ind
67 | else:
68 | save_suffix = f"seed{seed}"
69 |
70 | save_path = f"{parse.img_dir}/video_{save_suffix}.gif"
71 | if os.path.exists(save_path):
72 | print(f"Skipping {save_path}")
73 | return
74 |
75 | print("Generating")
76 | generator = torch.Generator(device="cuda").manual_seed(seed)
77 |
78 | video_frames = pipe(
79 | prompt,
80 | negative_prompt=negative_prompt,
81 | num_inference_steps=num_inference_steps,
82 | height=H,
83 | width=W,
84 | num_frames=num_frames,
85 | cross_attention_kwargs=None,
86 | generator=generator,
87 | ).frames
88 | video_frames = (video_frames[0] * 255.0).astype(np.uint8)
89 |
90 | if pipe_xl is not None:
91 | print("Refining")
92 | video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames]
93 | video_frames_xl = pipe_xl(
94 | prompt,
95 | negative_prompt=negative_prompt,
96 | num_inference_steps=num_inference_steps,
97 | video=video,
98 | strength=0.6,
99 | generator=generator,
100 | ).frames
101 |
102 | video_frames = (video_frames[0] * 255.0).astype(np.uint8)
103 |
104 | print("Saving")
105 | vis.save_frames(
106 | f"{parse.img_dir}/video_xl_{save_suffix}",
107 | video_frames_xl,
108 | formats=save_formats,
109 | )
110 | else:
111 | vis.save_frames(
112 | f"{parse.img_dir}/video_{save_suffix}", video_frames, formats=save_formats
113 | )
114 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import *
2 |
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # For compatibility
4 | from utils import torch_device
5 |
6 |
7 | def encode_prompts(
8 | tokenizer,
9 | text_encoder,
10 | prompts,
11 | negative_prompt="",
12 | return_full_only=False,
13 | one_uncond_input_only=False,
14 | ):
15 | if negative_prompt == "":
16 | print("Note that negative_prompt is an empty string")
17 |
18 | text_input = tokenizer(
19 | prompts,
20 | padding="max_length",
21 | max_length=tokenizer.model_max_length,
22 | truncation=True,
23 | return_tensors="pt",
24 | )
25 |
26 | max_length = text_input.input_ids.shape[-1]
27 | if one_uncond_input_only:
28 | num_uncond_input = 1
29 | else:
30 | num_uncond_input = len(prompts)
31 | uncond_input = tokenizer(
32 | [negative_prompt] * num_uncond_input,
33 | padding="max_length",
34 | max_length=max_length,
35 | return_tensors="pt",
36 | )
37 |
38 | with torch.no_grad():
39 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
40 | cond_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
41 |
42 | if one_uncond_input_only:
43 | return uncond_embeddings, cond_embeddings
44 |
45 | text_embeddings = torch.cat([uncond_embeddings, cond_embeddings])
46 |
47 | if return_full_only:
48 | return text_embeddings
49 | return text_embeddings, uncond_embeddings, cond_embeddings
50 |
51 |
52 | def process_input_embeddings(input_embeddings):
53 | assert isinstance(input_embeddings, (tuple, list))
54 | if len(input_embeddings) == 3:
55 | # input_embeddings: text_embeddings, uncond_embeddings, cond_embeddings
56 | # Assume `uncond_embeddings` is full (has batch size the same as cond_embeddings)
57 | _, uncond_embeddings, cond_embeddings = input_embeddings
58 | assert (
59 | uncond_embeddings.shape[0] == cond_embeddings.shape[0]
60 | ), f"{uncond_embeddings.shape[0]} != {cond_embeddings.shape[0]}"
61 | return input_embeddings
62 | elif len(input_embeddings) == 2:
63 | # input_embeddings: uncond_embeddings, cond_embeddings
64 | # uncond_embeddings may have only one item
65 | uncond_embeddings, cond_embeddings = input_embeddings
66 | if uncond_embeddings.shape[0] == 1:
67 | uncond_embeddings = uncond_embeddings.expand(cond_embeddings.shape)
68 | # We follow the convention: negative (unconditional) prompt comes first
69 | text_embeddings = torch.cat((uncond_embeddings, cond_embeddings), dim=0)
70 | return text_embeddings, uncond_embeddings, cond_embeddings
71 | else:
72 | raise ValueError(f"input_embeddings length: {len(input_embeddings)}")
73 |
74 |
75 | def attn_list_to_tensor(cross_attention_probs):
76 | # timestep, CrossAttnBlock, Transformer2DModel, 1xBasicTransformerBlock
77 |
78 | num_cross_attn_block = len(cross_attention_probs[0])
79 | cross_attention_probs_all = []
80 |
81 | for i in range(num_cross_attn_block):
82 | # cross_attention_probs_timestep[i]: Transformer2DModel
83 | # 1xBasicTransformerBlock is skipped
84 | cross_attention_probs_current = []
85 | for cross_attention_probs_timestep in cross_attention_probs:
86 | cross_attention_probs_current.append(
87 | torch.stack([item for item in cross_attention_probs_timestep[i]], dim=0)
88 | )
89 |
90 | cross_attention_probs_current = torch.stack(
91 | cross_attention_probs_current, dim=0
92 | )
93 | cross_attention_probs_all.append(cross_attention_probs_current)
94 |
95 | return cross_attention_probs_all
96 |
--------------------------------------------------------------------------------
/models/pipelines.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from tqdm import tqdm
4 | from utils import guidance, schedule
5 | import utils
6 | from PIL import Image
7 | import gc
8 | import numpy as np
9 | from .attention import GatedSelfAttentionDense
10 | from .models import process_input_embeddings, torch_device
11 | import warnings
12 |
13 | DEFAULT_GUIDANCE_ATTN_KEYS = [
14 | ("down", 2, 0, 0),
15 | ("down", 2, 1, 0),
16 | ("up", 1, 0, 0),
17 | ("up", 1, 1, 0),
18 | ]
19 |
20 |
21 | def latent_backward_guidance(
22 | scheduler,
23 | unet,
24 | cond_embeddings,
25 | index,
26 | bboxes,
27 | object_positions,
28 | t,
29 | latents,
30 | loss,
31 | loss_scale=30,
32 | loss_threshold=0.2,
33 | max_iter=5,
34 | max_index_step=10,
35 | cross_attention_kwargs=None,
36 | guidance_attn_keys=None,
37 | verbose=False,
38 | return_saved_attn=False,
39 | clear_cache=False,
40 | **kwargs,
41 | ):
42 | """
43 | return_saved_attn: return the saved attention for visualizations
44 | """
45 |
46 | iteration = 0
47 |
48 | saved_attn_to_return = None
49 |
50 | if index < max_index_step:
51 | if isinstance(max_iter, list):
52 | max_iter = max_iter[index]
53 |
54 | if verbose:
55 | print(
56 | f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
57 | )
58 |
59 | with torch.set_grad_enabled(True):
60 | while (
61 | loss.item() / loss_scale > loss_threshold
62 | and iteration < max_iter
63 | and index < max_index_step
64 | ):
65 | saved_attn = {}
66 | full_cross_attention_kwargs = {
67 | "save_attn_to_dict": saved_attn,
68 | "save_keys": guidance_attn_keys,
69 | }
70 |
71 | if cross_attention_kwargs is not None:
72 | full_cross_attention_kwargs.update(cross_attention_kwargs)
73 |
74 | latents.requires_grad_(True)
75 | latent_model_input = latents
76 | latent_model_input = scheduler.scale_model_input(latent_model_input, t)
77 |
78 | unet(
79 | latent_model_input,
80 | t,
81 | encoder_hidden_states=cond_embeddings,
82 | cross_attention_kwargs=full_cross_attention_kwargs,
83 | )
84 |
85 | if return_saved_attn == "first":
86 | if iteration == 0:
87 | saved_attn_to_return = {
88 | k: v.detach().cpu() for k, v in saved_attn.items()
89 | }
90 | elif return_saved_attn == "last":
91 | if iteration == max_iter - 1:
92 | # It will not save if the current call returns before the last iteration
93 | saved_attn_to_return = {
94 | k: v.detach().cpu() for k, v in saved_attn.items()
95 | }
96 | elif return_saved_attn:
97 | raise ValueError(return_saved_attn)
98 |
99 | # TODO: could return the attention maps for the required blocks only and not necessarily the final output
100 | # update latents with guidance
101 | loss = (
102 | guidance.compute_ca_lossv3(
103 | saved_attn=saved_attn,
104 | bboxes=bboxes,
105 | object_positions=object_positions,
106 | guidance_attn_keys=guidance_attn_keys,
107 | index=index,
108 | verbose=verbose,
109 | **kwargs,
110 | )
111 | * loss_scale
112 | )
113 |
114 | if torch.isnan(loss):
115 | print("**Loss is NaN**")
116 |
117 | del full_cross_attention_kwargs, saved_attn
118 | # call gc.collect() here may release some memory
119 |
120 | grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]
121 |
122 | latents.requires_grad_(False)
123 |
124 | if hasattr(scheduler, "alphas_cumprod"):
125 | warnings.warn("Using guidance scaled with alphas_cumprod")
126 | # Scaling with classifier guidance
127 | alpha_prod_t = scheduler.alphas_cumprod[t]
128 | # Classifier guidance: https://arxiv.org/pdf/2105.05233.pdf
129 | # DDIM: https://arxiv.org/pdf/2010.02502.pdf
130 | scale = (1 - alpha_prod_t) ** (0.5)
131 |
132 | latents = latents - scale * grad_cond
133 | else:
134 | # NOTE: no scaling is performed
135 | warnings.warn("No scaling in guidance is performed")
136 | latents = latents - grad_cond
137 | iteration += 1
138 |
139 | if clear_cache:
140 | gc.collect()
141 | torch.cuda.empty_cache()
142 |
143 | if verbose:
144 | print(
145 | f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
146 | )
147 |
148 | if return_saved_attn:
149 | return latents, loss, saved_attn_to_return
150 | return latents, loss
151 |
152 |
153 | @torch.no_grad()
154 | def encode(model_dict, image, generator):
155 | """
156 | image should be a PIL object or numpy array with range 0 to 255
157 | """
158 |
159 | vae, dtype = model_dict.vae, model_dict.dtype
160 |
161 | if isinstance(image, Image.Image):
162 | w, h = image.size
163 | assert (
164 | w % 8 == 0 and h % 8 == 0
165 | ), f"h ({h}) and w ({w}) should be a multiple of 8"
166 | # w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
167 | # image = np.array(image.resize((w, h), resample=Image.Resampling.LANCZOS))[None, :]
168 | image = np.array(image)
169 |
170 | if isinstance(image, np.ndarray):
171 | assert (
172 | image.dtype == np.uint8
173 | ), f"Should have dtype uint8 (dtype: {image.dtype})"
174 | image = image.astype(np.float32) / 255.0
175 | image = image[None, ...]
176 | image = image.transpose(0, 3, 1, 2)
177 | image = 2.0 * image - 1.0
178 | image = torch.from_numpy(image)
179 |
180 | assert isinstance(image, torch.Tensor), f"type of image: {type(image)}"
181 |
182 | image = image.to(device=torch_device, dtype=dtype)
183 | latents = vae.encode(image).latent_dist.sample(generator)
184 |
185 | latents = vae.config.scaling_factor * latents
186 |
187 | return latents
188 |
189 |
190 | @torch.no_grad()
191 | def decode(vae, latents):
192 | # scale and decode the image latents with vae
193 | scaled_latents = 1 / 0.18215 * latents
194 | with torch.no_grad():
195 | image = vae.decode(scaled_latents).sample
196 |
197 | image = (image / 2 + 0.5).clamp(0, 1)
198 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
199 | images = (image * 255).round().astype("uint8")
200 |
201 | return images
202 |
203 |
204 | def generate_semantic_guidance(
205 | model_dict,
206 | latents,
207 | input_embeddings,
208 | num_inference_steps,
209 | bboxes,
210 | phrases,
211 | object_positions,
212 | guidance_scale=7.5,
213 | semantic_guidance_kwargs=None,
214 | return_cross_attn=False,
215 | return_saved_cross_attn=False,
216 | saved_cross_attn_keys=None,
217 | return_cond_ca_only=False,
218 | return_token_ca_only=None,
219 | offload_guidance_cross_attn_to_cpu=False,
220 | offload_cross_attn_to_cpu=False,
221 | offload_latents_to_cpu=True,
222 | return_box_vis=False,
223 | show_progress=True,
224 | save_all_latents=False,
225 | dynamic_num_inference_steps=False,
226 | fast_after_steps=None,
227 | fast_rate=2,
228 | additional_guidance_cross_attention_kwargs={},
229 | custom_latent_backward_guidance=None,
230 | ):
231 | """
232 | object_positions: object indices in text tokens
233 | return_cross_attn: should be deprecated. Use `return_saved_cross_attn` and the new format.
234 | """
235 | vae, tokenizer, text_encoder, unet, scheduler, dtype = (
236 | model_dict.vae,
237 | model_dict.tokenizer,
238 | model_dict.text_encoder,
239 | model_dict.unet,
240 | model_dict.scheduler,
241 | model_dict.dtype,
242 | )
243 | text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
244 |
245 | # Just in case that we have in-place ops
246 | latents = latents.clone()
247 |
248 | if save_all_latents:
249 | # offload to cpu to save space
250 | if offload_latents_to_cpu:
251 | latents_all = [latents.cpu()]
252 | else:
253 | latents_all = [latents]
254 |
255 | scheduler.set_timesteps(num_inference_steps)
256 | if fast_after_steps is not None:
257 | scheduler.timesteps = schedule.get_fast_schedule(
258 | scheduler.timesteps, fast_after_steps, fast_rate
259 | )
260 |
261 | if dynamic_num_inference_steps:
262 | original_num_inference_steps = scheduler.num_inference_steps
263 |
264 | cross_attention_probs_down = []
265 | cross_attention_probs_mid = []
266 | cross_attention_probs_up = []
267 |
268 | loss = torch.tensor(10000.0)
269 |
270 | # TODO: we can also save necessary tokens only to save memory.
271 | # offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep.
272 | guidance_cross_attention_kwargs = {
273 | "offload_cross_attn_to_cpu": offload_guidance_cross_attn_to_cpu,
274 | "enable_flash_attn": False,
275 | **additional_guidance_cross_attention_kwargs,
276 | }
277 |
278 | if return_saved_cross_attn:
279 | saved_attns = []
280 |
281 | main_cross_attention_kwargs = {
282 | "offload_cross_attn_to_cpu": offload_cross_attn_to_cpu,
283 | "return_cond_ca_only": return_cond_ca_only,
284 | "return_token_ca_only": return_token_ca_only,
285 | "save_keys": saved_cross_attn_keys,
286 | }
287 |
288 | # Repeating keys leads to different weights for each key.
289 | # assert len(set(semantic_guidance_kwargs['guidance_attn_keys'])) == len(semantic_guidance_kwargs['guidance_attn_keys']), f"guidance_attn_keys not unique: {semantic_guidance_kwargs['guidance_attn_keys']}"
290 |
291 | for index, t in enumerate(tqdm(scheduler.timesteps, disable=not show_progress)):
292 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
293 |
294 | if bboxes:
295 | if custom_latent_backward_guidance:
296 | latents, loss = custom_latent_backward_guidance(
297 | scheduler,
298 | unet,
299 | cond_embeddings,
300 | index,
301 | bboxes,
302 | object_positions,
303 | t,
304 | latents,
305 | loss,
306 | cross_attention_kwargs=guidance_cross_attention_kwargs,
307 | **semantic_guidance_kwargs,
308 | )
309 | else:
310 | # If encountered None in `guidance_attn_keys`, please be sure to check whether `guidance_attn_keys` is added in `semantic_guidance_kwargs`. Default value has been removed.
311 | latents, loss = latent_backward_guidance(
312 | scheduler,
313 | unet,
314 | cond_embeddings,
315 | index,
316 | bboxes,
317 | object_positions,
318 | t,
319 | latents,
320 | loss,
321 | cross_attention_kwargs=guidance_cross_attention_kwargs,
322 | **semantic_guidance_kwargs,
323 | )
324 |
325 | # predict the noise residual
326 | with torch.no_grad():
327 | latent_model_input = torch.cat([latents] * 2)
328 | latent_model_input = scheduler.scale_model_input(
329 | latent_model_input, timestep=t
330 | )
331 |
332 | main_cross_attention_kwargs["save_attn_to_dict"] = {}
333 |
334 | unet_output = unet(
335 | latent_model_input,
336 | t,
337 | encoder_hidden_states=text_embeddings,
338 | return_cross_attention_probs=return_cross_attn,
339 | cross_attention_kwargs=main_cross_attention_kwargs,
340 | )
341 | noise_pred = unet_output.sample
342 |
343 | if return_cross_attn:
344 | cross_attention_probs_down.append(
345 | unet_output.cross_attention_probs_down
346 | )
347 | cross_attention_probs_mid.append(unet_output.cross_attention_probs_mid)
348 | cross_attention_probs_up.append(unet_output.cross_attention_probs_up)
349 |
350 | if return_saved_cross_attn:
351 | saved_attns.append(main_cross_attention_kwargs["save_attn_to_dict"])
352 |
353 | del main_cross_attention_kwargs["save_attn_to_dict"]
354 |
355 | # perform guidance
356 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
357 | noise_pred = noise_pred_uncond + guidance_scale * (
358 | noise_pred_text - noise_pred_uncond
359 | )
360 |
361 | if dynamic_num_inference_steps:
362 | schedule.dynamically_adjust_inference_steps(scheduler, index, t)
363 |
364 | # compute the previous noisy sample x_t -> x_t-1
365 | latents = scheduler.step(noise_pred, t, latents).prev_sample
366 |
367 | if save_all_latents:
368 | if offload_latents_to_cpu:
369 | latents_all.append(latents.cpu())
370 | else:
371 | latents_all.append(latents)
372 |
373 | if dynamic_num_inference_steps:
374 | # Restore num_inference_steps to avoid confusion in the next generation if it is not dynamic
375 | scheduler.num_inference_steps = original_num_inference_steps
376 |
377 | images = decode(vae, latents)
378 |
379 | ret = [latents, images]
380 |
381 | if return_cross_attn:
382 | ret.append(
383 | (
384 | cross_attention_probs_down,
385 | cross_attention_probs_mid,
386 | cross_attention_probs_up,
387 | )
388 | )
389 | if return_saved_cross_attn:
390 | ret.append(saved_attns)
391 | if return_box_vis:
392 | pil_images = [
393 | utils.draw_box(Image.fromarray(image), bboxes, phrases) for image in images
394 | ]
395 | ret.append(pil_images)
396 | if save_all_latents:
397 | latents_all = torch.stack(latents_all, dim=0)
398 | ret.append(latents_all)
399 | return tuple(ret)
400 |
--------------------------------------------------------------------------------
/models/transformer_temporal.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from dataclasses import dataclass
15 | from typing import Optional
16 |
17 | import torch
18 | from torch import nn
19 |
20 | from diffusers.configuration_utils import ConfigMixin, register_to_config
21 | from diffusers.utils import BaseOutput
22 | from .attention import BasicTransformerBlock
23 | from diffusers.models.modeling_utils import ModelMixin
24 |
25 |
26 | @dataclass
27 | class TransformerTemporalModelOutput(BaseOutput):
28 | """
29 | The output of [`TransformerTemporalModel`].
30 |
31 | Args:
32 | sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
33 | The hidden states output conditioned on `encoder_hidden_states` input.
34 | """
35 |
36 | sample: torch.FloatTensor
37 |
38 |
39 | class TransformerTemporalModel(ModelMixin, ConfigMixin):
40 | """
41 | A Transformer model for video-like data.
42 |
43 | Parameters:
44 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
45 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
46 | in_channels (`int`, *optional*):
47 | The number of channels in the input and output (specify if the input is **continuous**).
48 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
49 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
50 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
51 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
52 | This is fixed during training since it is used to learn a number of position embeddings.
53 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
54 | attention_bias (`bool`, *optional*):
55 | Configure if the `TransformerBlock` attention should contain a bias parameter.
56 | double_self_attention (`bool`, *optional*):
57 | Configure if each `TransformerBlock` should contain two self-attention layers.
58 | """
59 |
60 | @register_to_config
61 | def __init__(
62 | self,
63 | num_attention_heads: int = 16,
64 | attention_head_dim: int = 88,
65 | in_channels: Optional[int] = None,
66 | out_channels: Optional[int] = None,
67 | num_layers: int = 1,
68 | dropout: float = 0.0,
69 | norm_num_groups: int = 32,
70 | cross_attention_dim: Optional[int] = None,
71 | attention_bias: bool = False,
72 | sample_size: Optional[int] = None,
73 | activation_fn: str = "geglu",
74 | norm_elementwise_affine: bool = True,
75 | double_self_attention: bool = True,
76 | ):
77 | super().__init__()
78 | self.num_attention_heads = num_attention_heads
79 | self.attention_head_dim = attention_head_dim
80 | inner_dim = num_attention_heads * attention_head_dim
81 |
82 | self.in_channels = in_channels
83 |
84 | self.norm = torch.nn.GroupNorm(
85 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
86 | )
87 | self.proj_in = nn.Linear(in_channels, inner_dim)
88 |
89 | # 3. Define transformers blocks
90 | self.transformer_blocks = nn.ModuleList(
91 | [
92 | BasicTransformerBlock(
93 | inner_dim,
94 | num_attention_heads,
95 | attention_head_dim,
96 | dropout=dropout,
97 | cross_attention_dim=cross_attention_dim,
98 | activation_fn=activation_fn,
99 | attention_bias=attention_bias,
100 | double_self_attention=double_self_attention,
101 | norm_elementwise_affine=norm_elementwise_affine,
102 | )
103 | for d in range(num_layers)
104 | ]
105 | )
106 |
107 | self.proj_out = nn.Linear(inner_dim, in_channels)
108 |
109 | def forward(
110 | self,
111 | hidden_states,
112 | encoder_hidden_states=None,
113 | timestep=None,
114 | class_labels=None,
115 | num_frames=1,
116 | cross_attention_kwargs=None,
117 | return_dict: bool = True,
118 | ):
119 | """
120 | The [`TransformerTemporal`] forward method.
121 |
122 | Args:
123 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
124 | Input hidden_states.
125 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
126 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
127 | self-attention.
128 | timestep ( `torch.long`, *optional*):
129 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
130 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
131 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
132 | `AdaLayerZeroNorm`.
133 | return_dict (`bool`, *optional*, defaults to `True`):
134 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
135 | tuple.
136 |
137 | Returns:
138 | [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
139 | If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
140 | returned, otherwise a `tuple` where the first element is the sample tensor.
141 | """
142 | # 1. Input
143 | batch_frames, channel, height, width = hidden_states.shape
144 | batch_size = batch_frames // num_frames
145 |
146 | residual = hidden_states
147 |
148 | hidden_states = hidden_states[None, :].reshape(
149 | batch_size, num_frames, channel, height, width
150 | )
151 | hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
152 |
153 | hidden_states = self.norm(hidden_states)
154 | hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(
155 | batch_size * height * width, num_frames, channel
156 | )
157 |
158 | hidden_states = self.proj_in(hidden_states)
159 |
160 | base_attn_key = cross_attention_kwargs["attn_key"]
161 |
162 | # 2. Blocks
163 | for block_ind, block in enumerate(self.transformer_blocks):
164 | cross_attention_kwargs["attn_key"] = base_attn_key + [block_ind]
165 |
166 | hidden_states = block(
167 | hidden_states,
168 | encoder_hidden_states=encoder_hidden_states,
169 | timestep=timestep,
170 | cross_attention_kwargs=cross_attention_kwargs,
171 | class_labels=class_labels,
172 | )
173 |
174 | # 3. Output
175 | hidden_states = self.proj_out(hidden_states)
176 | hidden_states = (
177 | hidden_states[None, None, :]
178 | .reshape(batch_size, height, width, channel, num_frames)
179 | .permute(0, 3, 4, 1, 2)
180 | .contiguous()
181 | )
182 | hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
183 |
184 | output = hidden_states + residual
185 |
186 | if not return_dict:
187 | return (output,)
188 |
189 | return TransformerTemporalModelOutput(sample=output)
190 |
--------------------------------------------------------------------------------
/prompt.py:
--------------------------------------------------------------------------------
1 | # make the response more structured
2 | templatev0_1_chat = [
3 | {
4 | "role": "system",
5 | "content": """You are an intelligent bounding box generator for videos. You don't need to generate the videos themselves but need to generate the bounding boxes. I will provide you with a caption for a video with six frames, with two frames per second. Your task is to generate a list of realistic bounding boxes for the objects mentioned in the caption for each frame as well as a background keyword. The video frames are of size 512x512. The top-left corner has coordinates [0, 0]. The bottom-right corner has coordinnates [512, 512]. The bounding boxes should not overlap or go beyond the frame boundaries.
6 |
7 | Each frame should be represented as `[{'id': unique object identifier incrementing from 0, 'name': object name, 'box': [box top-left x-coordinate, box top-left y-coordinate, box width, box height]}, ...]`. Each box should not include more than one object. Your generated frames must encapsulate the whole scenario depicted by the caption. Assume objects move and interact based on real-world physics, considering aspects such as gravity and elasticity. Assume the camera follows perspective geometry. Boxes for an object should have the same id across the frames, even if the object may disappear and reappear. If needed, you can make reasonable guesses. Provide a concise reasoning statement that is not longer than a few sentences before each generation. Refer to the examples below for the desired format. Never use markdown or other formats not in the examples. Do not start each frame with `-`. Do not include any comments in your response.""",
8 | },
9 | {
10 | "role": "user",
11 | "content": "Caption: A woman walking from the left to the right and a man jumping on the right in a room",
12 | },
13 | {
14 | "role": "assistant",
15 | "content": """Reasoning: A woman is walking from the left to the right so her x-coordinate should increase with her y-coordinate fixed. A man is jumping on the right so his x-coordinate should be large, and his y-coordinate should first decrease (upward movement) and then increase (downward movement due to gravity).
16 | Frame 1: [{'id': 0, 'name': 'walking woman', 'box': [0, 270, 120, 200]}, {'id': 1, 'name': 'jumping man', 'box': [380, 290, 120, 180]}]
17 | Frame 2: [{'id': 0, 'name': 'walking woman', 'box': [50, 270, 120, 200]}, {'id': 1, 'name': 'jumping man', 'box': [380, 205, 120, 200]}]
18 | Frame 3: [{'id': 0, 'name': 'walking woman', 'box': [100, 270, 120, 200]}, {'id': 1, 'name': 'jumping man', 'box': [380, 175, 120, 200]}]
19 | Frame 4: [{'id': 0, 'name': 'walking woman', 'box': [150, 270, 120, 200]}, {'id': 1, 'name': 'jumping man', 'box': [380, 175, 120, 200]}]
20 | Frame 5: [{'id': 0, 'name': 'walking woman', 'box': [200, 270, 120, 200]}, {'id': 1, 'name': 'jumping man', 'box': [380, 205, 120, 200]}]
21 | Frame 6: [{'id': 0, 'name': 'walking woman', 'box': [250, 270, 120, 200]}, {'id': 1, 'name': 'jumping man', 'box': [380, 290, 120, 180]}]
22 | Background keyword: room""",
23 | },
24 | {
25 | "role": "user",
26 | "content": "Caption: A red ball is thrown from the left to the right in a garden",
27 | },
28 | {
29 | "role": "assistant",
30 | "content": """Reasoning: A ball is thrown from the left to the right, so its x-coordinate should increase. Due to gravity, its y-coordinate should increase, and the speed should be faster in later frames until it hits the ground. Due to its elasticity, the ball bounces back when it hits the ground.
31 | Frame 1: [{'id': 0, 'name': 'red ball', 'box': [0, 206, 50, 50]}]
32 | Frame 2: [{'id': 0, 'name': 'red ball', 'box': [80, 246, 50, 50]}]
33 | Frame 3: [{'id': 0, 'name': 'red ball', 'box': [160, 326, 50, 50]}]
34 | Frame 4: [{'id': 0, 'name': 'red ball', 'box': [240, 446, 50, 50]}]
35 | Frame 5: [{'id': 0, 'name': 'red ball', 'box': [320, 366, 50, 50]}]
36 | Frame 6: [{'id': 0, 'name': 'red ball', 'box': [400, 446, 50, 50]}]
37 | Background keyword: garden""",
38 | },
39 | {"role": "user", "content": "Caption: The camera is moving away from a painting"},
40 | {
41 | "role": "assistant",
42 | "content": """Reasoning: Due to perspective geometry, the painting will be smaller in later timesteps as the distance between the camera and the object is larger.
43 | Frame 1: [{'id': 0, 'name': 'painting', 'box': [156, 181, 200, 150]}]
44 | Frame 2: [{'id': 0, 'name': 'painting', 'box': [166, 189, 180, 135]}]
45 | Frame 3: [{'id': 0, 'name': 'painting', 'box': [176, 196, 160, 120]}]
46 | Frame 4: [{'id': 0, 'name': 'painting', 'box': [186, 204, 140, 105]}]
47 | Frame 5: [{'id': 0, 'name': 'painting', 'box': [196, 211, 120, 90]}]
48 | Frame 6: [{'id': 0, 'name': 'painting', 'box': [206, 219, 100, 75]}]
49 | Background keyword: room""",
50 | },
51 | ]
52 |
53 | templates = {
54 | "v0.1": templatev0_1_chat,
55 | }
56 |
57 | template_versions = list(templates.keys())
58 |
59 |
60 | def get_num_parsed_layout_frames(template_version):
61 | return 6
62 |
63 |
64 | # 6 frames
65 | required_lines = [f"Frame {i+1}:" for i in range(6)] + ["Background keyword:"]
66 | required_lines_ast = [True] * 6 + [False]
67 |
68 | strip_before = required_lines[0]
69 |
70 | stop = "\n\n"
71 |
72 | prompts_demo = [
73 | "A bear walks from the left to the right",
74 | ]
75 |
76 | prompt_types = ["demo", "lvd"]
77 |
78 | negative_prompt = (
79 | "dull, gray, unrealistic, colorless, blurry, low-quality, weird, abrupt"
80 | )
81 |
82 |
83 | def get_prompts(prompt_type, return_predicates=False):
84 | if prompt_type.startswith("lvd"):
85 | from utils.eval.lvd import get_lvd_full_prompts, get_lvd_full_prompt_predicates
86 |
87 | if return_predicates:
88 | prompts = get_lvd_full_prompt_predicates(prompt_type)
89 | else:
90 | prompts = get_lvd_full_prompts(prompt_type)
91 | elif prompt_type == "demo":
92 | assert (
93 | not return_predicates
94 | ), "Predicates are not supported for this prompt type"
95 | prompts = prompts_demo
96 | else:
97 | raise ValueError(f"Unknown prompt type: {prompt_type}")
98 |
99 | return prompts
100 |
101 |
102 | if __name__ == "__main__":
103 | if True:
104 | prompt_type = "demo"
105 |
106 | assert prompt_type in prompt_types, f"prompt_type {prompt_type} does not exist"
107 |
108 | prompts = get_prompts(prompt_type)
109 | prompt = prompts[-1]
110 | else:
111 | prompt = input("Prompt: ")
112 |
113 | template_key = "v0.1"
114 | template = templates[template_key]
115 |
116 | if isinstance(template, list):
117 | template = (
118 | "\n\n".join([item["content"] for item in template])
119 | + "\n\nCaption: {prompt}\nReasoning:"
120 | )
121 |
122 | prompt_full = template.replace("{prompt}", prompt.strip().rstrip(".")).strip()
123 | print(prompt_full)
124 |
125 | if False:
126 | import json
127 |
128 | print(json.dumps(prompt_full.strip("\n")))
129 |
--------------------------------------------------------------------------------
/prompt_batch.py:
--------------------------------------------------------------------------------
1 | import os
2 | from prompt import get_prompts, template_versions
3 | from utils import parse, utils
4 | from utils.parse import show_video_boxes, size
5 | from utils.llm import get_llm_kwargs, get_full_prompt, model_names, get_parsed_layout
6 | from utils import cache
7 | import argparse
8 | import time
9 |
10 | # This only applies to visualization in this file.
11 | scale_boxes = False
12 |
13 | if scale_boxes:
14 | print("Scaling the bounding box to fit the scene")
15 | else:
16 | print("Not scaling the bounding box to fit the scene")
17 |
18 | H, W = size
19 |
20 |
21 | def visualize_layout(parsed_layout):
22 | condition = parse.parsed_layout_to_condition(
23 | parsed_layout, tokenizer=None, height=H, width=W, verbose=True
24 | )
25 |
26 | show_video_boxes(condition, ind=ind, save=True)
27 |
28 | print(f"Visualize masks at {parse.img_dir}")
29 |
30 |
31 | if __name__ == "__main__":
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument("--prompt-type", type=str, default="demo")
34 | parser.add_argument("--model", choices=model_names, required=True)
35 | parser.add_argument("--template_version", choices=template_versions, required=True)
36 | parser.add_argument(
37 | "--auto-query", action="store_true", help="Auto query using the API"
38 | )
39 | parser.add_argument(
40 | "--always-save",
41 | action="store_true",
42 | help="Always save the layout without confirming",
43 | )
44 | parser.add_argument("--no-visualize", action="store_true", help="No visualizations")
45 | parser.add_argument(
46 | "--visualize-cache-hit", action="store_true", help="Save boxes for cache hit"
47 | )
48 | parser.add_argument(
49 | "--unnormalize-boxes-before-save",
50 | action="store_true",
51 | help="Unnormalize the boxes before saving. This should be enabled if the prompt asks the LLM to return normalized boxes.",
52 | )
53 | args = parser.parse_args()
54 |
55 | visualize_cache_hit = args.visualize_cache_hit
56 |
57 | template_version = args.template_version
58 |
59 | model, llm_kwargs = get_llm_kwargs(
60 | model=args.model, template_version=template_version
61 | )
62 | template = llm_kwargs.template
63 | # Need to parse json format for json templates
64 | json_template = "json" in template_version
65 |
66 | # This is for visualizing bounding boxes
67 | parse.img_dir = (
68 | f"img_generations/imgs_{args.prompt_type}_template{template_version}"
69 | )
70 | if not args.no_visualize:
71 | os.makedirs(parse.img_dir, exist_ok=True)
72 |
73 | cache.cache_path = f'cache/cache_{args.prompt_type.replace("lmd_", "")}{"_" + template_version if args.template_version != "v5" else ""}_{model}.json'
74 |
75 | os.makedirs(os.path.dirname(cache.cache_path), exist_ok=True)
76 | cache.cache_format = "json"
77 |
78 | cache.init_cache()
79 |
80 | prompts_query = get_prompts(args.prompt_type)
81 |
82 | max_attempts = 1
83 |
84 | for ind, prompt in enumerate(prompts_query):
85 | if isinstance(prompt, list):
86 | # prompt, seed
87 | prompt = prompt[0]
88 | prompt = prompt.strip().rstrip(".")
89 |
90 | resp = cache.get_cache(prompt)
91 | if resp is None:
92 | print(f"Cache miss: {prompt}")
93 |
94 | if not args.auto_query:
95 | print("#########")
96 | prompt_full = get_full_prompt(template=template, prompt=prompt)
97 | print(prompt_full)
98 | print("#########")
99 | resp = None
100 |
101 | attempts = 0
102 | while True:
103 | attempts += 1
104 | try:
105 | # The resp from `get_parsed_layout` has already been structured
106 | if args.auto_query:
107 | parsed_layout, resp = get_parsed_layout(
108 | prompt,
109 | llm_kwargs=llm_kwargs,
110 | json_template=json_template,
111 | verbose=False,
112 | )
113 | print("Response:", resp)
114 | else:
115 | resp = utils.multiline_input(
116 | prompt="Please enter LLM response (use an empty line to end): "
117 | )
118 | parsed_layout, resp = get_parsed_layout(
119 | prompt,
120 | llm_kwargs=llm_kwargs,
121 | override_response=resp,
122 | max_partial_response_retries=1,
123 | json_template=json_template,
124 | verbose=False,
125 | )
126 |
127 | except (ValueError, SyntaxError, TypeError) as e:
128 | if attempts > max_attempts:
129 | print("Retrying too many times, skipping")
130 | break
131 | print(
132 | f"Encountered invalid data with prompt {prompt} and response {resp}: {e}, retrying"
133 | )
134 | time.sleep(1)
135 | continue
136 |
137 | if not args.no_visualize:
138 | visualize_layout(parsed_layout)
139 | if not args.always_save:
140 | save = input("Save (y/n)? ").strip()
141 | else:
142 | save = "y"
143 | if save == "y" or save == "Y":
144 | cache.add_cache(prompt, resp)
145 | else:
146 | print("Not saved. Will generate the same prompt again.")
147 | continue
148 | break
149 | else:
150 | print(f"Cache hit: {prompt}")
151 |
152 | parsed_layout, resp = get_parsed_layout(
153 | prompt,
154 | llm_kwargs=llm_kwargs,
155 | override_response=resp,
156 | max_partial_response_retries=1,
157 | json_template=json_template,
158 | verbose=False,
159 | )
160 |
161 | if visualize_cache_hit:
162 | visualize_layout(parsed_layout)
163 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu113
2 | numpy
3 | scipy
4 | torch==2.0.0
5 | diffusers==0.27.2
6 | transformers==4.36.2
7 | opencv-python==4.7.0.72
8 | opencv-contrib-python==4.7.0.72
9 | inflect==6.0.4
10 | easydict
11 | accelerate==0.21.0
12 | gradio==3.35.2
13 | pydantic==1.10.7
14 | scikit-video==1.1.11
15 | imageio==2.34.1
16 | pyjson5==1.6.6
17 | joblib==1.4.2
18 |
--------------------------------------------------------------------------------
/scripts/eval_owl_vit.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
5 |
6 | import argparse
7 | import torch
8 | from utils.llm import get_full_model_name, model_names
9 | from transformers import OwlViTProcessor, OwlViTForObjectDetection
10 | from glob import glob
11 | from utils.eval import to_gen_box_format, evaluate_with_layout, class_aware_nms, nms
12 | from tqdm import tqdm
13 | import numpy as np
14 | import json
15 | import joblib
16 | from prompt import get_prompts, prompt_types
17 |
18 | torch.set_grad_enabled(False)
19 |
20 |
21 | def keep_one_box_per_class(boxes, scores, labels):
22 | # Keep the box with highest label per class
23 |
24 | boxes_output, scores_output, labels_output = [], [], []
25 | labels_unique = np.unique(labels)
26 | for label in labels_unique:
27 | label_mask = labels == label
28 | boxes_label = boxes[label_mask]
29 | scores_label = scores[label_mask]
30 | max_score_index = scores_label.argmax()
31 | box, score = boxes_label[max_score_index], scores_label[max_score_index]
32 | boxes_output.append(box)
33 | scores_output.append(score)
34 | labels_output.append(label)
35 |
36 | return np.array(boxes_output), np.array(scores_output), np.array(labels_output)
37 |
38 |
39 | def eval_prompt(
40 | p,
41 | predicate,
42 | path,
43 | processor,
44 | model,
45 | score_threshold=0.1,
46 | nms_threshold=0.5,
47 | use_class_aware_nms=False,
48 | num_eval_frames=6,
49 | use_cuda=True,
50 | verbose=False,
51 | ):
52 | video = joblib.load(path)
53 | texts = [predicate.texts]
54 |
55 | parsed_layout = {"Prompt": p, "Background keyword": None}
56 |
57 | eval_frame_indices = (
58 | np.round(np.linspace(0, len(video) - 1, num_eval_frames)).astype(int).tolist()
59 | )
60 |
61 | assert len(set(eval_frame_indices)) == len(
62 | eval_frame_indices
63 | ), f"Eval indices not unique: {eval_frame_indices}"
64 |
65 | print(f"Eval indices: {eval_frame_indices}")
66 |
67 | frame_ind = 1
68 | for eval_frame_index in eval_frame_indices:
69 | image = video[eval_frame_index]
70 | inputs = processor(text=texts, images=image, return_tensors="pt")
71 | if use_cuda:
72 | inputs = inputs.to("cuda")
73 | outputs = model(**inputs)
74 |
75 | height, width, _ = image.shape
76 |
77 | # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
78 | target_sizes = torch.Tensor([[height, width]])
79 | if use_cuda:
80 | target_sizes = target_sizes.cuda()
81 | # Convert outputs (bounding boxes and class logits) to COCO API
82 | results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
83 |
84 | i = 0 # Retrieve predictions for the first image for the corresponding text queries
85 | text = texts[i]
86 | boxes, scores, labels = (
87 | results[i]["boxes"],
88 | results[i]["scores"],
89 | results[i]["labels"],
90 | )
91 | boxes = boxes.cpu()
92 | # xyxy ranging from 0 to 1
93 | boxes = np.array(
94 | [
95 | [x_min / width, y_min / height, x_max / width, y_max / height]
96 | for (x_min, y_min, x_max, y_max), score in zip(boxes, scores)
97 | if score >= score_threshold
98 | ]
99 | )
100 | labels = np.array(
101 | [
102 | label.cpu().numpy()
103 | for label, score in zip(labels, scores)
104 | if score >= score_threshold
105 | ]
106 | )
107 | scores = np.array(
108 | [score.cpu().numpy() for score in scores if score >= score_threshold]
109 | )
110 |
111 | # print(f"Pre-NMS:")
112 | # for box, score, label in zip(boxes, scores, labels):
113 | # box = [round(i, 2) for i in box.tolist()]
114 | # print(
115 | # f"Detected {text[label]} ({label}) with confidence {round(score.item(), 3)} at location {box}")
116 |
117 | print(f"Post-NMS (frame frame_ind):")
118 |
119 | if use_class_aware_nms:
120 | boxes, scores, labels = class_aware_nms(
121 | boxes, scores, labels, nms_threshold
122 | )
123 | else:
124 | boxes, scores, labels = nms(boxes, scores, labels, nms_threshold)
125 |
126 | for box, score, label in zip(boxes, scores, labels):
127 | box = [round(i, 2) for i in box.tolist()]
128 | print(
129 | f"Detected {text[label]} ({label}) with confidence {round(score.item(), 3)} at location {box}"
130 | )
131 |
132 | if verbose:
133 | print(
134 | f"prompt: {p}, texts: {texts}, boxes: {boxes}, labels: {labels}, scores: {scores}"
135 | )
136 |
137 | # Here we are not using a tracker so the box id could mismatch when we have multiple objects with the same label.
138 | # For numeracy, we do not need tracking (mismatch is ok). For other tasks, we only include the box with max confidence.
139 | if predicate.one_box_per_class:
140 | boxes, scores, labels = keep_one_box_per_class(boxes, scores, labels)
141 |
142 | for box, score, label in zip(boxes, scores, labels):
143 | box = [round(i, 2) for i in box.tolist()]
144 | print(
145 | f"After selection one box per class: Detected {text[label]} ({label}) with confidence {round(score.item(), 3)} at location {box}"
146 | )
147 |
148 | det_boxes = []
149 | label_counts = {}
150 |
151 | # This ensures boxes of different labels will not be matched to each other.
152 | for box, score, label in zip(boxes, scores, labels):
153 | if label not in label_counts:
154 | label_counts[label] = 0
155 | # Here we convert to gen box format (same as LLM output), xywh (in pixels). This is for compatibility with first stage evaluation. This will be converted to condition format (xyxy, ranging from 0 to 1).
156 | det_boxes.append(
157 | {
158 | "id": label * 100 + label_counts[label],
159 | "name": text[label],
160 | "box": to_gen_box_format(box, width, height, rounding=True),
161 | "score": score,
162 | }
163 | )
164 | label_counts[label] += 1
165 |
166 | parsed_layout[f"Frame {frame_ind}"] = det_boxes
167 |
168 | frame_ind += 1
169 |
170 | print(f"parsed_layout: {parsed_layout}")
171 |
172 | eval_type, eval_success = evaluate_with_layout(
173 | parsed_layout,
174 | predicate,
175 | num_parsed_layout_frames=num_eval_frames,
176 | height=height,
177 | width=width,
178 | verbose=verbose,
179 | )
180 |
181 | return eval_type, eval_success
182 |
183 |
184 | if __name__ == "__main__":
185 | parser = argparse.ArgumentParser()
186 | parser.add_argument("--prompt-type", type=str, default="lvd")
187 | parser.add_argument("--run_base_path", type=str)
188 | parser.add_argument("--run_start_ind", default=0, type=int)
189 | parser.add_argument("--num_prompts", default=None, type=int)
190 | parser.add_argument("--num_eval_frames", default=6, type=int)
191 | parser.add_argument("--skip_first_prompts", default=0, type=int)
192 | parser.add_argument("--detection_score_threshold", default=0.05, type=float)
193 | parser.add_argument("--nms_threshold", default=0.5, type=float)
194 | parser.add_argument("--class-aware-nms", action="store_true")
195 | parser.add_argument("--save-eval", action="store_true")
196 | parser.add_argument("--verbose", action="store_true")
197 | parser.add_argument("--no-cuda", action="store_true")
198 | args = parser.parse_args()
199 |
200 | np.set_printoptions(precision=2)
201 |
202 | prompt_predicates = get_prompts(args.prompt_type, return_predicates=True)
203 | num_eval_frames = args.num_eval_frames
204 |
205 | print(f"Number of prompts (predicates): {len(prompt_predicates)}")
206 | print(f"Number of evaluating frames: {num_eval_frames}")
207 |
208 | processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
209 | owl_vit_model = OwlViTForObjectDetection.from_pretrained(
210 | "google/owlvit-base-patch32"
211 | )
212 | owl_vit_model.eval()
213 |
214 | use_cuda = not args.no_cuda
215 |
216 | if use_cuda:
217 | owl_vit_model.cuda()
218 |
219 | eval_success_counts = {}
220 | eval_all_counts = {}
221 |
222 | eval_successes = {}
223 |
224 | for ind, (prompt, predicate) in enumerate(tqdm(prompt_predicates)):
225 | if isinstance(prompt, list):
226 | # prompt and kwargs
227 | prompt = prompt[0]
228 | prompt = prompt.strip().rstrip(".")
229 | if ind < args.skip_first_prompts:
230 | continue
231 | if args.num_prompts is not None and ind >= (
232 | args.skip_first_prompts + args.num_prompts
233 | ):
234 | continue
235 |
236 | search_path = f"{args.run_base_path}/{ind+args.run_start_ind}/video_*.joblib"
237 |
238 | # NOTE: sorted with string type
239 | path = sorted(glob(search_path))
240 | if len(path) == 0:
241 | print(f"***No image matching {search_path}, skipping***")
242 | continue
243 | elif len(path) > 1:
244 | print(f"***More than one images match {search_path}: {path}, skipping***")
245 | continue
246 | path = path[0]
247 | print(f"Video path: {path} ({path.replace('.joblib', '.gif')})")
248 |
249 | eval_type, eval_success = eval_prompt(
250 | prompt,
251 | predicate,
252 | path,
253 | processor,
254 | owl_vit_model,
255 | score_threshold=args.detection_score_threshold,
256 | nms_threshold=args.nms_threshold,
257 | use_class_aware_nms=args.class_aware_nms,
258 | num_eval_frames=num_eval_frames,
259 | use_cuda=use_cuda,
260 | verbose=args.verbose,
261 | )
262 |
263 | print(f"Eval success (eval_type):", eval_success)
264 |
265 | if eval_type not in eval_all_counts:
266 | eval_success_counts[eval_type] = 0
267 | eval_all_counts[eval_type] = 0
268 | eval_successes[eval_type] = []
269 |
270 | eval_success_counts[eval_type] += int(eval_success)
271 | eval_all_counts[eval_type] += 1
272 | eval_successes[eval_type].append(bool(eval_success))
273 |
274 | summary = []
275 | eval_success_conut, eval_all_count = 0, 0
276 | for k, v in eval_all_counts.items():
277 | rate = eval_success_counts[k] / eval_all_counts[k]
278 | print(
279 | f"Eval type: {k}, success: {eval_success_counts[k]}/{eval_all_counts[k]}, rate: {round(rate, 2):.2f}"
280 | )
281 | eval_success_conut += eval_success_counts[k]
282 | eval_all_count += eval_all_counts[k]
283 | summary.append(rate)
284 |
285 | rate = eval_success_conut / eval_all_count
286 | print(f"Overall: success: {eval_success_conut}/{eval_all_count}, rate: {rate:.2f}")
287 | summary.append(rate)
288 |
289 | summary_str = "/".join([f"{round(rate, 2):.2f}" for rate in summary])
290 | print(f"Summary: {summary_str}")
291 |
292 | if args.save_eval:
293 | save_eval = {}
294 | save_eval["success_counts"] = eval_success_counts
295 | save_eval["sample_counts"] = eval_all_counts
296 | save_eval["successes"] = eval_successes
297 | save_eval["success_counts_overall"] = eval_success_conut
298 | save_eval["sample_counts_overall"] = eval_all_count
299 |
300 | # Reference: https://stackoverflow.com/questions/58408054/typeerror-object-of-type-bool-is-not-json-serializable
301 |
302 | with open(f"{args.run_base_path}/eval.json", "w") as f:
303 | json.dump(save_eval, f, indent=4)
304 |
--------------------------------------------------------------------------------
/scripts/eval_stage_one.py:
--------------------------------------------------------------------------------
1 | # This script allows evaluating stage one and saving the generated prompts to cache
2 |
3 | import sys
4 | import os
5 |
6 | sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
7 |
8 | import json
9 | import argparse
10 | from prompt import get_prompts, get_num_parsed_layout_frames, template_versions
11 | from utils.llm import get_llm_kwargs, get_parsed_layout_with_cache, model_names
12 | from utils.eval import evaluate_with_layout
13 | from utils import parse, cache
14 | import numpy as np
15 | from tqdm import tqdm
16 |
17 | eval_success_counts = {}
18 | eval_all_counts = {}
19 |
20 | if __name__ == "__main__":
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument("--prompt-type", type=str, default="lvd")
23 | parser.add_argument("--model", choices=model_names, required=True)
24 | parser.add_argument("--template_version", choices=template_versions, required=True)
25 | parser.add_argument("--skip_first_prompts", default=0, type=int)
26 | parser.add_argument("--num_prompts", default=None, type=int)
27 | parser.add_argument("--show-cache-access", action="store_true")
28 | parser.add_argument("--verbose", action="store_true")
29 | args = parser.parse_args()
30 |
31 | np.set_printoptions(precision=2)
32 |
33 | template_version = args.template_version
34 |
35 | json_template = "json" in template_version
36 |
37 | model, llm_kwargs = get_llm_kwargs(
38 | model=args.model, template_version=template_version
39 | )
40 |
41 | cache.cache_format = "json"
42 | cache.cache_path = f'cache/cache_{args.prompt_type.replace("lmd_", "")}_{template_version}_{model}.json'
43 | cache.init_cache()
44 |
45 | prompt_predicates = get_prompts(args.prompt_type, return_predicates=True)
46 | print(f"Number of prompts (predicates): {len(prompt_predicates)}")
47 |
48 | height, width = parse.size_h, parse.size_w
49 |
50 | for ind, (prompt, predicate) in enumerate(tqdm(prompt_predicates)):
51 | if isinstance(prompt, list):
52 | # prompt and kwargs
53 | prompt = prompt[0]
54 | prompt = prompt.strip().rstrip(".")
55 | if ind < args.skip_first_prompts:
56 | continue
57 | if args.num_prompts is not None and ind >= (
58 | args.skip_first_prompts + args.num_prompts
59 | ):
60 | continue
61 |
62 | parsed_layout = get_parsed_layout_with_cache(
63 | prompt, llm_kwargs, json_template=json_template, verbose=args.verbose
64 | )
65 | num_parsed_layout_frames = get_num_parsed_layout_frames(template_version)
66 | eval_type, eval_success = evaluate_with_layout(
67 | parsed_layout,
68 | predicate,
69 | num_parsed_layout_frames,
70 | height=height,
71 | width=width,
72 | verbose=args.verbose,
73 | )
74 |
75 | print(f"Eval success (eval_type):", eval_success)
76 |
77 | if eval_type not in eval_all_counts:
78 | eval_success_counts[eval_type] = 0
79 | eval_all_counts[eval_type] = 0
80 | eval_success_counts[eval_type] += int(eval_success)
81 | eval_all_counts[eval_type] += 1
82 |
83 | eval_success_conut, eval_all_count = 0, 0
84 | for k, v in eval_all_counts.items():
85 | print(
86 | f"Eval type: {k}, success: {eval_success_counts[k]}/{eval_all_counts[k]}, rate: {eval_success_counts[k]/eval_all_counts[k]:.2f}"
87 | )
88 | eval_success_conut += eval_success_counts[k]
89 | eval_all_count += eval_all_counts[k]
90 |
91 | print(
92 | f"Overall: success: {eval_success_conut}/{eval_all_count}, rate: {eval_success_conut/eval_all_count:.2f}"
93 | )
94 |
95 | if args.show_cache_access:
96 | # Print what are accessed in the cache (may have multiple values in each key)
97 | # Not including the newly added items
98 | print(json.dumps(cache.cache_queries))
99 | print("Number of accessed keys:", len(cache.cache_queries))
100 |
--------------------------------------------------------------------------------
/scripts/upsample.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import joblib
3 | import torch
4 | import imageio
5 | import os
6 | from PIL import Image
7 | import numpy as np
8 | from tqdm import tqdm
9 | import cv2
10 | from diffusers import StableDiffusionXLImg2ImgPipeline, VideoToVideoSDPipeline
11 | from diffusers import DPMSolverMultistepScheduler
12 |
13 |
14 | def prepare_init_upsampled(video_path, horizontal):
15 | video = joblib.load(video_path)
16 |
17 | if horizontal:
18 | video = [
19 | Image.fromarray(frame).resize((1024, 576), Image.LANCZOS) for frame in video
20 | ]
21 | else:
22 | video = [
23 | Image.fromarray(frame).resize((1024, 1024), Image.LANCZOS)
24 | for frame in video
25 | ]
26 |
27 | return video
28 |
29 |
30 | def save_images_to_video(images, output_video_path, frame_rate=8.0):
31 | """Save a list of images to a video file."""
32 | if not len(images):
33 | print("No images to process.")
34 | return
35 |
36 | # Assuming all images are the same size, get dimensions from the first image
37 | height, width, layers = images[0].shape
38 |
39 | # Define the codec and create VideoWriter object
40 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec definition
41 | video = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (width, height))
42 |
43 | for img in images:
44 | bgr_image = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Convert RGB to BGR
45 | video.write(bgr_image)
46 |
47 | video.release() # Release the video writer
48 | print(f"Video saved to {output_video_path}")
49 |
50 |
51 | def upsample_zsxl(
52 | video_path,
53 | prompt,
54 | horizontal,
55 | negative_prompt,
56 | seed,
57 | strength,
58 | use_zssdxl,
59 | output_mp4,
60 | fps=8,
61 | ):
62 | save_path = video_path.replace(
63 | ".joblib", "_zsxl" if strength == 0.35 else f"_zsxl_s{strength}"
64 | )
65 | if not os.path.exists(save_path + ".joblib"):
66 | video = prepare_init_upsampled(video_path, horizontal)
67 | g = torch.manual_seed(seed)
68 | video_frames_xl = pipe_xl(
69 | prompt,
70 | negative_prompt=negative_prompt,
71 | video=video,
72 | strength=strength,
73 | generator=g,
74 | ).frames[0]
75 | assert not os.path.exists(save_path + ".joblib"), save_path + ".joblib"
76 | video_frames_xl = (video_frames_xl * 255.0).astype(np.uint8)
77 | if output_mp4:
78 | save_images_to_video(
79 | video_frames_xl, save_path + ".mp4", frame_rate=fps + 0.01
80 | )
81 | imageio.mimsave(
82 | save_path + ".gif",
83 | video_frames_xl,
84 | format="gif",
85 | loop=0,
86 | duration=1000 * 1 / fps,
87 | )
88 | joblib.dump(video_frames_xl, save_path + ".joblib", compress=("bz2", 3))
89 | print(f"Zeroscope XL upsampled image saved at: {save_path + '.gif'}")
90 | else:
91 | print(f"{save_path + '.joblib'} exists, skipping")
92 |
93 | if use_zssdxl:
94 | upsample_sdxl(
95 | save_path + ".joblib",
96 | prompt,
97 | horizontal,
98 | negative_prompt,
99 | seed,
100 | strength=0.1,
101 | )
102 |
103 |
104 | def upsample_sdxl(video_path, prompt, horizontal, negative_prompt, seed, strength):
105 | save_path = video_path.replace(
106 | ".joblib", "_sdxl" if strength == 0.35 else f"_sdxl_s{strength}"
107 | )
108 | if not os.path.exists(save_path + ".joblib"):
109 | video = prepare_init_upsampled(video_path, horizontal)
110 | pipe_sdxl.set_progress_bar_config(disable=True)
111 | video_frames_sdxl = []
112 | for video_frame in tqdm(video):
113 | g = torch.manual_seed(seed)
114 | image = pipe_sdxl(
115 | prompt,
116 | image=video_frame,
117 | negative_prompt=negative_prompt,
118 | strength=strength,
119 | generator=g,
120 | ).images[0]
121 | video_frames_sdxl.append(np.asarray(image))
122 | assert not os.path.exists(save_path + ".joblib"), save_path + ".joblib"
123 | imageio.mimsave(save_path + ".gif", video_frames_sdxl, format="gif", loop=0)
124 | joblib.dump(video_frames_sdxl, save_path + ".joblib", compress=("bz2", 3))
125 |
126 | print(f"SDXL upsampled image saved at: {save_path + '.gif'}")
127 | else:
128 | print(f"{save_path + '.joblib'} exists, skipping")
129 |
130 |
131 | if __name__ == "__main__":
132 | parser = argparse.ArgumentParser()
133 | parser.add_argument(
134 | "--videos",
135 | nargs="+",
136 | required=True,
137 | type=str,
138 | help="path to videos in joblib format",
139 | )
140 | parser.add_argument("--prompts", nargs="+", required=True, type=str, help="prompts")
141 | parser.add_argument("--seed", type=int, default=1)
142 | parser.add_argument("--strength", type=float, default=0.35)
143 | parser.add_argument(
144 | "--negative_prompt",
145 | type=str,
146 | default="dull, gray, unrealistic, colorless, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
147 | )
148 | parser.add_argument("--use_zsxl", action="store_true")
149 | parser.add_argument("--use_sdxl", action="store_true")
150 | parser.add_argument("--use_zssdxl", action="store_true")
151 | parser.add_argument(
152 | "--horizontal",
153 | action="store_true",
154 | help="If True, the video is assumed to be horizontal (576x320 to 1024x576). If False, squared (512x512 to 1024x1024).",
155 | )
156 | parser.add_argument("--output-mp4", action="store_true", help="Store mp4 videos.")
157 |
158 | args = parser.parse_args()
159 |
160 | if args.use_zsxl:
161 | pipe_xl = VideoToVideoSDPipeline.from_pretrained(
162 | "cerspense/zeroscope_v2_XL", torch_dtype=torch.float16
163 | )
164 | pipe_xl.scheduler = DPMSolverMultistepScheduler.from_config(
165 | pipe_xl.scheduler.config
166 | )
167 | pipe_xl.enable_model_cpu_offload()
168 | pipe_xl.enable_vae_slicing()
169 |
170 | if args.use_sdxl or args.use_zssdxl:
171 | pipe_sdxl = StableDiffusionXLImg2ImgPipeline.from_pretrained(
172 | "stabilityai/stable-diffusion-xl-refiner-1.0",
173 | torch_dtype=torch.float16,
174 | variant="fp16",
175 | use_safetensors=True,
176 | )
177 | pipe_sdxl = pipe_sdxl.to("cuda")
178 |
179 | if len(args.prompts) == 1 and len(args.videos) > 1:
180 | args.prompts = args.prompts * len(args.videos)
181 |
182 | for video_path, prompt in tqdm(zip(args.videos, args.prompts)):
183 | video_path = video_path.replace(".gif", ".joblib")
184 | print(f"Video path: {video_path}, prompt: {prompt}")
185 |
186 | if args.use_zsxl:
187 | upsample_zsxl(
188 | video_path=video_path,
189 | prompt=prompt,
190 | horizontal=args.horizontal,
191 | negative_prompt=args.negative_prompt,
192 | seed=args.seed,
193 | strength=args.strength,
194 | use_zssdxl=args.use_zssdxl,
195 | output_mp4=args.output_mp4,
196 | )
197 |
198 | if args.use_sdxl:
199 | upsample_sdxl(
200 | video_path=video_path,
201 | prompt=prompt,
202 | horizontal=args.horizontal,
203 | negative_prompt=args.negative_prompt,
204 | seed=args.seed,
205 | strength=args.strength,
206 | )
207 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 |
--------------------------------------------------------------------------------
/utils/attn.py:
--------------------------------------------------------------------------------
1 | # visualization-related functions are in vis
2 | import numbers
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import math
7 | import utils
8 |
9 |
10 | def get_token_attnv2(
11 | token_id,
12 | saved_attns,
13 | attn_key,
14 | visualize_step_start=10,
15 | input_ca_has_condition_only=False,
16 | return_np=False,
17 | ):
18 | """
19 | saved_attns: a list of saved_attn (list is across timesteps)
20 |
21 | moves to cpu by default
22 | """
23 | saved_attns = saved_attns[visualize_step_start:]
24 |
25 | saved_attns = [saved_attn[attn_key].cpu() for saved_attn in saved_attns]
26 |
27 | attn = torch.stack(saved_attns, dim=0).mean(dim=0)
28 |
29 | # print("attn shape", attn.shape)
30 |
31 | # attn: (batch, head, spatial, text)
32 |
33 | if not input_ca_has_condition_only:
34 | assert (
35 | attn.shape[0] == 2
36 | ), f"Expect to have 2 items (uncond and cond), but found {attn.shape[0]} items"
37 | attn = attn[1]
38 | else:
39 | assert (
40 | attn.shape[0] == 1
41 | ), f"Expect to have 1 item (cond only), but found {attn.shape[0]} items"
42 | attn = attn[0]
43 | attn = attn.mean(dim=0)[:, token_id]
44 | H, W = utils.get_hw_from_attn_dim(attn_dim=attn.shape[0])
45 | attn = attn.reshape((H, W))
46 |
47 | if return_np:
48 | return attn.numpy()
49 |
50 | return attn
51 |
52 |
53 | def shift_saved_attns_item(
54 | saved_attns_item, offset, guidance_attn_keys, horizontal_shift_only=False
55 | ):
56 | """
57 | `horizontal_shift_only`: only shift horizontally. If you use `offset` from `compose_latents_with_alignment` with `horizontal_shift_only=True`, the `offset` already has y_offset = 0 and this option is not needed.
58 | """
59 | x_offset, y_offset = offset
60 | if horizontal_shift_only:
61 | y_offset = 0.0
62 |
63 | new_saved_attns_item = {}
64 | for k in guidance_attn_keys:
65 | attn_map = saved_attns_item[k]
66 |
67 | attn_size = attn_map.shape[-2]
68 | attn_h, attn_w = utils.get_hw_from_attn_dim(attn_dim=attn_size)
69 | # Example dimensions: [batch_size, num_heads, 8, 8, num_tokens]
70 | attn_map = attn_map.unflatten(2, (attn_h, attn_w))
71 | attn_map = utils.shift_tensor(
72 | attn_map, x_offset, y_offset, offset_normalized=True, ignore_last_dim=True
73 | )
74 | attn_map = attn_map.flatten(2, 3)
75 |
76 | new_saved_attns_item[k] = attn_map
77 |
78 | return new_saved_attns_item
79 |
80 |
81 | def shift_saved_attns(saved_attns, offset, guidance_attn_keys, **kwargs):
82 | # Iterate over timesteps
83 | shifted_saved_attns = [
84 | shift_saved_attns_item(saved_attns_item, offset, guidance_attn_keys, **kwargs)
85 | for saved_attns_item in saved_attns
86 | ]
87 |
88 | return shifted_saved_attns
89 |
90 |
91 | class GaussianSmoothing(nn.Module):
92 | """
93 | Apply gaussian smoothing on a
94 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
95 | in the input using a depthwise convolution.
96 | Arguments:
97 | channels (int, sequence): Number of channels of the input tensors. Output will
98 | have this number of channels as well.
99 | kernel_size (int, sequence): Size of the gaussian kernel.
100 | sigma (float, sequence): Standard deviation of the gaussian kernel.
101 | dim (int, optional): The number of dimensions of the data.
102 | Default value is 2 (spatial).
103 |
104 | Credit: https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/10
105 | """
106 |
107 | def __init__(self, channels, kernel_size, sigma, dim=2):
108 | super(GaussianSmoothing, self).__init__()
109 | if isinstance(kernel_size, numbers.Number):
110 | kernel_size = [kernel_size] * dim
111 | if isinstance(sigma, numbers.Number):
112 | sigma = [sigma] * dim
113 |
114 | # The gaussian kernel is the product of the
115 | # gaussian function of each dimension.
116 | kernel = 1
117 | meshgrids = torch.meshgrid(
118 | [torch.arange(size, dtype=torch.float32) for size in kernel_size]
119 | )
120 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
121 | mean = (size - 1) / 2
122 | kernel *= (
123 | 1
124 | / (std * math.sqrt(2 * math.pi))
125 | * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
126 | )
127 |
128 | # Make sure sum of values in gaussian kernel equals 1.
129 | kernel = kernel / torch.sum(kernel)
130 |
131 | # Reshape to depthwise convolutional weight
132 | kernel = kernel.view(1, 1, *kernel.size())
133 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
134 |
135 | self.register_buffer("weight", kernel)
136 | self.groups = channels
137 |
138 | if dim == 1:
139 | self.conv = F.conv1d
140 | elif dim == 2:
141 | self.conv = F.conv2d
142 | elif dim == 3:
143 | self.conv = F.conv3d
144 | else:
145 | raise RuntimeError(
146 | "Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim)
147 | )
148 |
149 | def forward(self, input):
150 | """
151 | Apply gaussian filter to input.
152 | Arguments:
153 | input (torch.Tensor): Input to apply gaussian filter on.
154 | Returns:
155 | filtered (torch.Tensor): Filtered output.
156 | """
157 | return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups)
158 |
--------------------------------------------------------------------------------
/utils/cache.py:
--------------------------------------------------------------------------------
1 | # If for a prompt we query fewer or equal to the times we have cache, we return from the cache sequentially. Otherwise we store into cache.
2 | # Need to set up a new cache if the hyperparam or the template changes.
3 |
4 | import os
5 | import pickle, json
6 |
7 | cache_path = ""
8 | cache_format = "json"
9 |
10 | global_cache = {}
11 |
12 | # Always obtain the first item (should be enabled for notebook debugging only)
13 | force_first_item = False
14 |
15 | # The cache records the access times to load more than one value in the cache when the keys repeat.
16 | global_cache_index = {}
17 | # This is for export and debugging the queries
18 | cache_queries = {}
19 |
20 |
21 | def reset_cache_access():
22 | global global_cache_index, cache_queries
23 | global_cache_index = {}
24 | cache_queries = {}
25 |
26 |
27 | def values_accessed():
28 | return sum(global_cache_index.values())
29 |
30 |
31 | def init_cache(allow_nonexist=True):
32 | global global_cache
33 | assert cache_path, "Need to set cache path"
34 |
35 | print(f"Cache path: {cache_path}")
36 |
37 | if not allow_nonexist:
38 | assert os.path.exists(cache_path), f"{cache_path} does not exist"
39 |
40 | if os.path.exists(cache_path):
41 | if cache_format == "pickle":
42 | with open(cache_path, "rb") as f:
43 | global_cache = pickle.load(f)
44 | elif cache_format == "json":
45 | with open(cache_path, "r") as f:
46 | global_cache = json.load(f)
47 |
48 |
49 | def get_cache(key):
50 | if key not in global_cache:
51 | global_cache[key] = []
52 |
53 | if key not in global_cache_index:
54 | global_cache_index[key] = 0
55 |
56 | current_items = global_cache[key]
57 | current_index = global_cache_index[key]
58 | if len(current_items) > current_index:
59 | if not force_first_item:
60 | global_cache_index[key] += 1
61 | if key not in cache_queries:
62 | cache_queries[key] = []
63 | cache_queries[key].append(current_items[current_index])
64 | return current_items[current_index]
65 |
66 | return None
67 |
68 |
69 | def add_cache(key, value):
70 | global_cache_index[key] += 1
71 | global_cache[key].append(value)
72 |
73 | if cache_format == "pickle":
74 | with open(cache_path, "wb") as f:
75 | pickle.dump(global_cache, f)
76 | elif cache_format == "json":
77 | with open(cache_path, "w") as f:
78 | json.dump(global_cache, f, indent=4)
79 |
80 | return value
81 |
82 |
83 | def pkl_to_json(filename):
84 | assert "pkl" in filename, filename
85 | with open(filename, "rb") as f:
86 | cache = pickle.load(f)
87 | del f
88 | filename = filename.replace("pkl", "json")
89 | assert not os.path.exists(filename)
90 | print(cache)
91 | with open(filename, "w") as f:
92 | json.dump(cache, f, indent=4)
93 |
--------------------------------------------------------------------------------
/utils/eval/__init__.py:
--------------------------------------------------------------------------------
1 | from .eval import *
2 |
--------------------------------------------------------------------------------
/utils/eval/eval.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from utils import parse
3 |
4 |
5 | def nms(
6 | bounding_boxes,
7 | confidence_score,
8 | labels,
9 | threshold,
10 | input_in_pixels=False,
11 | return_array=True,
12 | ):
13 | """
14 | This NMS processes boxes of all labels. It not only removes the box with the same label.
15 |
16 | Adapted from https://github.com/amusi/Non-Maximum-Suppression/blob/master/nms.py
17 | """
18 | # If no bounding boxes, return empty list
19 | if len(bounding_boxes) == 0:
20 | return np.array([]), np.array([]), np.array([])
21 |
22 | # Bounding boxes
23 | boxes = np.array(bounding_boxes)
24 |
25 | # coordinates of bounding boxes
26 | start_x = boxes[:, 0]
27 | start_y = boxes[:, 1]
28 | end_x = boxes[:, 2]
29 | end_y = boxes[:, 3]
30 |
31 | # Confidence scores of bounding boxes
32 | score = np.array(confidence_score)
33 |
34 | # Picked bounding boxes
35 | picked_boxes = []
36 | picked_score = []
37 | picked_labels = []
38 |
39 | # Compute areas of bounding boxes
40 | if input_in_pixels:
41 | areas = (end_x - start_x + 1) * (end_y - start_y + 1)
42 | else:
43 | areas = (end_x - start_x) * (end_y - start_y)
44 |
45 | # Sort by confidence score of bounding boxes
46 | order = np.argsort(score)
47 |
48 | # Iterate bounding boxes
49 | while order.size > 0:
50 | # The index of largest confidence score
51 | index = order[-1]
52 |
53 | # Pick the bounding box with largest confidence score
54 | picked_boxes.append(bounding_boxes[index])
55 | picked_score.append(confidence_score[index])
56 | picked_labels.append(labels[index])
57 |
58 | # Compute ordinates of intersection-over-union(IOU)
59 | x1 = np.maximum(start_x[index], start_x[order[:-1]])
60 | x2 = np.minimum(end_x[index], end_x[order[:-1]])
61 | y1 = np.maximum(start_y[index], start_y[order[:-1]])
62 | y2 = np.minimum(end_y[index], end_y[order[:-1]])
63 |
64 | # Compute areas of intersection-over-union
65 | if input_in_pixels:
66 | w = np.maximum(0.0, x2 - x1 + 1)
67 | h = np.maximum(0.0, y2 - y1 + 1)
68 | else:
69 | w = np.maximum(0.0, x2 - x1)
70 | h = np.maximum(0.0, y2 - y1)
71 | intersection = w * h
72 |
73 | # Compute the ratio between intersection and union
74 | ratio = intersection / (areas[index] + areas[order[:-1]] - intersection)
75 |
76 | left = np.where(ratio < threshold)
77 | order = order[left]
78 |
79 | if return_array:
80 | picked_boxes, picked_score, picked_labels = (
81 | np.array(picked_boxes),
82 | np.array(picked_score),
83 | np.array(picked_labels),
84 | )
85 |
86 | return picked_boxes, picked_score, picked_labels
87 |
88 |
89 | def class_aware_nms(
90 | bounding_boxes, confidence_score, labels, threshold, input_in_pixels=False
91 | ):
92 | """
93 | This NMS processes boxes of each label individually.
94 | """
95 | # If no bounding boxes, return empty list
96 | if len(bounding_boxes) == 0:
97 | return np.array([]), np.array([]), np.array([])
98 |
99 | picked_boxes, picked_score, picked_labels = [], [], []
100 |
101 | labels_unique = np.unique(labels)
102 | for label in labels_unique:
103 | bounding_boxes_label = [
104 | bounding_box
105 | for i, bounding_box in enumerate(bounding_boxes)
106 | if labels[i] == label
107 | ]
108 | confidence_score_label = [
109 | confidence_score_item
110 | for i, confidence_score_item in enumerate(confidence_score)
111 | if labels[i] == label
112 | ]
113 | labels_label = [label] * len(bounding_boxes_label)
114 | picked_boxes_label, picked_score_label, picked_labels_label = nms(
115 | bounding_boxes_label,
116 | confidence_score_label,
117 | labels_label,
118 | threshold=threshold,
119 | input_in_pixels=input_in_pixels,
120 | return_array=False,
121 | )
122 | picked_boxes += picked_boxes_label
123 | picked_score += picked_score_label
124 | picked_labels += picked_labels_label
125 |
126 | picked_boxes, picked_score, picked_labels = (
127 | np.array(picked_boxes),
128 | np.array(picked_score),
129 | np.array(picked_labels),
130 | )
131 |
132 | return picked_boxes, picked_score, picked_labels
133 |
134 |
135 | def evaluate_with_layout(
136 | parsed_layout, predicate, num_parsed_layout_frames, height, width, verbose=False
137 | ):
138 | condition = parse.parsed_layout_to_condition(
139 | parsed_layout,
140 | tokenizer=None,
141 | height=height,
142 | width=width,
143 | num_parsed_layout_frames=num_parsed_layout_frames,
144 | num_condition_frames=num_parsed_layout_frames,
145 | strip_phrases=True,
146 | verbose=True,
147 | )
148 |
149 | print("condition:", condition)
150 |
151 | prompt_type = predicate.type
152 | success = predicate(condition, verbose=verbose)
153 |
154 | return prompt_type, success
155 |
156 |
157 | def to_gen_box_format(box, width, height, rounding):
158 | # Input: xyxy, ranging from 0 to 1
159 | # Output: xywh, unnormalized (in pixels)
160 | x_min, y_min, x_max, y_max = box
161 | if rounding:
162 | return [
163 | round(x_min * width),
164 | round(y_min * height),
165 | round((x_max - x_min) * width),
166 | round((y_max - y_min) * height),
167 | ]
168 | return [
169 | x_min * width,
170 | y_min * height,
171 | (x_max - x_min) * width,
172 | (y_max - y_min) * height,
173 | ]
174 |
--------------------------------------------------------------------------------
/utils/eval/lvd.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from functools import partial
3 | from .utils import (
4 | p,
5 | predicate_numeracy,
6 | predicate_attribution,
7 | predicate_visibility,
8 | predicate_1obj_dynamic_spatial,
9 | predicate_2obj_dynamic_spatial,
10 | predicate_sequentialv2,
11 | )
12 |
13 | prompt_prefix = "A realistic lively video of a scene"
14 | prompt_top_down_prefix = "A realistic lively video of a top-down viewed scene"
15 |
16 | evaluate_classes = [
17 | ("moving car", "car"),
18 | ("lively cat", "cat"),
19 | ("flying bird", "bird"),
20 | ("moving ball", "ball"),
21 | ("walking dog", "dog"),
22 | ]
23 | evaluate_classes_no_attribute = [
24 | evaluate_class_no_attribute
25 | for evaluate_class, evaluate_class_no_attribute in evaluate_classes
26 | ]
27 |
28 |
29 | def get_prompt_predicates_numeracy(min_num=1, max_num=5, repeat=2):
30 | modifier = ""
31 |
32 | prompt_predicates = []
33 |
34 | for number in range(min_num, max_num + 1):
35 | for object_name, object_name_no_attribute in evaluate_classes:
36 | if prompt_prefix:
37 | prompt = f"{prompt_prefix} with {p.number_to_words(number) if number < 21 else number}{modifier} {p.plural(object_name) if number > 1 else object_name}"
38 | else:
39 | prompt = f"{p.number_to_words(number) if number < 21 else number}{modifier} {p.plural(object_name) if number > 1 else object_name}"
40 | prompt = prompt.strip()
41 |
42 | # `query_names` needs to match with `texts` since `query_names` will be searched in the detection of `texts`
43 | query_names = (object_name_no_attribute,)
44 | predicate = partial(predicate_numeracy, query_names, number)
45 | predicate.type = "numeracy"
46 | predicate.texts = [f"a photo of {p.a(object_name_no_attribute)}"]
47 | # We don't have tracking, but mismatch does not matter for numeracy.
48 | predicate.one_box_per_class = False
49 | prompt_predicate = prompt, predicate
50 |
51 | prompt_predicates += [prompt_predicate] * repeat
52 |
53 | return prompt_predicates
54 |
55 |
56 | def process_object_name(object_name):
57 | if isinstance(object_name, tuple):
58 | query_names = object_name
59 | object_name = object_name[0]
60 | else:
61 | query_names = (object_name,)
62 |
63 | return object_name, query_names
64 |
65 |
66 | def get_prompt_predicates_attribution(num_prompts=100, repeat=1):
67 | prompt_predicates = []
68 |
69 | intended_count1, intended_count2 = 1, 1
70 |
71 | modifiers = [
72 | "red",
73 | "orange",
74 | "yellow",
75 | "green",
76 | "blue",
77 | "purple",
78 | "pink",
79 | "brown",
80 | "black",
81 | "white",
82 | "gray",
83 | ]
84 |
85 | for ind in range(num_prompts):
86 | np.random.seed(ind)
87 | modifier1, modifier2 = np.random.choice(modifiers, 2, replace=False)
88 | object_name1, object_name2 = np.random.choice(
89 | evaluate_classes_no_attribute, 2, replace=False
90 | )
91 |
92 | object_name1, query_names1 = process_object_name(object_name1)
93 | object_name2, query_names2 = process_object_name(object_name2)
94 |
95 | if prompt_prefix:
96 | prompt = f"{prompt_prefix} with {p.a(modifier1)} {object_name1} and {p.a(modifier2)} {object_name2}"
97 | else:
98 | prompt = (
99 | f"{p.a(modifier1)} {object_name1} and {p.a(modifier2)} {object_name2}"
100 | )
101 | prompt = prompt.strip()
102 |
103 | # `query_names` needs to match with `texts` since `query_names` will be searched in the detection of `texts`
104 | predicate = partial(
105 | predicate_attribution,
106 | query_names1,
107 | query_names2,
108 | modifier1,
109 | modifier2,
110 | intended_count1,
111 | intended_count2,
112 | )
113 |
114 | prompt_predicate = prompt, predicate
115 |
116 | predicate.type = "attribution"
117 | predicate.texts = [
118 | f"a photo of {p.a(modifier1)} {object_name1}",
119 | f"a photo of {p.a(modifier2)} {object_name2}",
120 | ]
121 | # Limit to one box per class.
122 | predicate.one_box_per_class = True
123 |
124 | prompt_predicates += [prompt_predicate] * repeat
125 |
126 | return prompt_predicates
127 |
128 |
129 | def get_prompt_predicates_visibility(repeat=2):
130 | prompt_predicates = []
131 |
132 | for object_name, object_name_no_attribute in evaluate_classes:
133 | # `query_names` needs to match with `texts` since `query_names` will be searched in the detection of `texts`
134 | query_names = (object_name_no_attribute,)
135 |
136 | for i in range(2):
137 | # i == 0: appeared
138 | # i == 1: disappeared
139 |
140 | if i == 0:
141 | prompt = f"{prompt_prefix} in which {p.a(object_name)} appears only in the second half of the video"
142 | # Shouldn't use lambda here since query_names (and number) might change.
143 | predicate = partial(predicate_visibility, query_names, True)
144 | prompt_predicate = prompt, predicate
145 | else:
146 | prompt = f"{prompt_prefix} in which {p.a(object_name)} appears only in the first half of the video"
147 | # Shouldn't use lambda here since query_names (and number) might change.
148 | predicate = partial(predicate_visibility, query_names, False)
149 | prompt_predicate = prompt, predicate
150 |
151 | predicate.type = "visibility"
152 | predicate.texts = [f"a photo of {p.a(object_name_no_attribute)}"]
153 | # Limit to one box per class.
154 | predicate.one_box_per_class = True
155 |
156 | prompt_predicates += [prompt_predicate] * repeat
157 |
158 | return prompt_predicates
159 |
160 |
161 | def get_prompt_predicates_1obj_dynamic_spatial(repeat=1, left_right_only=True):
162 | prompt_predicates = []
163 |
164 | # NOTE: the boxes are in (x_min, y_min, x_max, y_max) format. This is because LVD uses `condition` rather than `gen_boxes` format as the input. `condition` has processed the coordinates.
165 | locations = [
166 | (
167 | "left",
168 | "right",
169 | lambda box1, box2: (box1[0] + box1[2]) / 2 < (box2[0] + box2[2]) / 2,
170 | ),
171 | (
172 | "right",
173 | "left",
174 | lambda box1, box2: (box1[0] + box1[2]) / 2 > (box2[0] + box2[2]) / 2,
175 | ),
176 | ]
177 | if not left_right_only:
178 | # NOTE: the boxes are in (x_min, y_min, x_max, y_max) format.
179 | locations += [
180 | (
181 | "top",
182 | "bottom",
183 | lambda box1, box2: (box1[1] + box1[3]) / 2 < (box2[1] + box2[3]) / 2,
184 | ),
185 | (
186 | "bottom",
187 | "top",
188 | lambda box1, box2: (box1[1] + box1[3]) / 2 > (box2[1] + box2[3]) / 2,
189 | ),
190 | ]
191 |
192 | # We use object names without motion attributes for spatial since the attribute words may interfere with the intended motion.
193 | for object_name_no_attribute in evaluate_classes_no_attribute:
194 | # `query_names` needs to match with `texts` since `query_names` will be searched in the detection of `texts`
195 | query_names = (object_name_no_attribute,)
196 |
197 | for location1, location2, verify_fn in locations:
198 | prompt = f"{prompt_prefix} with {p.a(object_name_no_attribute)} moving from the {location1} to the {location2}"
199 | prompt = prompt.strip()
200 |
201 | predicate = partial(predicate_1obj_dynamic_spatial, query_names, verify_fn)
202 | prompt_predicate = prompt, predicate
203 | predicate.type = "dynamic_spatial"
204 | predicate.texts = [f"a photo of {p.a(object_name_no_attribute)}"]
205 | # Limit to one box per class.
206 | predicate.one_box_per_class = True
207 |
208 | prompt_predicates += [prompt_predicate] * repeat
209 |
210 | return prompt_predicates
211 |
212 |
213 | def get_prompt_predicates_2obj_dynamic_spatial(
214 | num_prompts=10, repeat=1, left_right_only=True
215 | ):
216 | prompt_predicates = []
217 |
218 | # NOTE: the boxes are in (x_min, y_min, x_max, y_max) format. This is because LVD uses `condition` rather than `gen_boxes` format as the input. `condition` has processed the coordinates.
219 | locations = [
220 | (
221 | "left",
222 | "right",
223 | lambda box1, box2: (box1[0] + box1[2]) / 2 < (box2[0] + box2[2]) / 2,
224 | ),
225 | (
226 | "right",
227 | "left",
228 | lambda box1, box2: (box1[0] + box1[2]) / 2 > (box2[0] + box2[2]) / 2,
229 | ),
230 | ]
231 | if not left_right_only:
232 | # NOTE: the boxes are in (x_min, y_min, x_max, y_max) format.
233 | locations += [
234 | (
235 | "top",
236 | "bottom",
237 | lambda box1, box2: (box1[1] + box1[3]) / 2 < (box2[1] + box2[3]) / 2,
238 | ),
239 | (
240 | "bottom",
241 | "top",
242 | lambda box1, box2: (box1[1] + box1[3]) / 2 > (box2[1] + box2[3]) / 2,
243 | ),
244 | ]
245 |
246 | # We use object names without motion attributes for spatial since the attribute words may interfere with the intended motion.
247 | for ind in range(num_prompts):
248 | np.random.seed(ind)
249 | for location1, location2, verify_fn in locations:
250 | object_name1, object_name2 = np.random.choice(
251 | evaluate_classes_no_attribute, 2, replace=False
252 | )
253 |
254 | object_name1, query_names1 = process_object_name(object_name1)
255 | object_name2, query_names2 = process_object_name(object_name2)
256 |
257 | prompt = f"{prompt_prefix} with {p.a(object_name1)} moving from the {location1} of {p.a(object_name2)} to its {location2}"
258 | prompt = prompt.strip()
259 |
260 | # `query_names` needs to match with `texts` since `query_names` will be searched in the detection of `texts`
261 | predicate = partial(
262 | predicate_2obj_dynamic_spatial, query_names1, query_names2, verify_fn
263 | )
264 | prompt_predicate = prompt, predicate
265 | predicate.type = "dynamic_spatial"
266 | predicate.texts = [
267 | f"a photo of {p.a(object_name1)}",
268 | f"a photo of {p.a(object_name2)}",
269 | ]
270 | # Limit to one box per class.
271 | predicate.one_box_per_class = True
272 | prompt_predicates += [prompt_predicate] * repeat
273 |
274 | return prompt_predicates
275 |
276 |
277 | def get_prompt_predicates_sequential(repeat=1):
278 | prompt_predicates = []
279 |
280 | locations = [
281 | ("lower left", "lower right", "upper right"),
282 | ("lower left", "upper left", "upper right"),
283 | ("lower right", "lower left", "upper left"),
284 | ("lower right", "upper right", "upper left"),
285 | ]
286 | verify_fns = {
287 | # lower: y is large
288 | "lower left": lambda box: (box[1] + box[3]) / 2 > 0.5
289 | and (box[0] + box[2]) / 2 < 0.5,
290 | "lower right": lambda box: (box[1] + box[3]) / 2 > 0.5
291 | and (box[0] + box[2]) / 2 > 0.5,
292 | "upper left": lambda box: (box[1] + box[3]) / 2 < 0.5
293 | and (box[0] + box[2]) / 2 < 0.5,
294 | "upper right": lambda box: (box[1] + box[3]) / 2 < 0.5
295 | and (box[0] + box[2]) / 2 > 0.5,
296 | }
297 |
298 | for object_name_no_attribute in evaluate_classes_no_attribute:
299 | # `query_names` needs to match with `texts` since `query_names` will be searched in the detection of `texts`
300 | query_names = (object_name_no_attribute,)
301 |
302 | for location1, location2, location3 in locations:
303 | # We check the appearance/disappearance in addition to whether the object is on the right side in the last frame compared to the initial frame.
304 | prompt = f"{prompt_top_down_prefix} in which {p.a(object_name_no_attribute)} initially on the {location1} of the scene. It first moves to the {location2} of the scene and then moves to the {location3} of the scene."
305 |
306 | # Shouldn't use lambda here since query_names (and number) might change.
307 | predicate = partial(
308 | predicate_sequentialv2,
309 | query_names,
310 | verify_fns[location1],
311 | verify_fns[location2],
312 | verify_fns[location3],
313 | )
314 | # predicate = partial(predicate_sequentialv2, query_names, verify_fn1, verify_fn2)
315 | prompt_predicate = prompt, predicate
316 | predicate.type = "sequential"
317 | predicate.texts = [f"a photo of {p.a(object_name_no_attribute)}"]
318 | # Limit to one box per class.
319 | predicate.one_box_per_class = True
320 | prompt_predicates += [prompt_predicate] * repeat
321 |
322 | return prompt_predicates
323 |
324 |
325 | def get_lvd_full_prompt_predicates(prompt_type=None):
326 | # numeracy: 100 prompts, number 1 to 4, 5 classes, repeat 5 times
327 | prompt_predicates_numeracy = get_prompt_predicates_numeracy(max_num=4, repeat=5)
328 | # attribution: 100 prompts, two objects in each prompt, each with attributes (randomly sampled)
329 | prompt_predicates_attribution = get_prompt_predicates_attribution(num_prompts=100)
330 | # visibility: 100 prompts, 5 classes, appear/disappear, repeat 10 times
331 | prompt_predicates_visibility = get_prompt_predicates_visibility(repeat=10)
332 | # dynamic spatial: 100 prompts
333 | # 1 object: 50 prompts, 5 classes, left/right, repeat 5 times
334 | # 2 objects: 50 prompts, randomly sample two objects 25 times, left/right
335 | prompt_predicates_1obj_dynamic_spatial = get_prompt_predicates_1obj_dynamic_spatial(
336 | repeat=5
337 | )
338 | prompt_predicates_2obj_dynamic_spatial = get_prompt_predicates_2obj_dynamic_spatial(
339 | num_prompts=25
340 | )
341 | prompt_predicates_dynamic_spatial = (
342 | prompt_predicates_1obj_dynamic_spatial + prompt_predicates_2obj_dynamic_spatial
343 | )
344 | # sequential: 100 prompts, 5 classes, 4 location triplets, repeat 5 times
345 | prompt_predicates_sequential = get_prompt_predicates_sequential(repeat=5)
346 |
347 | prompt_predicates_static_all = (
348 | prompt_predicates_numeracy + prompt_predicates_attribution
349 | )
350 | prompts_predicates_dynamic_all = (
351 | prompt_predicates_visibility
352 | + prompt_predicates_dynamic_spatial
353 | + prompt_predicates_sequential
354 | )
355 |
356 | # Each one has 100 prompts
357 | prompt_predicates_all = (
358 | prompt_predicates_numeracy
359 | + prompt_predicates_attribution
360 | + prompt_predicates_visibility
361 | + prompt_predicates_dynamic_spatial
362 | + prompt_predicates_sequential
363 | )
364 |
365 | prompt_predicates = {
366 | "lvd": prompt_predicates_all,
367 | "lvd_static": prompt_predicates_static_all,
368 | "lvd_numeracy": prompt_predicates_numeracy,
369 | "lvd_attribution": prompt_predicates_attribution,
370 | "lvd_dynamic": prompts_predicates_dynamic_all,
371 | "lvd_dynamic_spatial": prompt_predicates_dynamic_spatial,
372 | "lvd_visibility": prompt_predicates_visibility,
373 | "lvd_sequential": prompt_predicates_sequential,
374 | }
375 |
376 | if prompt_type is not None:
377 | return prompt_predicates[prompt_type]
378 | else:
379 | return prompt_predicates
380 |
381 |
382 | def get_lvd_full_prompts(prompt_type):
383 | prompt_predicates = get_lvd_full_prompt_predicates(prompt_type)
384 | if prompt_type is not None:
385 | return [item[0] for item in prompt_predicates]
386 | else:
387 | return {k: [item[0] for item in v] for k, v in prompt_predicates.items()}
388 |
389 |
390 | if __name__ == "__main__":
391 | prompt_predicates = get_lvd_full_prompt_predicates("lvdv1.1")
392 |
393 | print(
394 | np.unique(
395 | [predicate.type for prompt, predicate in prompt_predicates],
396 | return_counts=True,
397 | )
398 | )
399 | print(len(prompt_predicates))
400 |
--------------------------------------------------------------------------------
/utils/eval/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import inflect
3 | import re
4 | from utils.parse import Condition
5 |
6 | p = inflect.engine()
7 |
8 | def find_word_after(text, word):
9 | pattern = r"\b" + re.escape(word) + r"\s+(.+)"
10 | match = re.search(pattern, text)
11 | if match:
12 | return match.group(1)
13 | else:
14 | return None
15 |
16 |
17 | word_to_num_mapping = {p.number_to_words(i): i for i in range(1, 21)}
18 |
19 | locations_xyxy = {
20 | ("left", "right"): (lambda box1, box2: (box1[0] + box1[2]) < (box2[0] + box2[2])),
21 | ("right", "left"): (lambda box1, box2: (box1[0] + box1[2]) > (box2[0] + box2[2])),
22 | ("top", "bottom"): (lambda box1, box2: (box1[1] + box1[3]) < (box2[1] + box2[3])),
23 | ("bottom", "top"): (lambda box1, box2: (box1[1] + box1[3]) > (box2[1] + box2[3])),
24 | }
25 |
26 | locations_xywh = {
27 | ("left", "right"): (
28 | lambda box1, box2: box1[0] + box1[2] / 2 < box2[0] + box2[2] / 2
29 | ),
30 | ("right", "left"): (
31 | lambda box1, box2: box1[0] + box1[2] / 2 > box2[0] + box2[2] / 2
32 | ),
33 | ("top", "bottom"): (
34 | lambda box1, box2: box1[1] + box1[3] / 2 < box2[1] + box2[3] / 2
35 | ),
36 | ("bottom", "top"): (
37 | lambda box1, box2: box1[1] + box1[3] / 2 > box2[1] + box2[3] / 2
38 | ),
39 | }
40 |
41 |
42 | def singular(noun):
43 | singular_noun = p.singular_noun(noun)
44 | if singular_noun is False:
45 | return noun
46 | return singular_noun
47 |
48 |
49 | def get_box(condition: Condition, name_include):
50 | # This prevents substring match on non-word boundaries: carrot vs car
51 | box_match = [
52 | any(
53 | [
54 | (
55 | (name_include_item + " ") in phrase
56 | or phrase.endswith(name_include_item)
57 | )
58 | for name_include_item in name_include
59 | ]
60 | )
61 | for phrase in condition.phrases
62 | ]
63 |
64 | if not any(box_match):
65 | return None
66 |
67 | boxes = condition.boxes
68 | box_ind = np.min(np.where(box_match)[0])
69 |
70 | return boxes[box_ind]
71 |
72 |
73 | def get_box_counts(condition):
74 | if len(condition.boxes) == 0:
75 | # No boxes
76 | return None
77 |
78 | box_counts = None
79 |
80 | for i, box in enumerate(condition.boxes):
81 | if i == 0:
82 | num_frames = len(box)
83 | box_counts = [0 for _ in range(num_frames)]
84 | else:
85 | assert num_frames == len(box), f"{num_frames} != {len(box)}"
86 | valid_frames = box_to_valid_frames(box)
87 |
88 | for frame_index, valid in enumerate(valid_frames):
89 | if valid:
90 | box_counts[frame_index] += 1
91 |
92 | return box_counts
93 |
94 |
95 | def predicate_numeracy(query_names, intended_count, condition, verbose=False):
96 | assert len(query_names) == 1
97 | name_include = query_names
98 | box_match = [
99 | any(
100 | [
101 | (
102 | (name_include_item + " ") in phrase
103 | or phrase.endswith(name_include_item)
104 | )
105 | for name_include_item in name_include
106 | ]
107 | )
108 | for phrase in condition.phrases
109 | ]
110 |
111 | # We do not have tracking in stage 2 evaluation, so let's put this assertion to be safe.
112 | # This could only be a problem for stage 1 where additional non-relevant boxes are generated, but so far we did not see additional boxes in stage 1.
113 | assert len(box_match) == len(
114 | condition.boxes
115 | ), "Currently do not support the case where other boxes are also generated"
116 |
117 | box_counts = get_box_counts(condition)
118 |
119 | if box_counts is None:
120 | majority_box_counts = 0
121 | else:
122 | majority_box_counts = np.bincount(box_counts).argmax()
123 |
124 | object_count = majority_box_counts
125 | if verbose:
126 | print(
127 | f"box_counts: {box_counts}, object_count: {object_count}, intended_count: {intended_count} (condition: {condition}, query_names: {query_names})"
128 | )
129 |
130 | success = object_count == intended_count
131 |
132 | return success
133 |
134 |
135 | def box_to_valid_frames(object_box):
136 | object_box = np.array(object_box)
137 | x, y, w, h = object_box[:, 0], object_box[:, 1], object_box[:, 2], object_box[:, 3]
138 | # If the box has 0 width or height, it is not valid.
139 | valid_frames = (w != 0) & (h != 0)
140 |
141 | return valid_frames
142 |
143 |
144 | def predicate_visibility(query_names, test_appearance, condition, verbose=False):
145 | # condition: dict with keys 'name' and 'bounding_box'
146 |
147 | object_box = get_box(condition, query_names)
148 | if not object_box:
149 | return False
150 |
151 | valid_frames = box_to_valid_frames(object_box)
152 |
153 | num_frames = len(valid_frames)
154 | first_half_index = num_frames // 2
155 |
156 | # Ignore the two frames in the middle since there may be discrepancies between the LLM's understanding of the middle frame and the middle frame after interpolation (in generation) and sampling (in evaluation).
157 | valid_frames_first_half, valid_frames_second_half = (
158 | valid_frames[: first_half_index - 1],
159 | valid_frames[first_half_index + 1 :],
160 | )
161 | present_in_first_half, present_in_second_half = (
162 | any(valid_frames_first_half),
163 | any(valid_frames_second_half),
164 | )
165 |
166 | if test_appearance:
167 | # Test appearing: we ensure the object is not in the first half but needs to be present in the second half.
168 | success = (not present_in_first_half) and present_in_second_half
169 | else:
170 | # Test disappearing: we ensure the object is in the first half but needs to be absent in the second half.
171 | success = present_in_first_half and (not present_in_second_half)
172 |
173 | if verbose:
174 | print(
175 | f"Test appearance: {test_appearance}, valid_frames: {valid_frames}, appeared at first half: {present_in_first_half}, appeared at second half: {present_in_second_half}"
176 | )
177 |
178 | return success
179 |
180 |
181 | def predicate_attribution(
182 | query_names1,
183 | query_names2,
184 | modifier1,
185 | modifier2,
186 | intended_count1,
187 | intended_count2,
188 | condition,
189 | verbose=False,
190 | ):
191 | # Attribution does not use count now
192 | assert intended_count1 == 1 and intended_count2 == 1
193 |
194 | if modifier1:
195 | query_names1 = [f"{modifier1} {item}" for item in query_names1]
196 | object_box1 = get_box(condition, name_include=query_names1)
197 |
198 | if object_box1 is None:
199 | return False
200 |
201 | valid_frames1 = box_to_valid_frames(object_box1)
202 | if valid_frames1.mean() < 0.5:
203 | # Not detected at more than half of frames
204 | return False
205 |
206 | if query_names2 is None:
207 | # Only one object
208 | return True
209 |
210 | if modifier2:
211 | query_names2 = [f"{modifier2} {item}" for item in query_names2]
212 | object_box2 = get_box(condition, name_include=query_names2)
213 |
214 | if object_box2 is None:
215 | return False
216 |
217 | valid_frames2 = box_to_valid_frames(object_box2)
218 | if valid_frames2.mean() < 0.5:
219 | # Not detected at more than half of frames
220 | return False
221 |
222 | if verbose:
223 | print(f"Object box 1: {object_box1}, Object box 2: {object_box2}")
224 |
225 | return True
226 |
227 |
228 | def predicate_1obj_dynamic_spatial(query_names, verify_fn, condition, verbose=False):
229 | object_box = get_box(condition, query_names)
230 | if not object_box:
231 | return False
232 |
233 | valid_frames = box_to_valid_frames(object_box)
234 | if not valid_frames[0] or not valid_frames[-1]:
235 | return False
236 |
237 | # For example, from the left to the right: object in the first frame is on the left compared to the object in the last frame
238 | success = verify_fn(object_box[0], object_box[-1])
239 |
240 | return success
241 |
242 |
243 | def predicate_2obj_dynamic_spatial(
244 | query_names1, query_names2, verify_fn, condition, verbose=False
245 | ):
246 | object_box1 = get_box(condition, query_names1)
247 | object_box2 = get_box(condition, query_names2)
248 |
249 | if verbose:
250 | print(f"object_box1: {object_box1}, object_box2: {object_box2}")
251 |
252 | if not object_box1 or not object_box2:
253 | return False
254 |
255 | valid_frames1 = box_to_valid_frames(object_box1)
256 | valid_frames2 = box_to_valid_frames(object_box2)
257 | if (
258 | not valid_frames1[0]
259 | or not valid_frames2[0]
260 | or not valid_frames1[-1]
261 | or not valid_frames2[-1]
262 | ):
263 | return False
264 |
265 | # For example, `object 1` moving from the left of `object 2` to the right: object 1 in the first frame is on the left compared to object 1 in the first frame; object 1 in the last frame is on the right compared to object 2 in the last frame
266 | success1 = verify_fn(object_box1[0], object_box2[0])
267 | success2 = verify_fn(object_box2[-1], object_box1[-1])
268 | success = success1 and success2
269 |
270 | return success
271 |
272 |
273 | def predicate_sequentialv2(
274 | query_names, verify_fn1, verify_fn2, verify_fn3, condition, verbose=False
275 | ):
276 | # condition: dict with keys 'name' and 'bounding_box'
277 |
278 | object_box = get_box(condition, query_names)
279 | if verbose:
280 | print(f"object_box: {object_box}")
281 |
282 | if not object_box:
283 | return False
284 |
285 | valid_frames = box_to_valid_frames(object_box)
286 | if verbose:
287 | print(f"valid_frames: {valid_frames}")
288 |
289 | num_frames = len(valid_frames)
290 | middle_frame_index = num_frames // 2
291 |
292 | # Need to be present in the first, the middle, and the last frame
293 | if (
294 | not valid_frames[0]
295 | or not valid_frames[middle_frame_index]
296 | or not valid_frames[-1]
297 | ):
298 | return False
299 |
300 | # Need to be on the right place in the first, middle, and last frame
301 | success1 = verify_fn1(object_box[0])
302 | success2 = verify_fn2(object_box[middle_frame_index])
303 | success3 = verify_fn3(object_box[-1])
304 |
305 | if verbose:
306 | print(
307 | f"success1: {success1} ({object_box[0]}), success2: {success2} ({object_box[middle_frame_index]}), success3: {success3} ({object_box[-1]})"
308 | )
309 |
310 | success = success1 and success2 and success3
311 | return success
312 |
--------------------------------------------------------------------------------
/utils/latents.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from . import utils
4 | from utils import torch_device
5 | import matplotlib.pyplot as plt
6 |
7 |
8 | def get_unscaled_latents(batch_size, in_channels, height, width, generator, dtype):
9 | """
10 | in_channels: often obtained with `unet.config.in_channels`
11 | """
12 | # Obtain with torch.float32 and cast to float16 if needed
13 | # Directly obtaining latents in float16 will lead to different latents
14 | latents_base = torch.randn(
15 | (batch_size, in_channels, height // 8, width // 8),
16 | generator=generator,
17 | dtype=dtype,
18 | ).to(torch_device, dtype=dtype)
19 |
20 | return latents_base
21 |
22 |
23 | def get_scaled_latents(
24 | batch_size, in_channels, height, width, generator, dtype, scheduler
25 | ):
26 | latents_base = get_unscaled_latents(
27 | batch_size, in_channels, height, width, generator, dtype
28 | )
29 | latents_base = latents_base * scheduler.init_noise_sigma
30 | return latents_base
31 |
32 |
33 | def blend_latents(latents_bg, latents_fg, fg_mask, fg_blending_ratio=0.01):
34 | """
35 | in_channels: often obtained with `unet.config.in_channels`
36 | """
37 | assert not torch.allclose(
38 | latents_bg, latents_fg
39 | ), "latents_bg should be independent with latents_fg"
40 |
41 | dtype = latents_bg.dtype
42 | latents = (
43 | latents_bg * (1.0 - fg_mask)
44 | + (
45 | latents_bg * np.sqrt(1.0 - fg_blending_ratio)
46 | + latents_fg * np.sqrt(fg_blending_ratio)
47 | )
48 | * fg_mask
49 | )
50 | latents = latents.to(dtype=dtype)
51 |
52 | return latents
53 |
54 |
55 | @torch.no_grad()
56 | def compose_latents(
57 | model_dict,
58 | latents_all_list,
59 | mask_tensor_list,
60 | num_inference_steps,
61 | overall_batch_size,
62 | height,
63 | width,
64 | latents_bg=None,
65 | bg_seed=None,
66 | compose_box_to_bg=True,
67 | use_fast_schedule=False,
68 | fast_after_steps=None,
69 | ):
70 | unet, scheduler, dtype = model_dict.unet, model_dict.scheduler, model_dict.dtype
71 |
72 | if latents_bg is None:
73 | generator = torch.manual_seed(
74 | bg_seed
75 | ) # Seed generator to create the inital latent noise
76 | latents_bg = get_scaled_latents(
77 | overall_batch_size,
78 | unet.config.in_channels,
79 | height,
80 | width,
81 | generator,
82 | dtype,
83 | scheduler,
84 | )
85 |
86 | # Other than t=T (idx=0), we only have masked latents. This is to prevent accidentally loading from non-masked part. Use same mask as the one used to compose the latents.
87 | if use_fast_schedule:
88 | # If we use fast schedule, we only compose the frozen steps because the later steps do not match.
89 | composed_latents = torch.zeros(
90 | (fast_after_steps + 1, *latents_bg.shape), dtype=dtype
91 | )
92 | else:
93 | # Otherwise we compose all steps so that we don't need to compose again if we change the frozen steps.
94 | composed_latents = torch.zeros(
95 | (num_inference_steps + 1, *latents_bg.shape), dtype=dtype
96 | )
97 | composed_latents[0] = latents_bg
98 |
99 | foreground_indices = torch.zeros(latents_bg.shape[-2:], dtype=torch.long)
100 |
101 | mask_size = np.array([mask_tensor.sum().item() for mask_tensor in mask_tensor_list])
102 | # Compose the largest mask first
103 | mask_order = np.argsort(-mask_size)
104 |
105 | if compose_box_to_bg:
106 | # This has two functionalities:
107 | # 1. copies the right initial latents from the right place (for centered so generation), 2. copies the right initial latents (since we have foreground blending) for centered/original so generation.
108 | for mask_idx in mask_order:
109 | latents_all, mask_tensor = (
110 | latents_all_list[mask_idx],
111 | mask_tensor_list[mask_idx],
112 | )
113 |
114 | # Note: need to be careful to not copy from zeros due to shifting.
115 | mask_tensor = utils.binary_mask_to_box_mask(mask_tensor, to_device=False)
116 |
117 | mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype)
118 | composed_latents[0] = (
119 | composed_latents[0] * (1.0 - mask_tensor_expanded)
120 | + latents_all[0] * mask_tensor_expanded
121 | )
122 |
123 | # This is still needed with `compose_box_to_bg` to ensure the foreground latent is still visible and to compute foreground indices.
124 | for mask_idx in mask_order:
125 | latents_all, mask_tensor = (
126 | latents_all_list[mask_idx],
127 | mask_tensor_list[mask_idx],
128 | )
129 | foreground_indices = (
130 | foreground_indices * (~mask_tensor) + (mask_idx + 1) * mask_tensor
131 | )
132 | mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype)
133 | if use_fast_schedule:
134 | composed_latents = (
135 | composed_latents * (1.0 - mask_tensor_expanded)
136 | + latents_all[: fast_after_steps + 1] * mask_tensor_expanded
137 | )
138 | else:
139 | composed_latents = (
140 | composed_latents * (1.0 - mask_tensor_expanded)
141 | + latents_all * mask_tensor_expanded
142 | )
143 |
144 | composed_latents, foreground_indices = (
145 | composed_latents.to(torch_device),
146 | foreground_indices.to(torch_device),
147 | )
148 | return composed_latents, foreground_indices
149 |
150 |
151 | def align_with_bboxes(
152 | latents_all_list, mask_tensor_list, bboxes, horizontal_shift_only=False
153 | ):
154 | """
155 | Each offset in `offset_list` is `(x_offset, y_offset)` (normalized).
156 | """
157 | new_latents_all_list, new_mask_tensor_list, offset_list = [], [], []
158 | for latents_all, mask_tensor, bbox in zip(
159 | latents_all_list, mask_tensor_list, bboxes
160 | ):
161 | x_src_center, y_src_center = utils.binary_mask_to_center(
162 | mask_tensor, normalize=True
163 | )
164 | x_min_dest, y_min_dest, x_max_dest, y_max_dest = bbox
165 | x_dest_center, y_dest_center = (
166 | (x_min_dest + x_max_dest) / 2,
167 | (y_min_dest + y_max_dest) / 2,
168 | )
169 | # print("src (x,y):", x_src_center, y_src_center, "dest (x,y):", x_dest_center, y_dest_center)
170 | x_offset, y_offset = x_dest_center - x_src_center, y_dest_center - y_src_center
171 | if horizontal_shift_only:
172 | y_offset = 0.0
173 | offset = x_offset, y_offset
174 | latents_all = utils.shift_tensor(
175 | latents_all, x_offset, y_offset, offset_normalized=True
176 | )
177 | mask_tensor = utils.shift_tensor(
178 | mask_tensor, x_offset, y_offset, offset_normalized=True
179 | )
180 | new_latents_all_list.append(latents_all)
181 | new_mask_tensor_list.append(mask_tensor)
182 | offset_list.append(offset)
183 |
184 | return new_latents_all_list, new_mask_tensor_list, offset_list
185 |
186 |
187 | @torch.no_grad()
188 | def compose_latents_with_alignment(
189 | model_dict,
190 | latents_all_list,
191 | mask_tensor_list,
192 | num_inference_steps,
193 | overall_batch_size,
194 | height,
195 | width,
196 | align_with_overall_bboxes=True,
197 | overall_bboxes=None,
198 | horizontal_shift_only=False,
199 | **kwargs,
200 | ):
201 | if align_with_overall_bboxes and len(latents_all_list):
202 | expanded_overall_bboxes = utils.expand_overall_bboxes(overall_bboxes)
203 | latents_all_list, mask_tensor_list, offset_list = align_with_bboxes(
204 | latents_all_list,
205 | mask_tensor_list,
206 | bboxes=expanded_overall_bboxes,
207 | horizontal_shift_only=horizontal_shift_only,
208 | )
209 | else:
210 | offset_list = [(0.0, 0.0) for _ in range(len(latents_all_list))]
211 | composed_latents, foreground_indices = compose_latents(
212 | model_dict,
213 | latents_all_list,
214 | mask_tensor_list,
215 | num_inference_steps,
216 | overall_batch_size,
217 | height,
218 | width,
219 | **kwargs,
220 | )
221 | return composed_latents, foreground_indices, offset_list
222 |
223 |
224 | def get_input_latents_list(
225 | model_dict,
226 | bg_seed,
227 | fg_seed_start,
228 | fg_blending_ratio,
229 | height,
230 | width,
231 | so_prompt_phrase_box_list=None,
232 | so_boxes=None,
233 | verbose=False,
234 | ):
235 | """
236 | Note: the returned input latents are scaled by `scheduler.init_noise_sigma`
237 |
238 | fg_seed_start: int or list. If int, `fg_seed = fg_seed_start + idx`. If list, `fg_seed = fg_seed_start[idx]`.
239 | """
240 | unet, scheduler, dtype = model_dict.unet, model_dict.scheduler, model_dict.dtype
241 |
242 | generator_bg = torch.manual_seed(
243 | bg_seed
244 | ) # Seed generator to create the inital latent noise
245 | latents_bg = get_unscaled_latents(
246 | batch_size=1,
247 | in_channels=unet.config.in_channels,
248 | height=height,
249 | width=width,
250 | generator=generator_bg,
251 | dtype=dtype,
252 | )
253 |
254 | input_latents_list = []
255 |
256 | if so_boxes is None:
257 | # For compatibility
258 | so_boxes = [item[-1] for item in so_prompt_phrase_box_list]
259 |
260 | # change this changes the foreground initial noise
261 | for idx, obj_box in enumerate(so_boxes):
262 | H, W = height // 8, width // 8
263 | fg_mask = utils.proportion_to_mask(obj_box, H, W)
264 |
265 | if verbose >= 2:
266 | plt.imshow(fg_mask.cpu().numpy())
267 | plt.show()
268 |
269 | if isinstance(fg_seed_start, list):
270 | fg_seed = fg_seed_start[idx]
271 | else:
272 | fg_seed = fg_seed_start + idx
273 | assert (
274 | bg_seed != fg_seed
275 | ), f"Need to use different seeds for bg and fg: fg_seed ({fg_seed}) and bg_seed ({bg_seed})"
276 |
277 | generator_fg = torch.manual_seed(fg_seed)
278 | latents_fg = get_unscaled_latents(
279 | batch_size=1,
280 | in_channels=unet.config.in_channels,
281 | height=height,
282 | width=width,
283 | generator=generator_fg,
284 | dtype=dtype,
285 | )
286 |
287 | fg_blending_ratio_item = (
288 | fg_blending_ratio[idx]
289 | if isinstance(fg_blending_ratio, list)
290 | else fg_blending_ratio
291 | )
292 | input_latents = blend_latents(
293 | latents_bg, latents_fg, fg_mask, fg_blending_ratio=fg_blending_ratio_item
294 | )
295 |
296 | input_latents = input_latents * scheduler.init_noise_sigma
297 |
298 | input_latents_list.append(input_latents)
299 |
300 | latents_bg = latents_bg * scheduler.init_noise_sigma
301 |
302 | return input_latents_list, latents_bg
303 |
--------------------------------------------------------------------------------
/utils/llm.py:
--------------------------------------------------------------------------------
1 | import requests
2 | from prompt import templates, stop, required_lines, required_lines_ast
3 | from easydict import EasyDict
4 | from utils.cache import get_cache, add_cache
5 | import ast
6 | import traceback
7 | import time
8 | import pyjson5
9 |
10 | model_names = [
11 | "vicuna",
12 | "vicuna-13b",
13 | "vicuna-13b-v1.3",
14 | "vicuna-33b-v1.3",
15 | "Llama-2-7b-hf",
16 | "Llama-2-13b-hf",
17 | "Llama-2-70b-hf",
18 | "FreeWilly2",
19 | "gpt-3.5-turbo",
20 | "gpt-3.5",
21 | "gpt-4",
22 | "gpt-4-1106-preview",
23 | ]
24 |
25 |
26 | def get_full_chat_prompt(template, prompt, suffix=None, query_prefix="Caption: "):
27 | # query_prefix should be "Caption: " to match with the prompt. This was fixed when template v1.7.2 is in use.
28 | if isinstance(template, str):
29 | full_prompt = [
30 | {"role": "system", "content": "You are a helpful assistant."},
31 | {
32 | "role": "user",
33 | "content": get_full_prompt(template, prompt, suffix).strip(),
34 | },
35 | ]
36 | else:
37 | print("**Using chat prompt**")
38 | assert suffix is None
39 | full_prompt = [*template, {"role": "user", "content": query_prefix + prompt}]
40 | return full_prompt
41 |
42 |
43 | def get_full_prompt(template, prompt, suffix=None):
44 | assert isinstance(template, str), "Chat template requires `get_full_chat_prompt`"
45 | full_prompt = template.replace("{prompt}", prompt)
46 | if suffix:
47 | full_prompt = full_prompt.strip() + suffix
48 | return full_prompt
49 |
50 |
51 | def get_full_model_name(model):
52 | if model == "gpt-3.5":
53 | model = "gpt-3.5-turbo"
54 | elif model == "vicuna":
55 | model = "vicuna-13b"
56 | elif model == "gpt-4":
57 | model = "gpt-4"
58 |
59 | return model
60 |
61 |
62 | def get_llm_kwargs(model, template_version):
63 | model = get_full_model_name(model)
64 |
65 | print(f"Using template: {template_version}")
66 |
67 | template = templates[template_version]
68 |
69 | if (
70 | "vicuna" in model.lower()
71 | or "llama" in model.lower()
72 | or "freewilly" in model.lower()
73 | ):
74 | api_base = "http://localhost:8000/v1"
75 | max_tokens = 900
76 | temperature = 0.25
77 | headers = {}
78 | else:
79 | from utils.api_key import api_key
80 |
81 | api_base = "https://api.openai.com/v1"
82 | max_tokens = 900
83 | temperature = 0.25
84 | headers = {"Authorization": f"Bearer {api_key}"}
85 |
86 | llm_kwargs = EasyDict(
87 | model=model,
88 | template=template,
89 | api_base=api_base,
90 | max_tokens=max_tokens,
91 | temperature=temperature,
92 | headers=headers,
93 | stop=stop,
94 | )
95 |
96 | return model, llm_kwargs
97 |
98 |
99 | def get_layout(prompt, llm_kwargs, suffix="", query_prefix="Caption: ", verbose=False):
100 | # No cache in this function
101 | model, template, api_base, max_tokens, temperature, stop, headers = (
102 | llm_kwargs.model,
103 | llm_kwargs.template,
104 | llm_kwargs.api_base,
105 | llm_kwargs.max_tokens,
106 | llm_kwargs.temperature,
107 | llm_kwargs.stop,
108 | llm_kwargs.headers,
109 | )
110 |
111 | if verbose:
112 | print("prompt:", prompt, "with suffix", suffix)
113 |
114 | done = False
115 | attempts = 0
116 | while not done:
117 | if "gpt" in model:
118 | r = requests.post(
119 | f"{api_base}/chat/completions",
120 | json={
121 | "model": model,
122 | "messages": get_full_chat_prompt(
123 | template, prompt, suffix, query_prefix=query_prefix
124 | ),
125 | "max_tokens": max_tokens,
126 | "temperature": temperature,
127 | "stop": stop if isinstance(template, str) else None,
128 | },
129 | headers=headers,
130 | )
131 | else:
132 | r = requests.post(
133 | f"{api_base}/completions",
134 | json={
135 | "model": model,
136 | "prompt": get_full_prompt(template, prompt, suffix).strip(),
137 | "max_tokens": max_tokens,
138 | "temperature": temperature,
139 | "stop": stop,
140 | },
141 | headers=headers,
142 | )
143 |
144 | done = r.status_code == 200
145 |
146 | if not done:
147 | print(r.json())
148 | attempts += 1
149 | if attempts >= 3 and "gpt" in model:
150 | print("Retrying after 1 minute")
151 | time.sleep(60)
152 | if attempts >= 5 and "gpt" in model:
153 | print("Exiting due to many non-successful attempts")
154 | exit()
155 |
156 | if "gpt" in model:
157 | if verbose > 1:
158 | print(f"***{r.json()}***")
159 | response = r.json()["choices"][0]["message"]["content"]
160 | else:
161 | response = r.json()["choices"][0]["text"]
162 |
163 | if verbose:
164 | print("resp", response)
165 |
166 | return response
167 |
168 |
169 | def get_parsed_layout(*args, json_template=False, **kwargs):
170 | if json_template:
171 | return get_parsed_layout_json_resp(*args, **kwargs)
172 | else:
173 | return get_parsed_layout_text_resp(*args, **kwargs)
174 |
175 |
176 | def get_parsed_layout_text_resp(
177 | prompt,
178 | llm_kwargs=None,
179 | max_partial_response_retries=1,
180 | override_response=None,
181 | strip_chars=" \t\n`",
182 | save_leading_text=True,
183 | **kwargs,
184 | ):
185 | """
186 | override_response: override the LLM response (will not query the LLM), useful for parsing existing response
187 | """
188 | if override_response is not None:
189 | assert (
190 | max_partial_response_retries == 1
191 | ), "override_response is specified so no partial queries are allowed"
192 |
193 | process_index = 0
194 | retries = 0
195 | suffix = None
196 | parsed_layout = {}
197 | reconstructed_response = ""
198 | while process_index < len(required_lines):
199 | retries += 1
200 | if retries > max_partial_response_retries:
201 | raise ValueError(
202 | f"Erroring due to many non-successful attempts on prompt: {prompt} with response {response}"
203 | )
204 | if override_response is not None:
205 | response = override_response
206 | else:
207 | response = get_layout(
208 | prompt, llm_kwargs=llm_kwargs, suffix=suffix, **kwargs
209 | )
210 | # print(f"Current response: {response}")
211 | if required_lines[process_index] in response:
212 | response_split = response.split(required_lines[process_index])
213 |
214 | # print(f"Unused leading text: {response_split[0]}")
215 |
216 | if save_leading_text:
217 | reconstructed_response += (
218 | response_split[0] + required_lines[process_index]
219 | )
220 | response = response_split[1]
221 |
222 | while process_index < len(required_lines):
223 | required_line = required_lines[process_index]
224 | next_required_line = (
225 | required_lines[process_index + 1]
226 | if process_index + 1 < len(required_lines)
227 | else ""
228 | )
229 | if next_required_line in response:
230 | if next_required_line != "":
231 | required_line_idx = response.find(next_required_line)
232 | line_content = response[:required_line_idx].strip(strip_chars)
233 | else:
234 | line_content = response.strip(strip_chars)
235 | if required_lines_ast[process_index]:
236 | # LLMs sometimes give comments starting with " - "
237 | line_content = line_content.split(" - ")[0].strip()
238 |
239 | # LLMs sometimes give a list: `- name: content`
240 | if line_content.startswith("-"):
241 | line_content = line_content[
242 | line_content.find("-") + 1 :
243 | ].strip()
244 |
245 | try:
246 | line_content = ast.literal_eval(line_content)
247 | except SyntaxError as e:
248 | print(
249 | f"Encountered SyntaxError with content {line_content}: {e}"
250 | )
251 | raise e
252 | parsed_layout[required_line.rstrip(":")] = line_content
253 | reconstructed_response += response[
254 | : required_line_idx + len(next_required_line)
255 | ]
256 | response = response[required_line_idx + len(next_required_line) :]
257 | process_index += 1
258 | else:
259 | break
260 | if process_index == 0:
261 | # Nothing matches, retry without adding suffix
262 | continue
263 | elif process_index < len(required_lines):
264 | # Partial matches, add the match to the suffix and retry
265 | suffix = (
266 | "\n"
267 | + response.rstrip(strip_chars)
268 | + "\n"
269 | + required_lines[process_index]
270 | )
271 |
272 | parsed_layout["Prompt"] = prompt
273 |
274 | return parsed_layout, reconstructed_response
275 |
276 |
277 | def get_parsed_layout_json_resp(
278 | prompt,
279 | llm_kwargs=None,
280 | max_partial_response_retries=1,
281 | override_response=None,
282 | strip_chars=" \t\n`",
283 | save_leading_text=True,
284 | **kwargs,
285 | ):
286 | """
287 | override_response: override the LLM response (will not query the LLM), useful for parsing existing response
288 | save_leading_text: ignored since we do not allow leading text in JSON
289 | max_partial_response_retries: ignored since we do not allow partial response in JSON
290 | """
291 | assert (
292 | max_partial_response_retries == 1
293 | ), "no partial queries are allowed in with JSON format templates"
294 | if override_response is not None:
295 | response = override_response
296 | else:
297 | response = get_layout(prompt, llm_kwargs=llm_kwargs, suffix=None, **kwargs)
298 |
299 | response = response.strip(strip_chars)
300 |
301 | # Alternatively we can use `removeprefix` in `str`.
302 | response = (
303 | response[len("Response:") :] if response.startswith("Response:") else response
304 | )
305 |
306 | response = response.strip(strip_chars)
307 |
308 | # print("Response:", response)
309 |
310 | try:
311 | parsed_layout = pyjson5.loads(response)
312 | except (
313 | ValueError,
314 | pyjson5.Json5Exception,
315 | pyjson5.Json5EOF,
316 | pyjson5.Json5DecoderException,
317 | pyjson5.Json5IllegalCharacter,
318 | ) as e:
319 | print(
320 | f"Encountered exception in parsing the response with content {response}: {e}"
321 | )
322 | raise e
323 |
324 | reconstructed_response = response
325 |
326 | parsed_layout["Prompt"] = prompt
327 |
328 | return parsed_layout, reconstructed_response
329 |
330 |
331 | def get_parsed_layout_with_cache(
332 | prompt,
333 | llm_kwargs,
334 | verbose=False,
335 | max_retries=3,
336 | cache_miss_allowed=True,
337 | json_template=False,
338 | **kwargs,
339 | ):
340 | """
341 | Get parsed_layout with cache support. This function will only add to cache after all lines are obtained and parsed. If partial response is obtained, the model will query with the previous partial response.
342 | """
343 | # Note that cache path needs to be set correctly, as get_cache does not check whether the cache is generated with the given model in the given setting.
344 |
345 | response = get_cache(prompt)
346 |
347 | if response is not None:
348 | print(f"Cache hit: {prompt}")
349 | parsed_layout, _ = get_parsed_layout(
350 | prompt,
351 | llm_kwargs=llm_kwargs,
352 | max_partial_response_retries=1,
353 | override_response=response,
354 | json_template=json_template,
355 | )
356 | return parsed_layout
357 |
358 | print(f"Cache miss: {prompt}")
359 |
360 | assert cache_miss_allowed, "Cache miss is not allowed"
361 |
362 | done = False
363 | retries = 0
364 | while not done:
365 | retries += 1
366 | if retries >= max_retries:
367 | raise ValueError(
368 | f"Erroring due to many non-successful attempts on prompt: {prompt}"
369 | )
370 | try:
371 | parsed_layout, reconstructed_response = get_parsed_layout(
372 | prompt, llm_kwargs=llm_kwargs, json_template=json_template, **kwargs
373 | )
374 | except Exception as e:
375 | print(f"Error: {e}, retrying")
376 | traceback.print_exc()
377 | continue
378 |
379 | done = True
380 |
381 | add_cache(prompt, reconstructed_response)
382 |
383 | if verbose:
384 | print(f"parsed_layout = {parsed_layout}")
385 |
386 | return parsed_layout
387 |
--------------------------------------------------------------------------------
/utils/parse.py:
--------------------------------------------------------------------------------
1 | from matplotlib.patches import Polygon
2 | from matplotlib.collections import PatchCollection
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import os
6 | from . import guidance
7 | from collections import namedtuple
8 | import imageio
9 | import shutil
10 |
11 | Condition = namedtuple(
12 | "Condition", ["prompt", "boxes", "phrases", "object_positions", "token_map"]
13 | )
14 |
15 | img_dir = "imgs"
16 |
17 | # h, w used in the layouts
18 | size = (512, 512)
19 | size_h, size_w = size
20 | # print(f"Using box scale: {size}")
21 |
22 |
23 | def draw_boxes(condition, frame_index=None):
24 | boxes, phrases = condition.boxes, condition.phrases
25 |
26 | ax = plt.gca()
27 | ax.set_autoscale_on(False)
28 | polygons = []
29 | color = []
30 | for box_ind, (box, name) in enumerate(zip(boxes, phrases)):
31 | if isinstance(box, dict):
32 | if frame_index not in box:
33 | continue
34 | else:
35 | if frame_index >= len(box):
36 | continue
37 |
38 | box = box[frame_index] if frame_index is not None else box
39 | # Each phrase may be a list to allow different phrase per timestep
40 | name = (
41 | name[frame_index]
42 | if frame_index is not None and isinstance(name, (dict, list, tuple))
43 | else name
44 | )
45 |
46 | # This ensures different frames have the same box color.
47 | rng = np.random.default_rng(box_ind)
48 | c = rng.random((1, 3)) * 0.6 + 0.4
49 | [bbox_x, bbox_y, bbox_x_max, bbox_y_max] = box
50 | if bbox_x_max <= bbox_x or bbox_y_max <= bbox_y:
51 | # Filters out the box in the frames without this box
52 | continue
53 | bbox_x, bbox_y, bbox_x_max, bbox_y_max = (
54 | bbox_x * size_w,
55 | bbox_y * size_h,
56 | bbox_x_max * size_w,
57 | bbox_y_max * size_h,
58 | )
59 | poly = [
60 | [bbox_x, bbox_y],
61 | [bbox_x, bbox_y_max],
62 | [bbox_x_max, bbox_y_max],
63 | [bbox_x_max, bbox_y],
64 | ]
65 | np_poly = np.array(poly).reshape((4, 2))
66 | polygons.append(Polygon(np_poly))
67 | color.append(c)
68 |
69 | # print(ann)
70 | ax.text(
71 | bbox_x,
72 | bbox_y,
73 | name,
74 | style="italic",
75 | bbox={"facecolor": "white", "alpha": 0.7, "pad": 5},
76 | )
77 |
78 | p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
79 | ax.add_collection(p)
80 |
81 |
82 | def show_boxes(
83 | condition, frame_index=None, ind=None, show=True, show_prompt=True, save=False
84 | ):
85 | """
86 | This draws the boxes in `frame_index`.
87 | """
88 | boxes, phrases = condition.boxes, condition.phrases
89 |
90 | if len(boxes) == 0:
91 | return
92 |
93 | # White background (to allow line to show on the edge)
94 | I = np.ones((size[0] + 4, size[1] + 4, 3), dtype=np.uint8) * 255
95 |
96 | plt.imshow(I)
97 | plt.axis("off")
98 |
99 | bg_prompt = getattr(condition, "prompt", None)
100 | neg_prompt = getattr(condition, "neg_prompt", None)
101 |
102 | ax = plt.gca()
103 | if show_prompt and bg_prompt is not None:
104 | ax.text(
105 | 0,
106 | 0,
107 | bg_prompt + f"(Neg: {neg_prompt})" if neg_prompt else bg_prompt,
108 | style="italic",
109 | bbox={"facecolor": "white", "alpha": 0.7, "pad": 5},
110 | )
111 | c = np.zeros((1, 3))
112 | [bbox_x, bbox_y, bbox_w, bbox_h] = (0, 0, size[1], size[0])
113 | poly = [
114 | [bbox_x, bbox_y],
115 | [bbox_x, bbox_y + bbox_h],
116 | [bbox_x + bbox_w, bbox_y + bbox_h],
117 | [bbox_x + bbox_w, bbox_y],
118 | ]
119 | np_poly = np.array(poly).reshape((4, 2))
120 | polygons = [Polygon(np_poly)]
121 | color = [c]
122 | p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
123 | ax.add_collection(p)
124 |
125 | draw_boxes(condition, frame_index=frame_index)
126 | if show:
127 | plt.show()
128 |
129 | if save:
130 | print("Saved to", f"{img_dir}/boxes.png", f"ind: {ind}")
131 | plt.savefig(f"{img_dir}/boxes.png")
132 | if ind is not None:
133 | shutil.copy(f"{img_dir}/boxes.png", f"{img_dir}/boxes_{ind}.png")
134 |
135 |
136 | def show_video_boxes(
137 | condition,
138 | figsize=(4, 4),
139 | ind=None,
140 | show=False,
141 | save=False,
142 | save_each_frame=False,
143 | fps=8,
144 | save_name="boxes",
145 | **kwargs,
146 | ):
147 | boxes, phrases = condition.boxes, condition.phrases
148 |
149 | assert len(boxes) == len(phrases), f"{len(boxes)} != {len(phrases)}"
150 |
151 | if len(boxes) == 0:
152 | return
153 |
154 | num_frames = len(boxes[0])
155 |
156 | boxes_frames = []
157 |
158 | for frame_index in range(num_frames):
159 | fig = plt.figure(figsize=figsize)
160 | # https://stackoverflow.com/questions/7821518/save-plot-to-numpy-array
161 | show_boxes(condition, frame_index=frame_index, show=False, save=False, **kwargs)
162 | # If we haven't already shown or saved the plot, then we need to
163 | # draw the figure first...
164 | fig.canvas.draw()
165 |
166 | # Now we can save it to a numpy array.
167 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
168 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
169 | plt.close()
170 |
171 | boxes_frames.append(data)
172 |
173 | if show:
174 | video = imageio.mimsave(
175 | imageio.RETURN_BYTES,
176 | boxes_frames,
177 | format="gif",
178 | loop=0,
179 | duration=1000 * 1 / fps,
180 | )
181 | from IPython.display import display, Image as IPyImage
182 |
183 | display(IPyImage(data=video, format="gif"))
184 |
185 | if save:
186 | imageio.mimsave(
187 | f"{img_dir}/{save_name}.gif",
188 | boxes_frames,
189 | format="gif",
190 | loop=0,
191 | duration=1000 * 1 / fps,
192 | )
193 | if ind is not None:
194 | shutil.copy(
195 | f"{img_dir}/{save_name}.gif", f"{img_dir}/{save_name}_{ind}.gif"
196 | )
197 | print(f'Saved to "{img_dir}/{save_name}.gif"', f"ind: {ind}")
198 |
199 | if save_each_frame:
200 | os.makedirs(f"{img_dir}/{save_name}", exist_ok=True)
201 | for frame_ind, frame in enumerate(boxes_frames):
202 | imageio.imsave(
203 | f"{img_dir}/{save_name}/{frame_ind}.png", frame, format="png"
204 | )
205 | print(f'Saved frames to "{img_dir}/{save_name}"', f"ind: {ind}")
206 |
207 |
208 | def show_masks(masks):
209 | masks_to_show = np.zeros((*size, 3), dtype=np.float32)
210 | for mask in masks:
211 | c = np.random.random((3,)) * 0.6 + 0.4
212 |
213 | masks_to_show += mask[..., None] * c[None, None, :]
214 | plt.imshow(masks_to_show)
215 | plt.savefig(f"{img_dir}/masks.png")
216 | plt.show()
217 | plt.close()
218 |
219 |
220 | def convert_box(box, height, width):
221 | # box: x, y, w, h (in 512 format) -> x_min, y_min, x_max, y_max
222 | x_min, y_min = box[0] / width, box[1] / height
223 | w_box, h_box = box[2] / width, box[3] / height
224 |
225 | x_max, y_max = x_min + w_box, y_min + h_box
226 |
227 | return x_min, y_min, x_max, y_max
228 |
229 |
230 | def interpolate_box(box, num_input_frames=6, num_output_frames=24, repeat=1):
231 | output_boxes = np.zeros((num_output_frames, 4))
232 | box_time_indices = np.sort(list(box.keys()))
233 | xs = np.concatenate(
234 | [box_time_indices / (num_input_frames - 1) + i for i in range(repeat)]
235 | )
236 | # The subtraction is to prevent the boundary effect with modulus.
237 | xs_query = np.linspace(0, repeat - 1e-5, num_output_frames)
238 | mask = np.isin(np.floor((xs_query % 1.0) * num_input_frames), box_time_indices)
239 |
240 | # 4: x_min, y_min, x_max, y_max
241 | for i in range(4):
242 | ys = np.array(
243 | [box[box_time_index][i] for box_time_index in box_time_indices] * repeat
244 | )
245 | # If the mask is False (the object does not exist in this timestep, the box has all items 0)
246 | output_boxes[:, i] = np.interp(xs_query, xs, ys) * mask
247 |
248 | return output_boxes.tolist()
249 |
250 |
251 | def parsed_layout_to_condition(
252 | parsed_layout,
253 | height,
254 | width,
255 | num_parsed_layout_frames=6,
256 | num_condition_frames=24,
257 | interpolate_boxes=True,
258 | tokenizer=None,
259 | output_phrase_per_timestep=False,
260 | add_background_to_prompt=True,
261 | strip_phrases=False,
262 | verbose=False,
263 | ):
264 | """
265 | Infer condition from parsed layout.
266 | Boxes can appear or disappear.
267 | """
268 |
269 | prompt = parsed_layout["Prompt"]
270 |
271 | if add_background_to_prompt and parsed_layout["Background keyword"]:
272 | prompt += f", {parsed_layout['Background keyword']} background"
273 |
274 | id_to_phrase, id_to_box = {}, {}
275 |
276 | box_ids = []
277 |
278 | for frame_ind in range(num_parsed_layout_frames):
279 | object_dicts = parsed_layout[f"Frame {frame_ind + 1}"]
280 | for object_dict in object_dicts:
281 | current_box_id = object_dict["id"]
282 | if current_box_id not in id_to_phrase:
283 | if output_phrase_per_timestep:
284 | # Only the phrase at the first occurrence is used if `output_phrase_per_timestep` is False
285 | id_to_phrase[current_box_id] = {}
286 | else:
287 | id_to_phrase[current_box_id] = (
288 | object_dict["name"]
289 | if "name" in object_dict
290 | else object_dict["keyword"]
291 | )
292 |
293 | # Use `dict` to handle appearance and disappearance of objects
294 | id_to_box[current_box_id] = {}
295 |
296 | box_ids.append(current_box_id)
297 |
298 | box = object_dict["box"]
299 | converted_box = convert_box(box, height=height, width=width)
300 | id_to_box[current_box_id][frame_ind] = converted_box
301 |
302 | if output_phrase_per_timestep:
303 | id_to_phrase[current_box_id][frame_ind] = (
304 | object_dict["name"]
305 | if "name" in object_dict
306 | else object_dict["keyword"]
307 | )
308 |
309 | boxes = [id_to_box[box_id] for box_id in box_ids]
310 | phrases = [id_to_phrase[box_id] for box_id in box_ids]
311 |
312 | if verbose:
313 | boxes_before_interpolation = boxes
314 |
315 | # Frames in interpolated boxes are consecutive, but some boxes may have all coordinates as 0 to indicate disappearance
316 | if interpolate_boxes:
317 | assert (
318 | not output_phrase_per_timestep
319 | ), "box interpolation with phrase per timestep is not implemented"
320 | boxes = [
321 | interpolate_box(
322 | box,
323 | num_parsed_layout_frames,
324 | num_condition_frames,
325 | repeat=parsed_layout.get("Repeat", 1),
326 | )
327 | for box in boxes
328 | ]
329 |
330 | if tokenizer is not None:
331 | for phrase in phrases:
332 | found, _ = guidance.refine_phrase(prompt, phrase, verbose=True)
333 |
334 | if not found:
335 | # Suffix the prompt with object name (before the refinement) for attention guidance if object is not in the prompt, using "|" to separate the prompt and the suffix
336 | prompt += "| " + phrase
337 |
338 | print(f'**Adding {phrase} to the prompt. Using prompt: "{prompt}"')
339 |
340 | # `phrases` might not correspond to the first occurrence in the prompt, which is not handled now.
341 | token_map = guidance.get_token_map(
342 | tokenizer, prompt=prompt, verbose=verbose, padding="do_not_pad"
343 | )
344 | object_positions = guidance.get_phrase_indices(
345 | tokenizer, prompt, phrases, token_map=token_map, verbose=verbose
346 | )
347 | else:
348 | token_map = None
349 | object_positions = None
350 |
351 | if verbose:
352 | print("prompt:", prompt)
353 | print("boxes (before interpolation):", boxes_before_interpolation)
354 | if verbose >= 2:
355 | print("boxes (after interpolation):", np.round(np.array(boxes), 2))
356 | print("phrases:", phrases)
357 | if object_positions is not None:
358 | print("object_positions:", object_positions)
359 |
360 | if strip_phrases:
361 | phrases = [phrase.strip("1234567890 ") for phrase in phrases]
362 |
363 | return Condition(prompt, boxes, phrases, object_positions, token_map)
364 |
--------------------------------------------------------------------------------
/utils/schedule.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import warnings
3 |
4 |
5 | def get_fast_schedule(origial_timesteps, fast_after_steps, fast_rate):
6 | if fast_after_steps >= len(origial_timesteps) - 1:
7 | return origial_timesteps
8 | new_timesteps = torch.cat(
9 | (
10 | origial_timesteps[:fast_after_steps],
11 | origial_timesteps[fast_after_steps + 1 :: fast_rate],
12 | ),
13 | dim=0,
14 | )
15 | return new_timesteps
16 |
17 |
18 | def dynamically_adjust_inference_steps(scheduler, index, t):
19 | prev_t = (
20 | scheduler.timesteps[index + 1] if index + 1 < len(scheduler.timesteps) else -1
21 | )
22 | scheduler.num_inference_steps = scheduler.config.num_train_timesteps // (t - prev_t)
23 | if index + 1 < len(scheduler.timesteps):
24 | if (
25 | scheduler.config.num_train_timesteps // scheduler.num_inference_steps
26 | != t - prev_t
27 | ):
28 | warnings.warn(
29 | f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) != ({t} - {prev_t}), so the step sizes may not be accurate"
30 | )
31 | else:
32 | # as long as we hit final cumprob, it should be fine.
33 | if (
34 | scheduler.config.num_train_timesteps // scheduler.num_inference_steps
35 | > t - prev_t
36 | ):
37 | warnings.warn(
38 | f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) > ({t} - {prev_t}), so the step sizes may not be accurate"
39 | )
40 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import ImageDraw
3 | import numpy as np
4 | import os
5 | import gc
6 | import math
7 | from typing import List
8 | import cv2
9 | import skvideo.io
10 |
11 | torch_device = "cuda"
12 |
13 |
14 | def draw_box(pil_img, bboxes, phrases, ignore_all_zeros=True):
15 | W, H = pil_img.size
16 | draw = ImageDraw.Draw(pil_img)
17 |
18 | for obj_bbox, phrase in zip(bboxes, phrases):
19 | x_0, y_0, x_1, y_1 = obj_bbox[0], obj_bbox[1], obj_bbox[2], obj_bbox[3]
20 | if ignore_all_zeros and x_0 == 0 and y_0 == 0 and x_1 == 0 and y_1 == 0:
21 | continue
22 | draw.rectangle(
23 | [int(x_0 * W), int(y_0 * H), int(x_1 * W), int(y_1 * H)],
24 | outline="red",
25 | width=5,
26 | )
27 | draw.text(
28 | (int(x_0 * W) + 5, int(y_0 * H) + 5), phrase, font=None, fill=(255, 0, 0)
29 | )
30 |
31 | return pil_img
32 |
33 |
34 | def get_centered_box(
35 | box,
36 | horizontal_center_only=True,
37 | vertical_placement="centered",
38 | vertical_center=0.5,
39 | floor_padding=None,
40 | ):
41 | x_min, y_min, x_max, y_max = box
42 | w = x_max - x_min
43 |
44 | x_min_new = 0.5 - w / 2
45 | x_max_new = 0.5 + w / 2
46 |
47 | if horizontal_center_only:
48 | return [x_min_new, y_min, x_max_new, y_max]
49 |
50 | h = y_max - y_min
51 |
52 | if vertical_placement == "centered":
53 | assert (
54 | floor_padding is None
55 | ), "Set vertical_placement to floor_padding to use floor padding"
56 |
57 | y_min_new = vertical_center - h / 2
58 | y_max_new = vertical_center + h / 2
59 | elif vertical_placement == "floor_padding":
60 | # Ignores `vertical_center`
61 |
62 | y_max_new = 1 - floor_padding
63 | y_min_new = y_max_new - h
64 | else:
65 | raise ValueError(f"Unknown vertical placement: {vertical_placement}")
66 |
67 | return [x_min_new, y_min_new, x_max_new, y_max_new]
68 |
69 |
70 | # NOTE: this changes the behavior of the function
71 | def proportion_to_mask(obj_box, H, W, use_legacy=False, return_np=False):
72 | x_min, y_min, x_max, y_max = scale_proportion(obj_box, H, W, use_legacy)
73 | if return_np:
74 | mask = np.zeros((H, W))
75 | else:
76 | mask = torch.zeros(H, W).to(torch_device)
77 | mask[y_min:y_max, x_min:x_max] = 1.0
78 |
79 | return mask
80 |
81 |
82 | def scale_proportion(obj_box, H, W, use_legacy=False):
83 | if use_legacy:
84 | # Bias towards the top-left corner
85 | x_min, y_min, x_max, y_max = (
86 | int(obj_box[0] * W),
87 | int(obj_box[1] * H),
88 | int(obj_box[2] * W),
89 | int(obj_box[3] * H),
90 | )
91 | else:
92 | # Separately rounding box_w and box_h to allow shift invariant box sizes. Otherwise box sizes may change when both coordinates being rounded end with ".5".
93 | x_min, y_min = round(obj_box[0] * W), round(obj_box[1] * H)
94 | box_w, box_h = (
95 | round((obj_box[2] - obj_box[0]) * W),
96 | round((obj_box[3] - obj_box[1]) * H),
97 | )
98 | x_max, y_max = x_min + box_w, y_min + box_h
99 |
100 | x_min, y_min = max(x_min, 0), max(y_min, 0)
101 | x_max, y_max = min(x_max, W), min(y_max, H)
102 |
103 | return x_min, y_min, x_max, y_max
104 |
105 |
106 | def binary_mask_to_box(mask, enlarge_box_by_one=True, w_scale=1, h_scale=1):
107 | if isinstance(mask, torch.Tensor):
108 | mask_loc = torch.where(mask)
109 | else:
110 | mask_loc = np.where(mask)
111 | height, width = mask.shape
112 | if len(mask_loc) == 0:
113 | raise ValueError("The mask is empty")
114 | if enlarge_box_by_one:
115 | ymin, ymax = max(min(mask_loc[0]) - 1, 0), min(max(mask_loc[0]) + 1, height)
116 | xmin, xmax = max(min(mask_loc[1]) - 1, 0), min(max(mask_loc[1]) + 1, width)
117 | else:
118 | ymin, ymax = min(mask_loc[0]), max(mask_loc[0])
119 | xmin, xmax = min(mask_loc[1]), max(mask_loc[1])
120 | box = [xmin * w_scale, ymin * h_scale, xmax * w_scale, ymax * h_scale]
121 |
122 | return box
123 |
124 |
125 | def binary_mask_to_box_mask(mask, to_device=True):
126 | box = binary_mask_to_box(mask)
127 | x_min, y_min, x_max, y_max = box
128 |
129 | H, W = mask.shape
130 | mask = torch.zeros(H, W)
131 | if to_device:
132 | mask = mask.to(torch_device)
133 | mask[y_min : y_max + 1, x_min : x_max + 1] = 1.0
134 |
135 | return mask
136 |
137 |
138 | def binary_mask_to_center(mask, normalize=False):
139 | """
140 | This computes the mass center of the mask.
141 | normalize: the coords range from 0 to 1
142 |
143 | Reference: https://stackoverflow.com/a/66184125
144 | """
145 | h, w = mask.shape
146 |
147 | total = mask.sum()
148 | if isinstance(mask, torch.Tensor):
149 | x_coord = ((mask.sum(dim=0) @ torch.arange(w)) / total).item()
150 | y_coord = ((mask.sum(dim=1) @ torch.arange(h)) / total).item()
151 | else:
152 | x_coord = (mask.sum(axis=0) @ np.arange(w)) / total
153 | y_coord = (mask.sum(axis=1) @ np.arange(h)) / total
154 |
155 | if normalize:
156 | x_coord, y_coord = x_coord / w, y_coord / h
157 | return x_coord, y_coord
158 |
159 |
160 | def iou(mask, masks, eps=1e-6):
161 | # mask: [h, w], masks: [n, h, w]
162 | mask = mask[None].astype(bool)
163 | masks = masks.astype(bool)
164 | i = (mask & masks).sum(axis=(1, 2))
165 | u = (mask | masks).sum(axis=(1, 2))
166 |
167 | return i / (u + eps)
168 |
169 |
170 | def free_memory():
171 | gc.collect()
172 | torch.cuda.empty_cache()
173 |
174 |
175 | def expand_overall_bboxes(overall_bboxes):
176 | """
177 | Expand overall bboxes from a 3d list to 2d list:
178 | Input: [[box 1 for phrase 1, box 2 for phrase 1], ...]
179 | Output: [box 1, box 2, ...]
180 | """
181 | return sum(overall_bboxes, start=[])
182 |
183 |
184 | def shift_tensor(
185 | tensor,
186 | x_offset,
187 | y_offset,
188 | base_w=8,
189 | base_h=8,
190 | offset_normalized=False,
191 | ignore_last_dim=False,
192 | ):
193 | """base_w and base_h: make sure the shift is aligned in the latent and multiple levels of cross attention"""
194 | if ignore_last_dim:
195 | tensor_h, tensor_w = tensor.shape[-3:-1]
196 | else:
197 | tensor_h, tensor_w = tensor.shape[-2:]
198 | if offset_normalized:
199 | assert (
200 | tensor_h % base_h == 0 and tensor_w % base_w == 0
201 | ), f"{tensor_h, tensor_w} is not a multiple of {base_h, base_w}"
202 | scale_from_base_h, scale_from_base_w = tensor_h // base_h, tensor_w // base_w
203 | x_offset, y_offset = (
204 | round(x_offset * base_w) * scale_from_base_w,
205 | round(y_offset * base_h) * scale_from_base_h,
206 | )
207 | new_tensor = torch.zeros_like(tensor)
208 |
209 | overlap_w = tensor_w - abs(x_offset)
210 | overlap_h = tensor_h - abs(y_offset)
211 |
212 | if y_offset >= 0:
213 | y_src_start = 0
214 | y_dest_start = y_offset
215 | else:
216 | y_src_start = -y_offset
217 | y_dest_start = 0
218 |
219 | if x_offset >= 0:
220 | x_src_start = 0
221 | x_dest_start = x_offset
222 | else:
223 | x_src_start = -x_offset
224 | x_dest_start = 0
225 |
226 | if ignore_last_dim:
227 | # For cross attention maps, the third to last and the second to last are the 2D dimensions after unflatten.
228 | new_tensor[
229 | ...,
230 | y_dest_start : y_dest_start + overlap_h,
231 | x_dest_start : x_dest_start + overlap_w,
232 | :,
233 | ] = tensor[
234 | ...,
235 | y_src_start : y_src_start + overlap_h,
236 | x_src_start : x_src_start + overlap_w,
237 | :,
238 | ]
239 | else:
240 | new_tensor[
241 | ...,
242 | y_dest_start : y_dest_start + overlap_h,
243 | x_dest_start : x_dest_start + overlap_w,
244 | ] = tensor[
245 | ...,
246 | y_src_start : y_src_start + overlap_h,
247 | x_src_start : x_src_start + overlap_w,
248 | ]
249 |
250 | return new_tensor
251 |
252 |
253 | def get_hw_from_attn_dim(attn_dim, base_attn_dim):
254 | # base_attn_dim: (40, 72) for zeroscope (width 576, height 320)
255 | scale = int(math.sqrt((base_attn_dim[0] * base_attn_dim[1]) / attn_dim))
256 | return base_attn_dim[0] // scale, base_attn_dim[1] // scale
257 |
258 |
259 | # Reference: https://github.com/huggingface/diffusers/blob/v0.20.0/src/diffusers/utils/testing_utils.py#L400
260 | def export_to_video(
261 | video_frames: List[np.ndarray],
262 | output_video_path: str,
263 | fps: int = 8,
264 | fourcc: str = "mp4v",
265 | use_opencv=False,
266 | crf=17,
267 | ) -> str:
268 | if use_opencv:
269 | # This requires a cv2 installation that has video encoder support.
270 |
271 | fourcc = cv2.VideoWriter_fourcc(*fourcc)
272 | h, w, c = video_frames[0].shape
273 | video_writer = cv2.VideoWriter(
274 | output_video_path, fourcc, fps=fps, frameSize=(w, h)
275 | )
276 | for i in range(len(video_frames)):
277 | img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
278 | video_writer.write(img)
279 | else:
280 | skvideo.io.vwrite(
281 | output_video_path,
282 | video_frames,
283 | inputdict={"-framerate": str(fps)},
284 | outputdict={"-vcodec": "libx264", "-pix_fmt": "yuv420p", "-crf": str(crf)},
285 | )
286 | return output_video_path
287 |
288 |
289 | def multiline_input(prompt, return_on_empty_lines=True):
290 | # Adapted from https://stackoverflow.com/questions/30239092/how-to-get-multiline-input-from-the-user
291 |
292 | print(prompt, end="", flush=True)
293 | contents = ""
294 | while True:
295 | try:
296 | line = input()
297 | except EOFError:
298 | break
299 | if line == "" and return_on_empty_lines:
300 | break
301 | contents += line + "\n"
302 |
303 | return contents
304 |
305 |
306 | def find_gen_dir(gen_name, create_dir=True):
307 | base_save_dir = f"img_generations/{gen_name}"
308 | run_ind = 0
309 |
310 | while True:
311 | gen_dir = f"{base_save_dir}/run{run_ind}"
312 | if not os.path.exists(gen_dir):
313 | break
314 | run_ind += 1
315 |
316 | print(f"Save results at {gen_dir}")
317 | if create_dir:
318 | os.makedirs(gen_dir, exist_ok=False)
319 |
320 | return gen_dir
321 |
--------------------------------------------------------------------------------
/utils/vis.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import utils
4 | from . import parse
5 | import imageio
6 | import joblib
7 |
8 |
9 | def visualize(image, title, colorbar=False, show_plot=True, **kwargs):
10 | plt.title(title)
11 | plt.imshow(image, **kwargs)
12 | if colorbar:
13 | plt.colorbar()
14 | if show_plot:
15 | plt.show()
16 |
17 |
18 | def visualize_arrays(
19 | image_title_pairs,
20 | colorbar_index=-1,
21 | show_plot=True,
22 | figsize=None,
23 | no_axis=False,
24 | **kwargs,
25 | ):
26 | if figsize is not None:
27 | plt.figure(figsize=figsize)
28 | num_subplots = len(image_title_pairs)
29 | for idx, image_title_pair in enumerate(image_title_pairs):
30 | plt.subplot(1, num_subplots, idx + 1)
31 | if isinstance(image_title_pair, (list, tuple)):
32 | image, title = image_title_pair
33 | else:
34 | image, title = image_title_pair, None
35 |
36 | if title is not None:
37 | plt.title(title)
38 |
39 | plt.imshow(image, **kwargs)
40 | if no_axis:
41 | plt.axis("off")
42 | if idx == colorbar_index:
43 | plt.colorbar()
44 |
45 | if show_plot:
46 | plt.show()
47 |
48 |
49 | def visualize_masked_latents(
50 | latents_all, masked_latents, timestep_T=False, timestep_0=True
51 | ):
52 | if timestep_T:
53 | # from T to 0
54 | latent_idx = 0
55 |
56 | plt.subplot(1, 2, 1)
57 | plt.title("latents_all (t=T)")
58 | plt.imshow(
59 | (
60 | latents_all[latent_idx, 0, :3]
61 | .cpu()
62 | .permute(1, 2, 0)
63 | .numpy()
64 | .astype(float)
65 | / 1.5
66 | ).clip(0.0, 1.0),
67 | cmap="gray",
68 | )
69 |
70 | plt.subplot(1, 2, 2)
71 | plt.title("mask latents (t=T)")
72 | plt.imshow(
73 | (
74 | masked_latents[latent_idx, 0, :3]
75 | .cpu()
76 | .permute(1, 2, 0)
77 | .numpy()
78 | .astype(float)
79 | / 1.5
80 | ).clip(0.0, 1.0),
81 | cmap="gray",
82 | )
83 |
84 | plt.show()
85 |
86 | if timestep_0:
87 | latent_idx = -1
88 | plt.subplot(1, 2, 1)
89 | plt.title("latents_all (t=0)")
90 | plt.imshow(
91 | (
92 | latents_all[latent_idx, 0, :3]
93 | .cpu()
94 | .permute(1, 2, 0)
95 | .numpy()
96 | .astype(float)
97 | / 1.5
98 | ).clip(0.0, 1.0),
99 | cmap="gray",
100 | )
101 |
102 | plt.subplot(1, 2, 2)
103 | plt.title("mask latents (t=0)")
104 | plt.imshow(
105 | (
106 | masked_latents[latent_idx, 0, :3]
107 | .cpu()
108 | .permute(1, 2, 0)
109 | .numpy()
110 | .astype(float)
111 | / 1.5
112 | ).clip(0.0, 1.0),
113 | cmap="gray",
114 | )
115 |
116 | plt.show()
117 |
118 |
119 | def visualize_bboxes(bboxes, H, W):
120 | num_boxes = len(bboxes)
121 | for ind, bbox in enumerate(bboxes):
122 | plt.subplot(1, num_boxes, ind + 1)
123 | fg_mask = utils.proportion_to_mask(bbox, H, W)
124 | plt.title(f"transformed bbox ({ind})")
125 | plt.imshow(fg_mask.cpu().numpy())
126 | plt.show()
127 |
128 |
129 | def save_image(image, save_prefix="", ind=None):
130 | global save_ind
131 | if save_prefix != "":
132 | save_prefix = save_prefix + "_"
133 | ind = f"{ind}_" if ind is not None else ""
134 | path = f"{parse.img_dir}/{save_prefix}{ind}{save_ind}.png"
135 |
136 | print(f"Saved to {path}")
137 |
138 | image.save(path)
139 | save_ind = save_ind + 1
140 |
141 |
142 | def save_frames(path, frames, formats="gif", fps=8):
143 | if isinstance(formats, (list, tuple)):
144 | for format in formats:
145 | save_frames(path, frames, format, fps)
146 | return
147 |
148 | if formats == "gif":
149 | imageio.mimsave(
150 | f"{path}.gif", frames, format="gif", loop=0, duration=1000 * 1 / fps
151 | )
152 | elif formats == "mp4":
153 | utils.export_to_video(
154 | video_frames=frames, output_video_path=f"{path}.mp4", fps=fps
155 | )
156 | elif formats == "npz":
157 | np.savez_compressed(f"{path}.npz", frames)
158 | elif formats == "joblib":
159 | joblib.dump(frames, f"{path}.joblib", compress=("bz2", 3))
160 | else:
161 | raise ValueError(f"Unknown format: {formats}")
162 |
--------------------------------------------------------------------------------