├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── assets ├── advance │ ├── backyard-7_0.jpg │ ├── backyard-7_1.jpg │ ├── backyard-7_2.jpg │ ├── backyard-7_3.jpg │ ├── backyard-7_4.jpg │ ├── backyard-7_5.jpg │ ├── backyard-7_6.jpg │ ├── blue-car.jpg │ ├── garden-4_0.jpg │ ├── garden-4_1.jpg │ ├── garden-4_2.jpg │ ├── garden-4_3.jpg │ ├── telebooth-2_0.jpg │ ├── telebooth-2_1.jpg │ ├── vgg-lab-4_0.png │ ├── vgg-lab-4_1.png │ ├── vgg-lab-4_2.png │ └── vgg-lab-4_3.png ├── basic │ ├── blue-car.jpg │ ├── hilly-countryside.jpg │ ├── lily-dragon.png │ ├── llff-room.jpg │ ├── mountain-lake.jpg │ ├── vasedeck.jpg │ └── vgg-lab-4_0.png ├── benchmark.png └── spiral.gif ├── benchmark ├── README.md └── export_reconfusion_example.py ├── demo.py ├── demo_gr.py ├── docs ├── CLI_USAGE.md ├── GR_USAGE.md └── INSTALL.md ├── pyproject.toml └── seva ├── __init__.py ├── data_io.py ├── eval.py ├── geometry.py ├── gui.py ├── model.py ├── modules ├── __init__.py ├── autoencoder.py ├── conditioner.py ├── layers.py ├── preprocessor.py └── transformer.py ├── sampling.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .envrc 2 | .venv/ 3 | .gradio/ 4 | work_dirs* 5 | 6 | # Byte-compiled files 7 | __pycache__/ 8 | *.py[cod] 9 | 10 | # Virtual environments 11 | env/ 12 | venv/ 13 | ENV/ 14 | .VENV/ 15 | 16 | # Distribution files 17 | build/ 18 | dist/ 19 | *.egg-info/ 20 | 21 | # Logs and temporary files 22 | *.log 23 | *.tmp 24 | *.bak 25 | *.swp 26 | 27 | # IDE files 28 | .idea/ 29 | .vscode/ 30 | *.sublime-workspace 31 | *.sublime-project 32 | 33 | # OS files 34 | .DS_Store 35 | Thumbs.db 36 | 37 | # Testing and coverage 38 | htmlcov/ 39 | .coverage 40 | *.cover 41 | *.py,cover 42 | .cache/ 43 | 44 | # Jupyter Notebook checkpoints 45 | .ipynb_checkpoints/ 46 | 47 | # Pre-commit hooks 48 | .pre-commit-config.yaml~ 49 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/dust3r"] 2 | path = third_party/dust3r 3 | url = https://github.com/jensenstability/dust3r 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | default_stages: [pre-commit] 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v5.0.0 7 | hooks: 8 | - id: trailing-whitespace 9 | - id: end-of-file-fixer 10 | - repo: https://github.com/charliermarsh/ruff-pre-commit 11 | rev: v0.8.3 12 | hooks: 13 | - id: ruff 14 | types_or: [python, pyi, jupyter] 15 | args: [--fix, --extend-ignore=E402] 16 | - id: ruff-format 17 | types_or: [python, pyi, jupyter] 18 | - repo: https://github.com/pre-commit/mirrors-prettier 19 | rev: v3.1.0 20 | hooks: 21 | - id: prettier 22 | types_or: [markdown] 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Stability AI Non-Commercial License Agreement 2 | Last Updated: February 20, 2025 3 | 4 | I. INTRODUCTION 5 | 6 | This Stability AI Non-Commercial License Agreement (the “Agreement”) applies to any individual person or entity 7 | (“You”, “Your” or “Licensee”) that uses or distributes any portion or element of the Stability AI Materials or 8 | Derivative Works thereof for any Research & Non-Commercial use. Capitalized terms not otherwise defined herein 9 | are defined in Section IV below. 10 | 11 | This Agreement is intended to allow research and non-commercial uses of the Model free of charge. 12 | 13 | By clicking “I Accept” or by using or distributing or using any portion or element of the Stability Materials 14 | or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. 15 | 16 | If You are acting on behalf of a company, organization, or other entity, then “You” includes you and that entity, 17 | and You agree that You: 18 | (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and 19 | (ii) You agree to the terms of this Agreement on that entity’s behalf. 20 | 21 | --- 22 | 23 | II. RESEARCH & NON-COMMERCIAL USE LICENSE 24 | 25 | Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, 26 | non-sublicensable, revocable, and royalty-free limited license under Stability AI’s intellectual property or other 27 | rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create 28 | Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. 29 | 30 | - **“Research Purpose”** means academic or scientific advancement, and in each case, is not primarily intended 31 | for commercial advantage or monetary compensation to You or others. 32 | - **“Non-Commercial Purpose”** means any purpose other than a Research Purpose that is not primarily intended 33 | for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) 34 | or evaluation and testing. 35 | 36 | --- 37 | 38 | III. GENERAL TERMS 39 | 40 | Your Research or Non-Commercial license under this Agreement is subject to the following terms. 41 | 42 | ### a. Distribution & Attribution 43 | If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product 44 | or service that uses any portion of them, You shall: 45 | 1. Provide a copy of this Agreement to that third party. 46 | 2. Retain the following attribution notice within a **"Notice"** text file distributed as a part of such copies: 47 | 48 | **"This Stability AI Model is licensed under the Stability AI Non-Commercial License, 49 | Copyright © Stability AI Ltd. All Rights Reserved."** 50 | 51 | 3. Prominently display **“Powered by Stability AI”** on a related website, user interface, blog post, 52 | about page, or product documentation. 53 | 4. If You create a Derivative Work, You may add your own attribution notice(s) to the **"Notice"** text file 54 | included with that Derivative Work, provided that You clearly indicate which attributions apply to the 55 | Stability AI Materials and state in the **"Notice"** text file that You changed the Stability AI Materials 56 | and how it was modified. 57 | 58 | ### b. Use Restrictions 59 | Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability 60 | AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control 61 | Laws and equivalent regulations) and adhere to the Documentation and Stability AI’s AUP, which is hereby 62 | incorporated by reference. 63 | 64 | Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the 65 | Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model 66 | (excluding the Model or Derivative Works). 67 | 68 | ### c. Intellectual Property 69 | 70 | #### (i) Trademark License 71 | No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials 72 | or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of 73 | its Affiliates, except as required under Section IV(a) herein. 74 | 75 | #### (ii) Ownership of Derivative Works 76 | As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI’s 77 | ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI. 78 | 79 | #### (iii) Ownership of Outputs 80 | As between You and Stability AI, You own any outputs generated from the Model or Derivative Works to the extent 81 | permitted by applicable law. 82 | 83 | #### (iv) Disputes 84 | If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works, or 86 | associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual 87 | property or other rights owned or licensable by You, then any licenses granted to You under this Agreement 88 | shall terminate as of the date such litigation or claim is filed or instituted. 89 | 90 | You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out 91 | of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of 92 | this Agreement. 93 | 94 | #### (v) Feedback 95 | From time to time, You may provide Stability AI with verbal and/or written suggestions, comments, or other 96 | feedback related to Stability AI’s existing or prospective technology, products, or services (collectively, 97 | “Feedback”). 98 | 99 | You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant 100 | Stability AI a **perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, 101 | worldwide right and license** to exploit the Feedback in any manner without restriction. 102 | 103 | Your Feedback is provided **“AS IS”** and You make no warranties whatsoever about any Feedback. 104 | 105 | --- 106 | 107 | IV. DEFINITIONS 108 | 109 | - **“Affiliate(s)”** means any entity that directly or indirectly controls, is controlled by, or is under common 110 | control with the subject entity. For purposes of this definition, “control” means direct or indirect ownership 111 | or control of more than 50% of the voting interests of the subject entity. 112 | - **“AUP”** means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may 113 | be updated from time to time. 114 | - **"Derivative Work(s)"** means: 115 | (a) Any derivative work of the Stability AI Materials as recognized by U.S. copyright laws. 116 | (b) Any modifications to a Model, and any other model created which is based on or derived from the Model or 117 | the Model’s output, including **fine-tune** and **low-rank adaptation** models derived from a Model or 118 | a Model’s output, but does not include the output of any Model. 119 | - **“Model”** means Stability AI’s Stable Virtual Camera model. 120 | - **"Stability AI" or "we"** means Stability AI Ltd. and its Affiliates. 121 | - **"Software"** means Stability AI’s proprietary software made available under this Agreement now or in the future. 122 | - **“Stability AI Materials”** means, collectively, Stability’s proprietary Model, Software, and Documentation 123 | (and any portion or combination thereof) made available under this Agreement. 124 | - **“Trade Control Laws”** means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations. 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Stable Virtual Camera: Generative View Synthesis with Diffusion Models

3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | [Jensen (Jinghao) Zhou](https://shallowtoil.github.io/)\*, [Hang Gao](https://hangg7.com/)\* 12 |
13 | [Vikram Voleti](https://voletiv.github.io/), [Aaryaman Vasishta](https://www.aaryaman.net/), [Chun-Han Yao](https://chhankyao.github.io/), [Mark Boss](https://markboss.me/) 14 |
15 | [Philip Torr](https://eng.ox.ac.uk/people/philip-torr/), [Christian Rupprecht](https://chrirupp.github.io/), [Varun Jampani](https://varunjampani.github.io/) 16 |
17 |
18 | [Stability AI](https://stability.ai/), [University of Oxford](https://www.robots.ox.ac.uk/~vgg/), [UC Berkeley](https://bair.berkeley.edu/) 19 | 20 |
21 | 22 |

23 | Teaser 24 |

25 | 26 |

27 | teaser_page1 28 |

29 | 30 | # Overview 31 | 32 | `Stable Virtual Camera (SEVA)` is a generalist diffusion model for Novel View Synthesis (NVS), generating 3D consistent novel views of a scene, given any number of input views and target cameras. 33 | 34 | # :tada: News 35 | 36 | - June 2025 - Release v`1.1` model checkpoint. 37 | - March 2025 - `Stable Virtual Camera` is out everywhere. 38 | 39 | # :gear: Versions 40 | 41 | | Model Version | \#Parameter | Resolution | Download Link | Update Notes | 42 | | :-----------: | :---------: | :--------: | :--------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------: | 43 | | `1.1` | 1.3B | 576P | 🤗 [Huggingface](https://huggingface.co/stabilityai/stable-virtual-camera/blob/main/modelv1.1.safetensors) | Fixing known issues of foreground objects sometimes being detached from the background in v`1.0` | 44 | | `1.0` | 1.3B | 576P | 🤗 [Huggingface](https://huggingface.co/stabilityai/stable-virtual-camera/blob/main/model.safetensors) | Initial release | 45 | 46 | You can specify the version via, for example, `load_model(..., model_version=1.1)` in the script. 47 | 48 | # :wrench: Installation 49 | 50 | ```bash 51 | git clone --recursive https://github.com/Stability-AI/stable-virtual-camera 52 | cd stable-virtual-camera 53 | pip install -e . 54 | ``` 55 | 56 | Please note that you will need `python>=3.10` and `torch>=2.6.0`. 57 | 58 | Check [INSTALL.md](docs/INSTALL.md) for other dependencies if you want to use our demos or develop from this repo. 59 | For windows users, please use WSL as flash attention isn't supported on native Windows [yet](https://github.com/pytorch/pytorch/issues/108175). 60 | 61 | # :open_book: Usage 62 | 63 | You need to properly authenticate with Hugging Face to download our model weights. Once set up, our code will handle it automatically at your first run. You can authenticate by running 64 | 65 | ```bash 66 | # This will prompt you to enter your Hugging Face credentials. 67 | huggingface-cli login 68 | ``` 69 | 70 | Once authenticated, go to our model card [here](https://huggingface.co/stabilityai/stable-virtual-camera) and enter your information for access. 71 | 72 | We provide two demos for you to interact with `Stable Virtual Camera`. 73 | 74 | ### :rocket: Gradio demo 75 | 76 | This gradio demo is a GUI interface that requires no expert knowledge, suitable for general users. Simply run 77 | 78 | ```bash 79 | python demo_gr.py 80 | ``` 81 | 82 | For a more detailed guide, follow [GR_USAGE.md](docs/GR_USAGE.md). 83 | 84 | ### :computer: CLI demo 85 | 86 | This cli demo allows you to pass in more options and control the model in a fine-grained way, suitable for power users and academic researchers. An example command line looks as simple as 87 | 88 | ```bash 89 | python demo.py --data_path [additional arguments] 90 | ``` 91 | 92 | For a more detailed guide, follow [CLI_USAGE.md](docs/CLI_USAGE.md). 93 | 94 | For users interested in benchmarking NVS models using command lines, check [`benchmark`](benchmark/) containing the details about scenes, splits, and input/target views we reported in the paper. 95 | 96 | # :question: Q&A 97 | 98 | - Training script? See issue https://github.com/Stability-AI/stable-virtual-camera/issues/27, https://github.com/Stability-AI/stable-virtual-camera/issues/42. [@nviolante25](https://www.github.com/nviolante25) has made a pull request (https://github.com/Stability-AI/stable-virtual-camera/pull/51) based on the dicussions. 99 | - License for the output? See issue https://github.com/Stability-AI/stable-virtual-camera/issues/26. The output follows the same non-commercial license. 100 | 101 | # :books: Citing 102 | 103 | If you find this repository useful, please consider giving a star :star: and citation. 104 | 105 | ``` 106 | @article{zhou2025stable, 107 | title={Stable Virtual Camera: Generative View Synthesis with Diffusion Models}, 108 | author={Jensen (Jinghao) Zhou and Hang Gao and Vikram Voleti and Aaryaman Vasishta and Chun-Han Yao and Mark Boss and 109 | Philip Torr and Christian Rupprecht and Varun Jampani 110 | }, 111 | journal={arXiv preprint arXiv:2503.14489}, 112 | year={2025} 113 | } 114 | ``` 115 | -------------------------------------------------------------------------------- /assets/advance/backyard-7_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/backyard-7_0.jpg -------------------------------------------------------------------------------- /assets/advance/backyard-7_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/backyard-7_1.jpg -------------------------------------------------------------------------------- /assets/advance/backyard-7_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/backyard-7_2.jpg -------------------------------------------------------------------------------- /assets/advance/backyard-7_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/backyard-7_3.jpg -------------------------------------------------------------------------------- /assets/advance/backyard-7_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/backyard-7_4.jpg -------------------------------------------------------------------------------- /assets/advance/backyard-7_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/backyard-7_5.jpg -------------------------------------------------------------------------------- /assets/advance/backyard-7_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/backyard-7_6.jpg -------------------------------------------------------------------------------- /assets/advance/blue-car.jpg: -------------------------------------------------------------------------------- 1 | ../basic/blue-car.jpg -------------------------------------------------------------------------------- /assets/advance/garden-4_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/garden-4_0.jpg -------------------------------------------------------------------------------- /assets/advance/garden-4_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/garden-4_1.jpg -------------------------------------------------------------------------------- /assets/advance/garden-4_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/garden-4_2.jpg -------------------------------------------------------------------------------- /assets/advance/garden-4_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/garden-4_3.jpg -------------------------------------------------------------------------------- /assets/advance/telebooth-2_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/telebooth-2_0.jpg -------------------------------------------------------------------------------- /assets/advance/telebooth-2_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/telebooth-2_1.jpg -------------------------------------------------------------------------------- /assets/advance/vgg-lab-4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/vgg-lab-4_0.png -------------------------------------------------------------------------------- /assets/advance/vgg-lab-4_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/vgg-lab-4_1.png -------------------------------------------------------------------------------- /assets/advance/vgg-lab-4_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/vgg-lab-4_2.png -------------------------------------------------------------------------------- /assets/advance/vgg-lab-4_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/advance/vgg-lab-4_3.png -------------------------------------------------------------------------------- /assets/basic/blue-car.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/basic/blue-car.jpg -------------------------------------------------------------------------------- /assets/basic/hilly-countryside.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/basic/hilly-countryside.jpg -------------------------------------------------------------------------------- /assets/basic/lily-dragon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/basic/lily-dragon.png -------------------------------------------------------------------------------- /assets/basic/llff-room.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/basic/llff-room.jpg -------------------------------------------------------------------------------- /assets/basic/mountain-lake.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/basic/mountain-lake.jpg -------------------------------------------------------------------------------- /assets/basic/vasedeck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/basic/vasedeck.jpg -------------------------------------------------------------------------------- /assets/basic/vgg-lab-4_0.png: -------------------------------------------------------------------------------- 1 | ../advance/vgg-lab-4_0.png -------------------------------------------------------------------------------- /assets/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/benchmark.png -------------------------------------------------------------------------------- /assets/spiral.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/assets/spiral.gif -------------------------------------------------------------------------------- /benchmark/README.md: -------------------------------------------------------------------------------- 1 | # :bar_chart: Benchmark 2 | 3 | We provide in this release (`benchmark.zip`) with the following 17 entries as a benchmark to evaluate NVS models. 4 | We hope this will help standardize the evaluation of NVS models and facilitate fair comparison between different methods. 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 |
DatasetSplitPathContentImage PreprocessingImage Postprocessing
OmniObject3DS (SV3D), O (Ours) omniobject3dtrain_test_split_*.jsoncenter crop to 576\
GSOS (SV3D), O (Ours) gsotrain_test_split_*.jsoncenter crop to 576\
RealEstate10KD (4DiM) re10k-4dimtrain_test_split_*.jsoncenter crop to 576resize to 256
R (ReconFusion) re10ktrain_test_split_*.jsoncenter crop to 576\
P (pixelSplat) re10k-pixelsplattrain_test_split_*.jsoncenter crop to 576resize to 256
V (ViewCrafter) re10k-viewcrafterimages/*.png,transforms.json,train_test_split_*.jsonresize the shortest side to 576 (--L_short 576)center crop
LLFFR (ReconFusion) llfftrain_test_split_*.jsoncenter crop to 576\
DTUR (ReconFusion) dtutrain_test_split_*.jsoncenter crop to 576\
CO3DR (ReconFusion) co3dtrain_test_split_*.jsoncenter crop to 576\
V (ViewCrafter) co3d-viewcrafterimages/*.png,transforms.json,train_test_split_*.jsonresize the shortest side to 576 (--L_short 576)center crop
WildRGB-DOₑ (Ours, easy) wildgbd/easytrain_test_split_*.jsoncenter crop to 576\
Oₕ (Ours, hard) wildgbd/hardtrain_test_split_*.jsoncenter crop to 576\
Mip-NeRF360R (ReconFusion) mipnerf360train_test_split_*.jsoncenter crop to 576\
DL3DV-140O (Ours) dl3dv10train_test_split_*.jsoncenter crop to 576\
L (Long-LRM) dl3dv140train_test_split_*.jsoncenter crop to 576\
Tanks and TemplesV (ViewCrafter) tnt-viewcrafterimages/*.png,transforms.json,train_test_split_*.jsonresize the shortest side to 576 (--L_short 576)center crop
L (Long-LRM) tnt-longlrmtrain_test_split_*.jsoncenter crop to 576\
149 | 150 | - For entries without `images/*.png` and `transforms.json`, we use the images from the original dataset after converting them into the `reconfusion` format, which is then parsable by `ReconfusionParser` (`seva/data_io.py`). 151 | Please note that during this conversion, you should sort the images by `sorted(image_paths)`, which is then directly indexable by our train/test ids. We provide in `benchmark/export_reconfusion_example.py` an example script converting an existing academic dataset into the the scene folders. 152 | - For evaluation and benchmarking, we first conduct operations in the `Image Preprocessing` column to the model input and then operations in the `Image Postprocessing` column to the model output. The final processed samples are used for metric computation. 153 | 154 | ## Acknowledgment 155 | 156 | We would like to thank Wangbo Yu, Aleksander Hołyński, Saurabh Saxena, and Ziwen Chen for their kind clarification on experiment settings. 157 | -------------------------------------------------------------------------------- /benchmark/export_reconfusion_example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import numpy as np 6 | from PIL import Image 7 | 8 | try: 9 | from sklearn.cluster import KMeans # type: ignore[import] 10 | except ImportError: 11 | print("Please install sklearn to use this script.") 12 | exit(1) 13 | 14 | # Define the folder containing the image and JSON files 15 | subfolder = "/path/to/your/dataset" 16 | output_file = os.path.join(subfolder, "transforms.json") 17 | 18 | # List to hold the frames 19 | frames = [] 20 | 21 | # Iterate over the files in the folder 22 | for file in sorted(os.listdir(subfolder)): 23 | if file.endswith(".json"): 24 | # Read the JSON file containing camera extrinsics and intrinsics 25 | json_path = os.path.join(subfolder, file) 26 | with open(json_path, "r") as f: 27 | data = json.load(f) 28 | 29 | # Read the corresponding image file 30 | image_file = file.replace(".json", ".png") 31 | image_path = os.path.join(subfolder, image_file) 32 | if not os.path.exists(image_path): 33 | print(f"Image file not found for {file}, skipping...") 34 | continue 35 | with Image.open(image_path) as img: 36 | w, h = img.size 37 | 38 | # Extract and normalize intrinsic matrix K 39 | K = data["K"] 40 | fx = K[0][0] * w 41 | fy = K[1][1] * h 42 | cx = K[0][2] * w 43 | cy = K[1][2] * h 44 | 45 | # Extract the transformation matrix 46 | transform_matrix = np.array(data["c2w"]) 47 | # Adjust for OpenGL convention 48 | transform_matrix[..., [1, 2]] *= -1 49 | 50 | # Add the frame data 51 | frames.append( 52 | { 53 | "fl_x": fx, 54 | "fl_y": fy, 55 | "cx": cx, 56 | "cy": cy, 57 | "w": w, 58 | "h": h, 59 | "file_path": f"./{os.path.relpath(image_path, subfolder)}", 60 | "transform_matrix": transform_matrix.tolist(), 61 | } 62 | ) 63 | 64 | # Create the output dictionary 65 | transforms_data = {"orientation_override": "none", "frames": frames} 66 | 67 | # Write to the transforms.json file 68 | with open(output_file, "w") as f: 69 | json.dump(transforms_data, f, indent=4) 70 | 71 | print(f"transforms.json generated at {output_file}") 72 | 73 | 74 | # Train-test split function using K-means clustering with stride 75 | def create_train_test_split(frames, n, output_path, stride): 76 | # Prepare the data for K-means 77 | positions = [] 78 | for frame in frames: 79 | transform_matrix = np.array(frame["transform_matrix"]) 80 | position = transform_matrix[:3, 3] # 3D camera position 81 | direction = transform_matrix[:3, 2] / np.linalg.norm( 82 | transform_matrix[:3, 2] 83 | ) # Normalized 3D direction 84 | positions.append(np.concatenate([position, direction])) 85 | 86 | positions = np.array(positions) 87 | 88 | # Apply K-means clustering 89 | kmeans = KMeans(n_clusters=n, random_state=42) 90 | kmeans.fit(positions) 91 | centers = kmeans.cluster_centers_ 92 | 93 | # Find the index closest to each cluster center 94 | train_ids = [] 95 | for center in centers: 96 | distances = np.linalg.norm(positions - center, axis=1) 97 | train_ids.append(int(np.argmin(distances))) # Convert to Python int 98 | 99 | # Remaining indices as test_ids, applying stride 100 | all_indices = set(range(len(frames))) 101 | remaining_indices = sorted(all_indices - set(train_ids)) 102 | test_ids = [ 103 | int(idx) for idx in remaining_indices[::stride] 104 | ] # Convert to Python int 105 | 106 | # Create the split data 107 | split_data = {"train_ids": sorted(train_ids), "test_ids": test_ids} 108 | 109 | with open(output_path, "w") as f: 110 | json.dump(split_data, f, indent=4) 111 | 112 | print(f"Train-test split file generated at {output_path}") 113 | 114 | 115 | # Parse arguments 116 | if __name__ == "__main__": 117 | parser = argparse.ArgumentParser( 118 | description="Generate train-test split JSON file using K-means clustering." 119 | ) 120 | parser.add_argument( 121 | "--n", 122 | type=int, 123 | required=True, 124 | help="Number of frames to include in the training set.", 125 | ) 126 | parser.add_argument( 127 | "--stride", 128 | type=int, 129 | default=1, 130 | help="Stride for selecting test frames (not used with K-means).", 131 | ) 132 | 133 | args = parser.parse_args() 134 | 135 | # Create train-test split 136 | train_test_split_path = os.path.join(subfolder, f"train_test_split_{args.n}.json") 137 | create_train_test_split(frames, args.n, train_test_split_path, args.stride) 138 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | 5 | import fire 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | from seva.data_io import get_parser 13 | from seva.eval import ( 14 | IS_TORCH_NIGHTLY, 15 | compute_relative_inds, 16 | create_transforms_simple, 17 | infer_prior_inds, 18 | infer_prior_stats, 19 | run_one_scene, 20 | ) 21 | from seva.geometry import ( 22 | generate_interpolated_path, 23 | generate_spiral_path, 24 | get_arc_horizontal_w2cs, 25 | get_default_intrinsics, 26 | get_lookat, 27 | get_preset_pose_fov, 28 | ) 29 | from seva.model import SGMWrapper 30 | from seva.modules.autoencoder import AutoEncoder 31 | from seva.modules.conditioner import CLIPConditioner 32 | from seva.sampling import DiscreteDenoiser 33 | from seva.utils import load_model 34 | 35 | device = "cuda:0" 36 | 37 | 38 | # Constants. 39 | WORK_DIR = "work_dirs/demo" 40 | 41 | if IS_TORCH_NIGHTLY: 42 | COMPILE = True 43 | os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1" 44 | os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" 45 | else: 46 | COMPILE = False 47 | 48 | AE = AutoEncoder(chunk_size=1).to(device) 49 | CONDITIONER = CLIPConditioner().to(device) 50 | DENOISER = DiscreteDenoiser(num_idx=1000, device=device) 51 | 52 | if COMPILE: 53 | CONDITIONER = torch.compile(CONDITIONER, dynamic=False) 54 | AE = torch.compile(AE, dynamic=False) 55 | 56 | 57 | def parse_task( 58 | task, 59 | scene, 60 | num_inputs, 61 | T, 62 | version_dict, 63 | ): 64 | options = version_dict["options"] 65 | 66 | anchor_indices = None 67 | anchor_c2ws = None 68 | anchor_Ks = None 69 | 70 | if task == "img2trajvid_s-prob": 71 | if num_inputs is not None: 72 | assert ( 73 | num_inputs == 1 74 | ), "Task `img2trajvid_s-prob` only support 1-view conditioning..." 75 | else: 76 | num_inputs = 1 77 | num_targets = options.get("num_targets", T - 1) 78 | num_anchors = infer_prior_stats( 79 | T, 80 | num_inputs, 81 | num_total_frames=num_targets, 82 | version_dict=version_dict, 83 | ) 84 | 85 | input_indices = [0] 86 | anchor_indices = np.linspace(1, num_targets, num_anchors).tolist() 87 | 88 | all_imgs_path = [scene] + [None] * num_targets 89 | 90 | c2ws, fovs = get_preset_pose_fov( 91 | option=options.get("traj_prior", "orbit"), 92 | num_frames=num_targets + 1, 93 | start_w2c=torch.eye(4), 94 | look_at=torch.Tensor([0, 0, 10]), 95 | ) 96 | 97 | with Image.open(scene) as img: 98 | W, H = img.size 99 | aspect_ratio = W / H 100 | Ks = get_default_intrinsics(fovs, aspect_ratio=aspect_ratio) # unormalized 101 | Ks[:, :2] *= ( 102 | torch.tensor([W, H]).reshape(1, -1, 1).repeat(Ks.shape[0], 1, 1) 103 | ) # normalized 104 | Ks = Ks.numpy() 105 | 106 | anchor_c2ws = c2ws[[round(ind) for ind in anchor_indices]] 107 | anchor_Ks = Ks[[round(ind) for ind in anchor_indices]] 108 | 109 | else: 110 | parser = get_parser( 111 | parser_type="reconfusion", 112 | data_dir=scene, 113 | normalize=False, 114 | ) 115 | all_imgs_path = parser.image_paths 116 | c2ws = parser.camtoworlds 117 | camera_ids = parser.camera_ids 118 | Ks = np.concatenate([parser.Ks_dict[cam_id][None] for cam_id in camera_ids], 0) 119 | 120 | if num_inputs is None: 121 | assert len(parser.splits_per_num_input_frames.keys()) == 1 122 | num_inputs = list(parser.splits_per_num_input_frames.keys())[0] 123 | split_dict = parser.splits_per_num_input_frames[num_inputs] # type: ignore 124 | elif isinstance(num_inputs, str): 125 | split_dict = parser.splits_per_num_input_frames[num_inputs] # type: ignore 126 | num_inputs = int(num_inputs.split("-")[0]) # for example 1_from32 127 | else: 128 | split_dict = parser.splits_per_num_input_frames[num_inputs] # type: ignore 129 | 130 | num_targets = len(split_dict["test_ids"]) 131 | 132 | if task == "img2img": 133 | # Note in this setting, we should refrain from using all the other camera 134 | # info except ones from sampled_indices, and most importantly, the order. 135 | num_anchors = infer_prior_stats( 136 | T, 137 | num_inputs, 138 | num_total_frames=num_targets, 139 | version_dict=version_dict, 140 | ) 141 | 142 | sampled_indices = np.sort( 143 | np.array(split_dict["train_ids"] + split_dict["test_ids"]) 144 | ) # we always sort all indices first 145 | 146 | traj_prior = options.get("traj_prior", None) 147 | if traj_prior == "spiral": 148 | assert parser.bounds is not None 149 | anchor_c2ws = generate_spiral_path( 150 | c2ws[sampled_indices] @ np.diagflat([1, -1, -1, 1]), 151 | parser.bounds[sampled_indices], 152 | n_frames=num_anchors + 1, 153 | n_rots=2, 154 | zrate=0.5, 155 | endpoint=False, 156 | )[1:] @ np.diagflat([1, -1, -1, 1]) 157 | elif traj_prior == "interpolated": 158 | assert num_inputs > 1 159 | anchor_c2ws = generate_interpolated_path( 160 | c2ws[split_dict["train_ids"], :3], 161 | round((num_anchors + 1) / (num_inputs - 1)), 162 | endpoint=False, 163 | )[1 : num_anchors + 1] 164 | elif traj_prior == "orbit": 165 | c2ws_th = torch.as_tensor(c2ws) 166 | lookat = get_lookat( 167 | c2ws_th[sampled_indices, :3, 3], 168 | c2ws_th[sampled_indices, :3, 2], 169 | ) 170 | anchor_c2ws = torch.linalg.inv( 171 | get_arc_horizontal_w2cs( 172 | torch.linalg.inv(c2ws_th[split_dict["train_ids"][0]]), 173 | lookat, 174 | -F.normalize( 175 | c2ws_th[split_dict["train_ids"]][:, :3, 1].mean(0), 176 | dim=-1, 177 | ), 178 | num_frames=num_anchors + 1, 179 | endpoint=False, 180 | ) 181 | ).numpy()[1:, :3] 182 | else: 183 | anchor_c2ws = None 184 | # anchor_Ks is default to be the first from target_Ks 185 | 186 | all_imgs_path = [all_imgs_path[i] for i in sampled_indices] 187 | c2ws = c2ws[sampled_indices] 188 | Ks = Ks[sampled_indices] 189 | 190 | # absolute to relative indices 191 | input_indices = compute_relative_inds( 192 | sampled_indices, 193 | np.array(split_dict["train_ids"]), 194 | ) 195 | anchor_indices = np.arange( 196 | sampled_indices.shape[0], 197 | sampled_indices.shape[0] + num_anchors, 198 | ).tolist() # the order has no meaning here 199 | 200 | elif task == "img2vid": 201 | num_targets = len(all_imgs_path) - num_inputs 202 | num_anchors = infer_prior_stats( 203 | T, 204 | num_inputs, 205 | num_total_frames=num_targets, 206 | version_dict=version_dict, 207 | ) 208 | 209 | input_indices = split_dict["train_ids"] 210 | anchor_indices = infer_prior_inds( 211 | c2ws, 212 | num_prior_frames=num_anchors, 213 | input_frame_indices=input_indices, 214 | options=options, 215 | ).tolist() 216 | num_anchors = len(anchor_indices) 217 | anchor_c2ws = c2ws[anchor_indices, :3] 218 | anchor_Ks = Ks[anchor_indices] 219 | 220 | elif task == "img2trajvid": 221 | num_anchors = infer_prior_stats( 222 | T, 223 | num_inputs, 224 | num_total_frames=num_targets, 225 | version_dict=version_dict, 226 | ) 227 | 228 | target_c2ws = c2ws[split_dict["test_ids"], :3] 229 | target_Ks = Ks[split_dict["test_ids"]] 230 | anchor_c2ws = target_c2ws[ 231 | np.linspace(0, num_targets - 1, num_anchors).round().astype(np.int64) 232 | ] 233 | anchor_Ks = target_Ks[ 234 | np.linspace(0, num_targets - 1, num_anchors).round().astype(np.int64) 235 | ] 236 | 237 | sampled_indices = split_dict["train_ids"] + split_dict["test_ids"] 238 | all_imgs_path = [all_imgs_path[i] for i in sampled_indices] 239 | c2ws = c2ws[sampled_indices] 240 | Ks = Ks[sampled_indices] 241 | 242 | input_indices = np.arange(num_inputs).tolist() 243 | anchor_indices = np.linspace( 244 | num_inputs, num_inputs + num_targets - 1, num_anchors 245 | ).tolist() 246 | 247 | else: 248 | raise ValueError(f"Unknown task: {task}") 249 | 250 | return ( 251 | all_imgs_path, 252 | num_inputs, 253 | num_targets, 254 | input_indices, 255 | anchor_indices, 256 | torch.tensor(c2ws[:, :3]).float(), 257 | torch.tensor(Ks).float(), 258 | (torch.tensor(anchor_c2ws[:, :3]).float() if anchor_c2ws is not None else None), 259 | (torch.tensor(anchor_Ks).float() if anchor_Ks is not None else None), 260 | ) 261 | 262 | 263 | def main( 264 | data_path, 265 | data_items=None, 266 | version=1.1, 267 | task="img2img", 268 | save_subdir="", 269 | H=None, 270 | W=None, 271 | T=None, 272 | use_traj_prior=False, 273 | pretrained_model_name_or_path="stabilityai/stable-virtual-camera", 274 | weight_name="model.safetensors", 275 | seed=23, 276 | **overwrite_options, 277 | ): 278 | MODEL = SGMWrapper( 279 | load_model( 280 | model_version=version, 281 | pretrained_model_name_or_path=pretrained_model_name_or_path, 282 | weight_name=weight_name, 283 | device="cpu", 284 | verbose=True, 285 | ).eval() 286 | ).to(device) 287 | 288 | if COMPILE: 289 | MODEL = torch.compile(MODEL, dynamic=False) 290 | 291 | VERSION_DICT = { 292 | "H": H or 576, 293 | "W": W or 576, 294 | "T": ([int(t) for t in T.split(",")] if isinstance(T, str) else T) or 21, 295 | "C": 4, 296 | "f": 8, 297 | "options": { 298 | "chunk_strategy": "nearest-gt", 299 | "video_save_fps": 30.0, 300 | "beta_linear_start": 5e-6, 301 | "log_snr_shift": 2.4, 302 | "guider_types": 1, 303 | "cfg": 2.0, 304 | "camera_scale": 2.0, 305 | "num_steps": 50, 306 | "cfg_min": 1.2, 307 | "encoding_t": 1, 308 | "decoding_t": 1, 309 | }, 310 | } 311 | 312 | options = VERSION_DICT["options"] 313 | options.update(overwrite_options) 314 | 315 | if data_items is not None: 316 | if not isinstance(data_items, (list, tuple)): 317 | data_items = data_items.split(",") 318 | scenes = [os.path.join(data_path, item) for item in data_items] 319 | else: 320 | scenes = [ 321 | item for item in glob.glob(osp.join(data_path, "*")) if os.path.isfile(item) 322 | ] 323 | 324 | for scene in tqdm(scenes): 325 | num_inputs = options.get("num_inputs", None) 326 | save_path_scene = os.path.join( 327 | WORK_DIR, task, save_subdir, os.path.splitext(os.path.basename(scene))[0] 328 | ) 329 | if options.get("skip_saved", False) and os.path.exists( 330 | os.path.join(save_path_scene, "transforms.json") 331 | ): 332 | print(f"Skipping {scene} as it is already sampled.") 333 | continue 334 | 335 | # parse_task -> infer_prior_stats modifies VERSION_DICT["T"] in-place. 336 | ( 337 | all_imgs_path, 338 | num_inputs, 339 | num_targets, 340 | input_indices, 341 | anchor_indices, 342 | c2ws, 343 | Ks, 344 | anchor_c2ws, 345 | anchor_Ks, 346 | ) = parse_task( 347 | task, 348 | scene, 349 | num_inputs, 350 | VERSION_DICT["T"], 351 | VERSION_DICT, 352 | ) 353 | assert num_inputs is not None 354 | # Create image conditioning. 355 | image_cond = { 356 | "img": all_imgs_path, 357 | "input_indices": input_indices, 358 | "prior_indices": anchor_indices, 359 | } 360 | # Create camera conditioning. 361 | camera_cond = { 362 | "c2w": c2ws.clone(), 363 | "K": Ks.clone(), 364 | "input_indices": list(range(num_inputs + num_targets)), 365 | } 366 | # run_one_scene -> transform_img_and_K modifies VERSION_DICT["H"] and VERSION_DICT["W"] in-place. 367 | video_path_generator = run_one_scene( 368 | task, 369 | VERSION_DICT, # H, W maybe updated in run_one_scene 370 | model=MODEL, 371 | ae=AE, 372 | conditioner=CONDITIONER, 373 | denoiser=DENOISER, 374 | image_cond=image_cond, 375 | camera_cond=camera_cond, 376 | save_path=save_path_scene, 377 | use_traj_prior=use_traj_prior, 378 | traj_prior_Ks=anchor_Ks, 379 | traj_prior_c2ws=anchor_c2ws, 380 | seed=seed, 381 | ) 382 | for _ in video_path_generator: 383 | pass 384 | 385 | # Convert from OpenCV to OpenGL camera format. 386 | c2ws = c2ws @ torch.tensor(np.diag([1, -1, -1, 1])).float() 387 | img_paths = sorted(glob.glob(osp.join(save_path_scene, "samples-rgb", "*.png"))) 388 | if len(img_paths) != len(c2ws): 389 | input_img_paths = sorted( 390 | glob.glob(osp.join(save_path_scene, "input", "*.png")) 391 | ) 392 | assert len(img_paths) == num_targets 393 | assert len(input_img_paths) == num_inputs 394 | assert c2ws.shape[0] == num_inputs + num_targets 395 | target_indices = [i for i in range(c2ws.shape[0]) if i not in input_indices] 396 | img_paths = [ 397 | input_img_paths[input_indices.index(i)] 398 | if i in input_indices 399 | else img_paths[target_indices.index(i)] 400 | for i in range(c2ws.shape[0]) 401 | ] 402 | create_transforms_simple( 403 | save_path=save_path_scene, 404 | img_paths=img_paths, 405 | img_whs=np.array([VERSION_DICT["W"], VERSION_DICT["H"]])[None].repeat( 406 | num_inputs + num_targets, 0 407 | ), 408 | c2ws=c2ws, 409 | Ks=Ks, 410 | ) 411 | 412 | 413 | if __name__ == "__main__": 414 | fire.Fire(main) 415 | -------------------------------------------------------------------------------- /docs/CLI_USAGE.md: -------------------------------------------------------------------------------- 1 | # :computer: CLI Demo 2 | 3 | This cli demo allows you to pass in more options and control the model in a fine-grained way, suitable for power users and academic researchers. An examplar command line looks as simple as 4 | 5 | ```bash 6 | python demo.py --data_path [additional arguments] 7 | ``` 8 | 9 | We discuss here first some key attributes: 10 | 11 | - `Procedural Two-Pass Sampling`: We recommend enabling procedural sampling by setting `--use_traj_prior True --chunk_strategy ` with `` set according to the type of the task. 12 | - `Resolution and Aspect-Ratio`: Default image preprocessing include center cropping. All input and output are square images of size $576\times 576$. To overwrite, the code support to pass in `--W --H ` directly. We recommend passing in `--L_short 576` such that the aspect-ratio of original image is kept while the shortest side will be resized to $576$. 13 | 14 | ## Task 15 | 16 | Before diving into the command lines, we introduce `Task` (specified by `--task `) to bucket different usage cases depending on the data constraints in input and output domains (e.g., if the ordering is available). 17 | 18 | | Task | Type of NVS | Format of `` | Target Views Sorted? | Input and Target Views Sorted? | Recommended Usage | 19 | | :------------------: | :------------: | :--------------------------------------: | :------------------: | :----------------------------: | :----------------------: | 20 | | `img2img` | set NVS | folder (parsable by `ReconfusionParser`) | :x: | :x: | evaluation, benchmarking | 21 | | `img2vid` | trajectory NVS | folder (parsable by `ReconfusionParser`) | :white_check_mark: | :white_check_mark: | evaluation, benchmarking | 22 | | `img2trajvid_s-prob` | trajectory NVS | single image | :white_check_mark: | :white_check_mark: | general | 23 | | `img2trajvid` | trajectory NVS | folder (parsable by `ReconfusionParser`) | :white_check_mark: | :x: | general | 24 | 25 | ### Format of `` 26 | 27 | For `img2trajvid_s-prob` task, we are generating a trajectory video following preset camera motions or effects given only one input image, the data format as simple as 28 | 29 | ```bash 30 | / 31 | ├── scene_1.png 32 | ├── scene_2.png 33 | └── scene_3.png 34 | ``` 35 | 36 | For all the other tasks, we use a folder for each scene that is parsable by `ReconfusionParser` (see `seva/data_io.py`). It contains (1) a subdirectory containing all views; (2) `transforms.json` defining the intrinsics and extrinsics (OpenGL convention) for each image; and (3) `train_test_split_*.json` file splitting the input and target views, with `*` indicating the number of the input views. 37 | 38 | We provide in this release (`assets_demo_cli.zip`) several examplar scenes for you to take reference from. Target views is available if you the data are from academic sources, but in the case where target views is unavailble, we will create dummy black images as placeholders (e.g., the `garden_flythrough` scene). The general data structure follows 39 | 40 | ```bash 41 | / 42 | ├── scene_1/ 43 | ├── train_test_split_1.json # for single-view regime 44 | ├── train_test_split_6.json # for sparse-veiw regime 45 | ├── train_test_split_32.json # for semi-dense-view regime 46 | ├── transforms.json 47 | └── images/ 48 | ├── image_0.png 49 | ├── image_1.png 50 | ├── ... 51 | └── image_1000.png 52 | ├── scene_2 53 | └── scene_3 54 | ``` 55 | 56 | You can specify which scene to run by passing in `--data_items scene_1,scene_2` to run, for example, `scene_1` and `scene_2`. 57 | 58 | ### Recommended Usage 59 | 60 | - `img2img` and `img2vid` are recommended to be used for evaluation and benchmarking. These two tasks are used for the quantitative evalution in our paper. The data is converted from academic datasets so the groundtruth target views are available for metric computation. Check the [`benchmark`](../benchmark/) folder for detailed splits we organize to benchmark different NVS models. 61 | - `img2vid` requries both the input and target views to be sorted, which is usually not guaranteed in general usage. 62 | - `img2trajvid_s-prob` is for general usage but only for single-view regime and fixed preset camera control. 63 | - `img2trajvid` is the task designed for general usage since it does not need the ordering of the input views. This is the task used in the gradio demo. 64 | 65 | Next we go over all tasks and provide for each task an examplar command line. 66 | 67 | ## `img2img` 68 | 69 | ```bash 70 | python demo.py \ 71 | --data_path \ 72 | --num_inputs

\ 73 | --video_save_fps 10 74 | ``` 75 | 76 | - `--num_inputs

` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder. 77 | - The above command works for the dataset without trajectory prior (e.g., DL3DV-140). When the trajectory prior is available given a benchmarking dataset, for example, `orbit` trajectory prior for the CO3D dataset, we use the `nearest-gt` chunking strategy by setting `--use_traj_prior True --traj_prior orbit --chunking_strategy nearest-gt`. We find this leads to more 3D consistent results. 78 | - For all the single-view conditioning test scenarios: we set `--camera_scale ` with `` sweeping 20 different camera scales `0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9 2.0`. 79 | - In single-view regime for the RealEstate10K dataset, we find increasing `cfg` is helpful: we additionally set `--cfg 6.0` (`cfg` is `2.0` by default). 80 | - For the evaluation in semi-dense-view regime (i.e., DL3DV-140 and Tanks and Temples dataset) with `32` input views, we zero-shot extend `T` to fit all input and target views in one forward. Specifically, we set `--T 90` for the DL3DV-140 dataset and `--T 80` for the Tanks and Temples dataset. 81 | - For the evaluation on ViewCrafter split (including the RealEastate10K, CO3D, and Tanks and Temples dataset), we find zero-shot extending `T` to `25` to fit all input and target views in one forward is better. Also, the V split uses the original image resolutions: we therefore set `--T 25 --L_short 576`. 82 | 83 | For example, you can run the following command on the example `dl3d140-165f5af8bfe32f70595a1c9393a6e442acf7af019998275144f605b89a306557` with 3 input views: 84 | 85 | ```bash 86 | python demo.py \ 87 | --data_path /path/to/assets_demo_cli/ \ 88 | --data_items dl3d140-165f5af8bfe32f70595a1c9393a6e442acf7af019998275144f605b89a306557 \ 89 | --num_inputs 3 \ 90 | --video_save_fps 10 91 | ``` 92 | 93 | ## `img2vid` 94 | 95 | ```bash 96 | python demo.py \ 97 | --data_path \ 98 | --task img2vid \ 99 | --replace_or_include_input True \ 100 | --num_inputs

\ 101 | --use_traj_prior True \ 102 | --chunk_strategy interp \ 103 | ``` 104 | 105 | - `--replace_or_include_input True` is necessary here since input views and target views are mutually exclusive, forming a trajectory together in this task, so we need to append back the input views to the generated target views. 106 | - `--num_inputs

` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder. 107 | - We use `interp` chunking strategy by default. 108 | - For the evaluation on ViewCrafter split (including the RealEastate10K, CO3D, and Tanks and Temples dataset), we find zero-shot extending `T` to `25` to fit all input and target views in one forward is better. Also, the V split uses the original image resolutions: we therefore set `--T 25 --L_short 576`. 109 | 110 | ## `img2trajvid_s-prob` 111 | 112 | ```bash 113 | python demo.py \ 114 | --data_path \ 115 | --task img2trajvid_s-prob \ 116 | --replace_or_include_input True \ 117 | --traj_prior orbit \ 118 | --cfg 4.0,2.0 \ 119 | --guider 1,2 \ 120 | --num_targets 111 \ 121 | --L_short 576 \ 122 | --use_traj_prior True \ 123 | --chunk_strategy interp 124 | ``` 125 | 126 | - `--replace_or_include_input True` is necessary here since input views and target views are mutually exclusive, forming a trajectory together in this task, so we need to append back the input views to the generated target views. 127 | - Default `cfg` should be adusted according to `traj_prior`. 128 | - Default chunking strategy is `interp`. 129 | - Default guider is `--guider 1,2` (instead of `1`, `1` still works but `1,2` is slightly better). 130 | - `camera_scale` (default is `2.0`) can be adjusted according to `traj_prior`. The model has scale ambiguity with single-view input, especially for panning motions. We encourage to tune up `camera_scale` to `10.0` for all panning motions (`--traj_prior pan-*/dolly*`) if you expect a larger camera motion. 131 | 132 | ## `img2trajvid` 133 | 134 | ### Sparse-view regime ($P\leq 8$) 135 | 136 | ```bash 137 | python demo.py \ 138 | --data_path \ 139 | --task img2trajvid \ 140 | --num_inputs

\ 141 | --cfg 3.0,2.0 \ 142 | --use_traj_prior True \ 143 | --chunk_strategy interp-gt 144 | ``` 145 | 146 | - `--num_inputs

` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder. 147 | - Default `cfg` should be set to `3,2` (`3` being `cfg` for the first pass, and `2` being the `cfg` for the second pass). Try to increase the `cfg` for the first pass from `3` to higher values if you observe blurry areas (usually happens for harder scenes with a fair amount of unseen regions). 148 | - Default chunking strategy should be set to `interp-gt` (instead of `interp`, `interp` can work but usually a bit worse). 149 | - The `--chunk_strategy_first_pass` is set as `gt-nearest` by default. So it can automatically adapt when $P$ is large (up to a thousand frames). 150 | 151 | ### Semi-dense-view regime ($P>9$) 152 | 153 | ```bash 154 | python demo.py \ 155 | --data_path \ 156 | --task img2trajvid \ 157 | --num_inputs

\ 158 | --cfg 3.0 \ 159 | --L_short 576 \ 160 | --use_traj_prior True \ 161 | --chunk_strategy interp 162 | ``` 163 | 164 | - `--num_inputs

` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder. 165 | - Default `cfg` should be set to `3`. 166 | - Default chunking strategy should be set to `interp` (instead of `interp-gt`, `interp-gt` is also supported but the results do not look good). 167 | - `T` can be overwritten by `--T ,21` (X being extended `T` for the first pass, and `21` being the default `T` for the second pass). `` is dynamically decided now in the code but can also be manually updated. This is useful when you observe that there exist two very dissimilar adjacent anchors which make the interpolation in the second pass impossible. There exist two ways: 168 | - `--T 96,21`: this overwrites the `T` in the first pass to be exactly `96`. 169 | - `--num_prior_frames_ratio 1.2`: this enlarges T in the first pass dynamically to be `1.2`$\times$ larger. 170 | -------------------------------------------------------------------------------- /docs/GR_USAGE.md: -------------------------------------------------------------------------------- 1 | # :rocket: Gradio Demo 2 | 3 | This gradio demo is the simplest starting point for you play with our project. 4 | 5 | You can either visit it at our huggingface space [here](https://huggingface.co/spaces/stabilityai/stable-virtual-camera) or run it locally yourself by 6 | 7 | ```bash 8 | python demo_gr.py 9 | ``` 10 | 11 | We provide two ways to use our demo: 12 | 13 | 1. `Basic` mode, where user can upload a single image, and set a target camera trajectory from our preset options. This is the most straightforward way to use our model, and is suitable for most users. 14 | 2. `Advanced` mode, where user can upload one or multiple images, and set a target camera trajectory by interacting with a 3D viewport (powered by [viser](https://viser.studio/latest)). This is suitable for power users and academic researchers. 15 | 16 | ### `Basic` 17 | 18 | This is the default mode when entering our demo (given its simplicity). 19 | 20 | User can upload a single image, and set a target camera trajectory from our preset options. This is the most straightforward way to use our model, and is suitable for most users. 21 | 22 | Here is a video walkthrough: 23 | 24 | https://github.com/user-attachments/assets/4d965fa6-d8eb-452c-b773-6e09c88ca705 25 | 26 | You can choose from 13 preset trajectories that are common for NVS (`move-forward/backward` are omitted for visualization purpose): 27 | 28 | https://github.com/user-attachments/assets/b2cf8700-3d85-44b9-8d52-248e82f1fb55 29 | 30 | More formally: 31 | 32 | - `orbit/spiral/lemniscate` are good for showing the "3D-ness" of the scene. 33 | - `zoom-in/out` keep the camera position the same while increasing/decreasing the focal length. 34 | - `dolly zoom-in/out` move camera position backward/forward while increasing/decreasing the focal length. 35 | - `move-forward/backward/up/down/left/right` move camera position in different directions. 36 | 37 | Notes: 38 | 39 | - For a 80 frame video at `786x576` resolution, it takes around 20 seconds for the first pass generation, and around 2 minutes for the second pass generation, tested with a single H100 GPU. 40 | - Please expect around ~2-3x more times on HF space. 41 | 42 | ### `Advanced` 43 | 44 | This is the power mode where you can have very fine-grained control over camera trajectories. 45 | 46 | User can upload one or multiple images, and set a target camera trajectory by interacting with a 3D viewport. This is suitable for power users and academic researchers. 47 | 48 | Here is a video walkthrough 49 | 50 | https://github.com/user-attachments/assets/dcec1be0-bd10-441e-879c-d1c2b63091ba 51 | 52 | Notes: 53 | 54 | - For a 134 frame video at `576x576` resolution, it takes around 16 seconds for the first pass generation, and around 4 minutes for the second pass generation, tested with a single H100 GPU. 55 | - Please expect around ~2-3x more times on HF space. 56 | 57 | ### Pro tips 58 | 59 | - If the first pass sampling result is bad, click "Abort rendering" button in GUI to avoid stucking at second pass sampling such that you can try something else. 60 | 61 | ### Performance benchmark 62 | 63 | We have tested our gradio demo in both a local environment and the HF space environment, across different modes and compilation settings. Here are our results: 64 | | Total time (s) | `Basic` first pass | `Basic` second pass | `Advanced` first pass | `Advanced` second pass | 65 | |:------------------------:|:-----------------:|:------------------:|:--------------------:|:---------------------:| 66 | | HF (L40S, w/o comp.) | 68 | 484 | 48 | 780 | 67 | | HF (L40S, w/ comp.) | 51 | 362 | 36 | 587 | 68 | | Local (H100, w/o comp.) | 35 | 204 | 20 | 313 | 69 | | Local (H100, w/ comp.) | 21 | 144 | 16 | 234 | 70 | 71 | Notes: 72 | 73 | - HF space uses L40S GPU, and our local environment uses H100 GPU. 74 | - We opt-in compilation by `torch.compile`. 75 | - `Basic` mode is tested by generating 80 frames at `768x576` resolution. 76 | - `Advanced` mode is tested by generating 134 frames at `576x576` resolution. 77 | -------------------------------------------------------------------------------- /docs/INSTALL.md: -------------------------------------------------------------------------------- 1 | # :wrench: Installation 2 | 3 | ### Model Dependencies 4 | 5 | ```bash 6 | # Install seva model dependencies. 7 | pip install -e . 8 | ``` 9 | 10 | ### Demo Dependencies 11 | 12 | To use the cli demo (`demo.py`) or the gradio demo (`demo_gr.py`), do the following: 13 | 14 | ```bash 15 | # Initialize and update submodules for demo. 16 | git submodule update --init --recursive 17 | 18 | # Install pycolmap dependencies for cli and gradio demo (our model is not dependent on it). 19 | echo "Installing pycolmap (for both cli and gradio demo)..." 20 | pip install git+https://github.com/jensenz-sai/pycolmap@543266bc316df2fe407b3a33d454b310b1641042 21 | 22 | # Install dust3r dependencies for gradio demo (our model is not dependent on it). 23 | echo "Installing dust3r dependencies (only for gradio demo)..." 24 | pushd third_party/dust3r 25 | pip install -r requirements.txt 26 | popd 27 | ``` 28 | 29 | ### Dev and Speeding Up (Optional) 30 | 31 | ```bash 32 | # [OPTIONAL] Install seva dependencies for development. 33 | pip install -e ".[dev]" 34 | pre-commit install 35 | 36 | # [OPTIONAL] Install the torch nightly version for faster JIT via. torch.compile (speed up sampling by 2x in our testing). 37 | # Please adjust to your own cuda version. For example, if you have cuda 11.8, use the following command. 38 | pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 39 | ``` 40 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=65.5.3"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "seva" 7 | version = "0.0.0" 8 | requires-python = ">=3.10" 9 | dependencies = [ 10 | "torch", 11 | "roma", 12 | "viser", 13 | "tyro", 14 | "fire", 15 | "ninja", 16 | "gradio==5.17.0", 17 | "einops", 18 | "colorama", 19 | "splines", 20 | "kornia", 21 | "open-clip-torch", 22 | "diffusers", 23 | "numpy==1.24.4", 24 | "imageio[ffmpeg]", 25 | "huggingface-hub", 26 | "opencv-python", 27 | ] 28 | 29 | [project.optional-dependencies] 30 | dev = ["ruff", "ipdb", "pytest", "line_profiler", "pre-commit"] 31 | 32 | [tool.setuptools.packages.find] 33 | include = ["seva"] 34 | 35 | [tool.pyright] 36 | extraPaths = ["third_party/dust3r"] 37 | 38 | [tool.ruff] 39 | lint.ignore = ["E741"] 40 | -------------------------------------------------------------------------------- /seva/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/seva/__init__.py -------------------------------------------------------------------------------- /seva/data_io.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | from glob import glob 5 | from typing import Any, Dict, List, Optional, Tuple 6 | 7 | import cv2 8 | import imageio.v3 as iio 9 | import numpy as np 10 | import torch 11 | 12 | from seva.geometry import ( 13 | align_principle_axes, 14 | similarity_from_cameras, 15 | transform_cameras, 16 | transform_points, 17 | ) 18 | 19 | 20 | def _get_rel_paths(path_dir: str) -> List[str]: 21 | """Recursively get relative paths of files in a directory.""" 22 | paths = [] 23 | for dp, _, fn in os.walk(path_dir): 24 | for f in fn: 25 | paths.append(os.path.relpath(os.path.join(dp, f), path_dir)) 26 | return paths 27 | 28 | 29 | class BaseParser(object): 30 | def __init__( 31 | self, 32 | data_dir: str, 33 | factor: int = 1, 34 | normalize: bool = False, 35 | test_every: Optional[int] = 8, 36 | ): 37 | self.data_dir = data_dir 38 | self.factor = factor 39 | self.normalize = normalize 40 | self.test_every = test_every 41 | 42 | self.image_names: List[str] = [] # (num_images,) 43 | self.image_paths: List[str] = [] # (num_images,) 44 | self.camtoworlds: np.ndarray = np.zeros((0, 4, 4)) # (num_images, 4, 4) 45 | self.camera_ids: List[int] = [] # (num_images,) 46 | self.Ks_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> K 47 | self.params_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> params 48 | self.imsize_dict: Dict[ 49 | int, Tuple[int, int] 50 | ] = {} # Dict of camera_id -> (width, height) 51 | self.points: np.ndarray = np.zeros((0, 3)) # (num_points, 3) 52 | self.points_err: np.ndarray = np.zeros((0,)) # (num_points,) 53 | self.points_rgb: np.ndarray = np.zeros((0, 3)) # (num_points, 3) 54 | self.point_indices: Dict[str, np.ndarray] = {} # Dict of image_name -> (M,) 55 | self.transform: np.ndarray = np.zeros((4, 4)) # (4, 4) 56 | 57 | self.mapx_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> (H, W) 58 | self.mapy_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> (H, W) 59 | self.roi_undist_dict: Dict[int, Tuple[int, int, int, int]] = ( 60 | dict() 61 | ) # Dict of camera_id -> (x, y, w, h) 62 | self.scene_scale: float = 1.0 63 | 64 | 65 | class DirectParser(BaseParser): 66 | def __init__( 67 | self, 68 | imgs: List[np.ndarray], 69 | c2ws: np.ndarray, 70 | Ks: np.ndarray, 71 | points: Optional[np.ndarray] = None, 72 | points_rgb: Optional[np.ndarray] = None, # uint8 73 | mono_disps: Optional[List[np.ndarray]] = None, 74 | normalize: bool = False, 75 | test_every: Optional[int] = None, 76 | ): 77 | super().__init__("", 1, normalize, test_every) 78 | 79 | self.image_names = [f"{i:06d}" for i in range(len(imgs))] 80 | self.image_paths = ["null" for _ in range(len(imgs))] 81 | self.camtoworlds = c2ws 82 | self.camera_ids = [i for i in range(len(imgs))] 83 | self.Ks_dict = {i: K for i, K in enumerate(Ks)} 84 | self.imsize_dict = { 85 | i: (img.shape[1], img.shape[0]) for i, img in enumerate(imgs) 86 | } 87 | if points is not None: 88 | self.points = points 89 | assert points_rgb is not None 90 | self.points_rgb = points_rgb 91 | self.points_err = np.zeros((len(points),)) 92 | 93 | self.imgs = imgs 94 | self.mono_disps = mono_disps 95 | 96 | # Normalize the world space. 97 | if normalize: 98 | T1 = similarity_from_cameras(self.camtoworlds) 99 | self.camtoworlds = transform_cameras(T1, self.camtoworlds) 100 | 101 | if points is not None: 102 | self.points = transform_points(T1, self.points) 103 | T2 = align_principle_axes(self.points) 104 | self.camtoworlds = transform_cameras(T2, self.camtoworlds) 105 | self.points = transform_points(T2, self.points) 106 | else: 107 | T2 = np.eye(4) 108 | 109 | self.transform = T2 @ T1 110 | else: 111 | self.transform = np.eye(4) 112 | 113 | # size of the scene measured by cameras 114 | camera_locations = self.camtoworlds[:, :3, 3] 115 | scene_center = np.mean(camera_locations, axis=0) 116 | dists = np.linalg.norm(camera_locations - scene_center, axis=1) 117 | self.scene_scale = np.max(dists) 118 | 119 | 120 | class COLMAPParser(BaseParser): 121 | """COLMAP parser.""" 122 | 123 | def __init__( 124 | self, 125 | data_dir: str, 126 | factor: int = 1, 127 | normalize: bool = False, 128 | test_every: Optional[int] = 8, 129 | image_folder: str = "images", 130 | colmap_folder: str = "sparse/0", 131 | ): 132 | super().__init__(data_dir, factor, normalize, test_every) 133 | 134 | colmap_dir = os.path.join(data_dir, colmap_folder) 135 | assert os.path.exists( 136 | colmap_dir 137 | ), f"COLMAP directory {colmap_dir} does not exist." 138 | 139 | try: 140 | from pycolmap import SceneManager 141 | except ImportError: 142 | raise ImportError( 143 | "Please install pycolmap to use the data parsers: " 144 | " `pip install git+https://github.com/jensenz-sai/pycolmap.git@543266bc316df2fe407b3a33d454b310b1641042`" 145 | ) 146 | 147 | manager = SceneManager(colmap_dir) 148 | manager.load_cameras() 149 | manager.load_images() 150 | manager.load_points3D() 151 | 152 | # Extract extrinsic matrices in world-to-camera format. 153 | imdata = manager.images 154 | w2c_mats = [] 155 | camera_ids = [] 156 | Ks_dict = dict() 157 | params_dict = dict() 158 | imsize_dict = dict() # width, height 159 | bottom = np.array([0, 0, 0, 1]).reshape(1, 4) 160 | for k in imdata: 161 | im = imdata[k] 162 | rot = im.R() 163 | trans = im.tvec.reshape(3, 1) 164 | w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0) 165 | w2c_mats.append(w2c) 166 | 167 | # support different camera intrinsics 168 | camera_id = im.camera_id 169 | camera_ids.append(camera_id) 170 | 171 | # camera intrinsics 172 | cam = manager.cameras[camera_id] 173 | fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy 174 | K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) 175 | K[:2, :] /= factor 176 | Ks_dict[camera_id] = K 177 | 178 | # Get distortion parameters. 179 | type_ = cam.camera_type 180 | if type_ == 0 or type_ == "SIMPLE_PINHOLE": 181 | params = np.empty(0, dtype=np.float32) 182 | camtype = "perspective" 183 | elif type_ == 1 or type_ == "PINHOLE": 184 | params = np.empty(0, dtype=np.float32) 185 | camtype = "perspective" 186 | if type_ == 2 or type_ == "SIMPLE_RADIAL": 187 | params = np.array([cam.k1, 0.0, 0.0, 0.0], dtype=np.float32) 188 | camtype = "perspective" 189 | elif type_ == 3 or type_ == "RADIAL": 190 | params = np.array([cam.k1, cam.k2, 0.0, 0.0], dtype=np.float32) 191 | camtype = "perspective" 192 | elif type_ == 4 or type_ == "OPENCV": 193 | params = np.array([cam.k1, cam.k2, cam.p1, cam.p2], dtype=np.float32) 194 | camtype = "perspective" 195 | elif type_ == 5 or type_ == "OPENCV_FISHEYE": 196 | params = np.array([cam.k1, cam.k2, cam.k3, cam.k4], dtype=np.float32) 197 | camtype = "fisheye" 198 | assert ( 199 | camtype == "perspective" # type: ignore 200 | ), f"Only support perspective camera model, got {type_}" 201 | 202 | params_dict[camera_id] = params # type: ignore 203 | 204 | # image size 205 | imsize_dict[camera_id] = (cam.width // factor, cam.height // factor) 206 | 207 | print( 208 | f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras." 209 | ) 210 | 211 | if len(imdata) == 0: 212 | raise ValueError("No images found in COLMAP.") 213 | if not (type_ == 0 or type_ == 1): # type: ignore 214 | print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.") 215 | 216 | w2c_mats = np.stack(w2c_mats, axis=0) 217 | 218 | # Convert extrinsics to camera-to-world. 219 | camtoworlds = np.linalg.inv(w2c_mats) 220 | 221 | # Image names from COLMAP. No need for permuting the poses according to 222 | # image names anymore. 223 | image_names = [imdata[k].name for k in imdata] 224 | 225 | # Previous Nerf results were generated with images sorted by filename, 226 | # ensure metrics are reported on the same test set. 227 | inds = np.argsort(image_names) 228 | image_names = [image_names[i] for i in inds] 229 | camtoworlds = camtoworlds[inds] 230 | camera_ids = [camera_ids[i] for i in inds] 231 | 232 | # Load images. 233 | if factor > 1: 234 | image_dir_suffix = f"_{factor}" 235 | else: 236 | image_dir_suffix = "" 237 | colmap_image_dir = os.path.join(data_dir, image_folder) 238 | image_dir = os.path.join(data_dir, image_folder + image_dir_suffix) 239 | for d in [image_dir, colmap_image_dir]: 240 | if not os.path.exists(d): 241 | raise ValueError(f"Image folder {d} does not exist.") 242 | 243 | # Downsampled images may have different names vs images used for COLMAP, 244 | # so we need to map between the two sorted lists of files. 245 | colmap_files = sorted(_get_rel_paths(colmap_image_dir)) 246 | image_files = sorted(_get_rel_paths(image_dir)) 247 | colmap_to_image = dict(zip(colmap_files, image_files)) 248 | image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names] 249 | 250 | # 3D points and {image_name -> [point_idx]} 251 | points = manager.points3D.astype(np.float32) # type: ignore 252 | points_err = manager.point3D_errors.astype(np.float32) # type: ignore 253 | points_rgb = manager.point3D_colors.astype(np.uint8) # type: ignore 254 | point_indices = dict() 255 | 256 | image_id_to_name = {v: k for k, v in manager.name_to_image_id.items()} 257 | for point_id, data in manager.point3D_id_to_images.items(): 258 | for image_id, _ in data: 259 | image_name = image_id_to_name[image_id] 260 | point_idx = manager.point3D_id_to_point3D_idx[point_id] 261 | point_indices.setdefault(image_name, []).append(point_idx) 262 | point_indices = { 263 | k: np.array(v).astype(np.int32) for k, v in point_indices.items() 264 | } 265 | 266 | # Normalize the world space. 267 | if normalize: 268 | T1 = similarity_from_cameras(camtoworlds) 269 | camtoworlds = transform_cameras(T1, camtoworlds) 270 | points = transform_points(T1, points) 271 | 272 | T2 = align_principle_axes(points) 273 | camtoworlds = transform_cameras(T2, camtoworlds) 274 | points = transform_points(T2, points) 275 | 276 | transform = T2 @ T1 277 | else: 278 | transform = np.eye(4) 279 | 280 | self.image_names = image_names # List[str], (num_images,) 281 | self.image_paths = image_paths # List[str], (num_images,) 282 | self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4) 283 | self.camera_ids = camera_ids # List[int], (num_images,) 284 | self.Ks_dict = Ks_dict # Dict of camera_id -> K 285 | self.params_dict = params_dict # Dict of camera_id -> params 286 | self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height) 287 | self.points = points # np.ndarray, (num_points, 3) 288 | self.points_err = points_err # np.ndarray, (num_points,) 289 | self.points_rgb = points_rgb # np.ndarray, (num_points, 3) 290 | self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,] 291 | self.transform = transform # np.ndarray, (4, 4) 292 | 293 | # undistortion 294 | self.mapx_dict = dict() 295 | self.mapy_dict = dict() 296 | self.roi_undist_dict = dict() 297 | for camera_id in self.params_dict.keys(): 298 | params = self.params_dict[camera_id] 299 | if len(params) == 0: 300 | continue # no distortion 301 | assert camera_id in self.Ks_dict, f"Missing K for camera {camera_id}" 302 | assert ( 303 | camera_id in self.params_dict 304 | ), f"Missing params for camera {camera_id}" 305 | K = self.Ks_dict[camera_id] 306 | width, height = self.imsize_dict[camera_id] 307 | K_undist, roi_undist = cv2.getOptimalNewCameraMatrix( 308 | K, params, (width, height), 0 309 | ) 310 | mapx, mapy = cv2.initUndistortRectifyMap( 311 | K, 312 | params, 313 | None, 314 | K_undist, 315 | (width, height), 316 | cv2.CV_32FC1, # type: ignore 317 | ) 318 | self.Ks_dict[camera_id] = K_undist 319 | self.mapx_dict[camera_id] = mapx 320 | self.mapy_dict[camera_id] = mapy 321 | self.roi_undist_dict[camera_id] = roi_undist # type: ignore 322 | 323 | # size of the scene measured by cameras 324 | camera_locations = camtoworlds[:, :3, 3] 325 | scene_center = np.mean(camera_locations, axis=0) 326 | dists = np.linalg.norm(camera_locations - scene_center, axis=1) 327 | self.scene_scale = np.max(dists) 328 | 329 | 330 | class ReconfusionParser(BaseParser): 331 | def __init__(self, data_dir: str, normalize: bool = False): 332 | super().__init__(data_dir, 1, normalize, test_every=None) 333 | 334 | def get_num(p): 335 | return p.split("_")[-1].removesuffix(".json") 336 | 337 | splits_per_num_input_frames = {} 338 | num_input_frames = [ 339 | int(get_num(p)) if get_num(p).isdigit() else get_num(p) 340 | for p in sorted(glob(osp.join(data_dir, "train_test_split_*.json"))) 341 | ] 342 | for num_input_frames in num_input_frames: 343 | with open( 344 | osp.join( 345 | data_dir, 346 | f"train_test_split_{num_input_frames}.json", 347 | ) 348 | ) as f: 349 | splits_per_num_input_frames[num_input_frames] = json.load(f) 350 | self.splits_per_num_input_frames = splits_per_num_input_frames 351 | 352 | with open(osp.join(data_dir, "transforms.json")) as f: 353 | metadata = json.load(f) 354 | 355 | image_names, image_paths, camtoworlds = [], [], [] 356 | for frame in metadata["frames"]: 357 | if frame["file_path"] is None: 358 | image_path = image_name = None 359 | else: 360 | image_path = osp.join(data_dir, frame["file_path"]) 361 | image_name = osp.basename(image_path) 362 | image_paths.append(image_path) 363 | image_names.append(image_name) 364 | camtoworld = np.array(frame["transform_matrix"]) 365 | if "applied_transform" in metadata: 366 | applied_transform = np.concatenate( 367 | [metadata["applied_transform"], [[0, 0, 0, 1]]], axis=0 368 | ) 369 | camtoworld = np.linalg.inv(applied_transform) @ camtoworld 370 | camtoworlds.append(camtoworld) 371 | camtoworlds = np.array(camtoworlds) 372 | camtoworlds[:, :, [1, 2]] *= -1 373 | 374 | # Normalize the world space. 375 | if normalize: 376 | T1 = similarity_from_cameras(camtoworlds) 377 | camtoworlds = transform_cameras(T1, camtoworlds) 378 | self.transform = T1 379 | else: 380 | self.transform = np.eye(4) 381 | 382 | self.image_names = image_names 383 | self.image_paths = image_paths 384 | self.camtoworlds = camtoworlds 385 | self.camera_ids = list(range(len(image_paths))) 386 | self.Ks_dict = { 387 | i: np.array( 388 | [ 389 | [ 390 | metadata.get("fl_x", frame.get("fl_x", None)), 391 | 0.0, 392 | metadata.get("cx", frame.get("cx", None)), 393 | ], 394 | [ 395 | 0.0, 396 | metadata.get("fl_y", frame.get("fl_y", None)), 397 | metadata.get("cy", frame.get("cy", None)), 398 | ], 399 | [0.0, 0.0, 1.0], 400 | ] 401 | ) 402 | for i, frame in enumerate(metadata["frames"]) 403 | } 404 | self.imsize_dict = { 405 | i: ( 406 | metadata.get("w", frame.get("w", None)), 407 | metadata.get("h", frame.get("h", None)), 408 | ) 409 | for i, frame in enumerate(metadata["frames"]) 410 | } 411 | # When num_input_frames is None, use all frames for both training and 412 | # testing. 413 | # self.splits_per_num_input_frames[None] = { 414 | # "train_ids": list(range(len(image_paths))), 415 | # "test_ids": list(range(len(image_paths))), 416 | # } 417 | 418 | # size of the scene measured by cameras 419 | camera_locations = camtoworlds[:, :3, 3] 420 | scene_center = np.mean(camera_locations, axis=0) 421 | dists = np.linalg.norm(camera_locations - scene_center, axis=1) 422 | self.scene_scale = np.max(dists) 423 | 424 | self.bounds = None 425 | if osp.exists(osp.join(data_dir, "bounds.npy")): 426 | self.bounds = np.load(osp.join(data_dir, "bounds.npy")) 427 | scaling = np.linalg.norm(self.transform[0, :3]) 428 | self.bounds = self.bounds / scaling 429 | 430 | 431 | class Dataset(torch.utils.data.Dataset): 432 | """A simple dataset class.""" 433 | 434 | def __init__( 435 | self, 436 | parser: BaseParser, 437 | split: str = "train", 438 | num_input_frames: Optional[int] = None, 439 | patch_size: Optional[int] = None, 440 | load_depths: bool = False, 441 | load_mono_disps: bool = False, 442 | ): 443 | self.parser = parser 444 | self.split = split 445 | self.num_input_frames = num_input_frames 446 | self.patch_size = patch_size 447 | self.load_depths = load_depths 448 | self.load_mono_disps = load_mono_disps 449 | if load_mono_disps: 450 | assert isinstance(parser, DirectParser) 451 | assert parser.mono_disps is not None 452 | if isinstance(parser, ReconfusionParser): 453 | ids_per_split = parser.splits_per_num_input_frames[num_input_frames] 454 | self.indices = ids_per_split[ 455 | "train_ids" if split == "train" else "test_ids" 456 | ] 457 | else: 458 | indices = np.arange(len(self.parser.image_names)) 459 | if split == "train": 460 | self.indices = ( 461 | indices[indices % self.parser.test_every != 0] 462 | if self.parser.test_every is not None 463 | else indices 464 | ) 465 | else: 466 | self.indices = ( 467 | indices[indices % self.parser.test_every == 0] 468 | if self.parser.test_every is not None 469 | else indices 470 | ) 471 | 472 | def __len__(self): 473 | return len(self.indices) 474 | 475 | def __getitem__(self, item: int) -> Dict[str, Any]: 476 | index = self.indices[item] 477 | if isinstance(self.parser, DirectParser): 478 | image = self.parser.imgs[index] 479 | else: 480 | image = iio.imread(self.parser.image_paths[index])[..., :3] 481 | camera_id = self.parser.camera_ids[index] 482 | K = self.parser.Ks_dict[camera_id].copy() # undistorted K 483 | params = self.parser.params_dict.get(camera_id, None) 484 | camtoworlds = self.parser.camtoworlds[index] 485 | 486 | x, y, w, h = 0, 0, image.shape[1], image.shape[0] 487 | if params is not None and len(params) > 0: 488 | # Images are distorted. Undistort them. 489 | mapx, mapy = ( 490 | self.parser.mapx_dict[camera_id], 491 | self.parser.mapy_dict[camera_id], 492 | ) 493 | image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR) 494 | x, y, w, h = self.parser.roi_undist_dict[camera_id] 495 | image = image[y : y + h, x : x + w] 496 | 497 | if self.patch_size is not None: 498 | # Random crop. 499 | h, w = image.shape[:2] 500 | x = np.random.randint(0, max(w - self.patch_size, 1)) 501 | y = np.random.randint(0, max(h - self.patch_size, 1)) 502 | image = image[y : y + self.patch_size, x : x + self.patch_size] 503 | K[0, 2] -= x 504 | K[1, 2] -= y 505 | 506 | data = { 507 | "K": torch.from_numpy(K).float(), 508 | "camtoworld": torch.from_numpy(camtoworlds).float(), 509 | "image": torch.from_numpy(image).float(), 510 | "image_id": item, # the index of the image in the dataset 511 | } 512 | 513 | if self.load_depths: 514 | # projected points to image plane to get depths 515 | worldtocams = np.linalg.inv(camtoworlds) 516 | image_name = self.parser.image_names[index] 517 | point_indices = self.parser.point_indices[image_name] 518 | points_world = self.parser.points[point_indices] 519 | points_cam = (worldtocams[:3, :3] @ points_world.T + worldtocams[:3, 3:4]).T 520 | points_proj = (K @ points_cam.T).T 521 | points = points_proj[:, :2] / points_proj[:, 2:3] # (M, 2) 522 | depths = points_cam[:, 2] # (M,) 523 | if self.patch_size is not None: 524 | points[:, 0] -= x 525 | points[:, 1] -= y 526 | # filter out points outside the image 527 | selector = ( 528 | (points[:, 0] >= 0) 529 | & (points[:, 0] < image.shape[1]) 530 | & (points[:, 1] >= 0) 531 | & (points[:, 1] < image.shape[0]) 532 | & (depths > 0) 533 | ) 534 | points = points[selector] 535 | depths = depths[selector] 536 | data["points"] = torch.from_numpy(points).float() 537 | data["depths"] = torch.from_numpy(depths).float() 538 | if self.load_mono_disps: 539 | data["mono_disps"] = torch.from_numpy(self.parser.mono_disps[index]).float() # type: ignore 540 | 541 | return data 542 | 543 | 544 | def get_parser(parser_type: str, **kwargs) -> BaseParser: 545 | if parser_type == "colmap": 546 | parser = COLMAPParser(**kwargs) 547 | elif parser_type == "direct": 548 | parser = DirectParser(**kwargs) 549 | elif parser_type == "reconfusion": 550 | parser = ReconfusionParser(**kwargs) 551 | else: 552 | raise ValueError(f"Unknown parser type: {parser_type}") 553 | return parser 554 | -------------------------------------------------------------------------------- /seva/geometry.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import numpy as np 4 | import roma 5 | import scipy.interpolate 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | DEFAULT_FOV_RAD = 0.9424777960769379 # 54 degrees by default 10 | 11 | 12 | def get_camera_dist( 13 | source_c2ws: torch.Tensor, # N x 3 x 4 14 | target_c2ws: torch.Tensor, # M x 3 x 4 15 | mode: str = "translation", 16 | ): 17 | if mode == "rotation": 18 | dists = torch.acos( 19 | ( 20 | ( 21 | torch.matmul( 22 | source_c2ws[:, None, :3, :3], 23 | target_c2ws[None, :, :3, :3].transpose(-1, -2), 24 | ) 25 | .diagonal(offset=0, dim1=-2, dim2=-1) 26 | .sum(-1) 27 | - 1 28 | ) 29 | / 2 30 | ).clamp(-1, 1) 31 | ) * (180 / torch.pi) 32 | elif mode == "translation": 33 | dists = torch.norm( 34 | source_c2ws[:, None, :3, 3] - target_c2ws[None, :, :3, 3], dim=-1 35 | ) 36 | else: 37 | raise NotImplementedError( 38 | f"Mode {mode} is not implemented for finding nearest source indices." 39 | ) 40 | return dists 41 | 42 | 43 | def to_hom(X): 44 | # get homogeneous coordinates of the input 45 | X_hom = torch.cat([X, torch.ones_like(X[..., :1])], dim=-1) 46 | return X_hom 47 | 48 | 49 | def to_hom_pose(pose): 50 | # get homogeneous coordinates of the input pose 51 | if pose.shape[-2:] == (3, 4): 52 | pose_hom = torch.eye(4, device=pose.device)[None].repeat(pose.shape[0], 1, 1) 53 | pose_hom[:, :3, :] = pose 54 | return pose_hom 55 | return pose 56 | 57 | 58 | def get_default_intrinsics( 59 | fov_rad=DEFAULT_FOV_RAD, 60 | aspect_ratio=1.0, 61 | ): 62 | if not isinstance(fov_rad, torch.Tensor): 63 | fov_rad = torch.tensor( 64 | [fov_rad] if isinstance(fov_rad, (int, float)) else fov_rad 65 | ) 66 | if aspect_ratio >= 1.0: # W >= H 67 | focal_x = 0.5 / torch.tan(0.5 * fov_rad) 68 | focal_y = focal_x * aspect_ratio 69 | else: # W < H 70 | focal_y = 0.5 / torch.tan(0.5 * fov_rad) 71 | focal_x = focal_y / aspect_ratio 72 | intrinsics = focal_x.new_zeros((focal_x.shape[0], 3, 3)) 73 | intrinsics[:, torch.eye(3, device=focal_x.device, dtype=bool)] = torch.stack( 74 | [focal_x, focal_y, torch.ones_like(focal_x)], dim=-1 75 | ) 76 | intrinsics[:, :, -1] = torch.tensor( 77 | [0.5, 0.5, 1.0], device=focal_x.device, dtype=focal_x.dtype 78 | ) 79 | return intrinsics 80 | 81 | 82 | def get_image_grid(img_h, img_w): 83 | # add 0.5 is VERY important especially when your img_h and img_w 84 | # is not very large (e.g., 72)!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 85 | y_range = torch.arange(img_h, dtype=torch.float32).add_(0.5) 86 | x_range = torch.arange(img_w, dtype=torch.float32).add_(0.5) 87 | Y, X = torch.meshgrid(y_range, x_range, indexing="ij") # [H,W] 88 | xy_grid = torch.stack([X, Y], dim=-1).view(-1, 2) # [HW,2] 89 | return to_hom(xy_grid) # [HW,3] 90 | 91 | 92 | def img2cam(X, cam_intr): 93 | return X @ cam_intr.inverse().transpose(-1, -2) 94 | 95 | 96 | def cam2world(X, pose): 97 | X_hom = to_hom(X) 98 | pose_inv = torch.linalg.inv(to_hom_pose(pose))[..., :3, :4] 99 | return X_hom @ pose_inv.transpose(-1, -2) 100 | 101 | 102 | def get_center_and_ray(img_h, img_w, pose, intr): # [HW,2] 103 | # given the intrinsic/extrinsic matrices, get the camera center and ray directions] 104 | # assert(opt.camera.model=="perspective") 105 | 106 | # compute center and ray 107 | grid_img = get_image_grid(img_h, img_w) # [HW,3] 108 | grid_3D_cam = img2cam(grid_img.to(intr.device), intr.float()) # [B,HW,3] 109 | center_3D_cam = torch.zeros_like(grid_3D_cam) # [B,HW,3] 110 | 111 | # transform from camera to world coordinates 112 | grid_3D = cam2world(grid_3D_cam, pose) # [B,HW,3] 113 | center_3D = cam2world(center_3D_cam, pose) # [B,HW,3] 114 | ray = grid_3D - center_3D # [B,HW,3] 115 | 116 | return center_3D, ray, grid_3D_cam 117 | 118 | 119 | def get_plucker_coordinates( 120 | extrinsics_src, 121 | extrinsics, 122 | intrinsics=None, 123 | fov_rad=DEFAULT_FOV_RAD, 124 | target_size=[72, 72], 125 | ): 126 | if intrinsics is None: 127 | intrinsics = get_default_intrinsics(fov_rad).to(extrinsics.device) 128 | else: 129 | if not ( 130 | torch.all(intrinsics[:, :2, -1] >= 0) 131 | and torch.all(intrinsics[:, :2, -1] <= 1) 132 | ): 133 | intrinsics[:, :2] /= intrinsics.new_tensor(target_size).view(1, -1, 1) * 8 134 | # you should ensure the intrisics are expressed in 135 | # resolution-independent normalized image coordinates just performing a 136 | # very simple verification here checking if principal points are 137 | # between 0 and 1 138 | assert ( 139 | torch.all(intrinsics[:, :2, -1] >= 0) 140 | and torch.all(intrinsics[:, :2, -1] <= 1) 141 | ), "Intrinsics should be expressed in resolution-independent normalized image coordinates." 142 | 143 | c2w_src = torch.linalg.inv(extrinsics_src) 144 | # transform coordinates from the source camera's coordinate system to the coordinate system of the respective camera 145 | extrinsics_rel = torch.einsum( 146 | "vnm,vmp->vnp", extrinsics, c2w_src[None].repeat(extrinsics.shape[0], 1, 1) 147 | ) 148 | 149 | intrinsics[:, :2] *= extrinsics.new_tensor( 150 | [ 151 | target_size[1], # w 152 | target_size[0], # h 153 | ] 154 | ).view(1, -1, 1) 155 | centers, rays, grid_cam = get_center_and_ray( 156 | img_h=target_size[0], 157 | img_w=target_size[1], 158 | pose=extrinsics_rel[:, :3, :], 159 | intr=intrinsics, 160 | ) 161 | 162 | rays = torch.nn.functional.normalize(rays, dim=-1) 163 | plucker = torch.cat((rays, torch.cross(centers, rays, dim=-1)), dim=-1) 164 | plucker = plucker.permute(0, 2, 1).reshape(plucker.shape[0], -1, *target_size) 165 | return plucker 166 | 167 | 168 | def rt_to_mat4( 169 | R: torch.Tensor, t: torch.Tensor, s: torch.Tensor | None = None 170 | ) -> torch.Tensor: 171 | """ 172 | Args: 173 | R (torch.Tensor): (..., 3, 3). 174 | t (torch.Tensor): (..., 3). 175 | s (torch.Tensor): (...,). 176 | 177 | Returns: 178 | torch.Tensor: (..., 4, 4) 179 | """ 180 | mat34 = torch.cat([R, t[..., None]], dim=-1) 181 | if s is None: 182 | bottom = ( 183 | mat34.new_tensor([[0.0, 0.0, 0.0, 1.0]]) 184 | .reshape((1,) * (mat34.dim() - 2) + (1, 4)) 185 | .expand(mat34.shape[:-2] + (1, 4)) 186 | ) 187 | else: 188 | bottom = F.pad(1.0 / s[..., None, None], (3, 0), value=0.0) 189 | mat4 = torch.cat([mat34, bottom], dim=-2) 190 | return mat4 191 | 192 | 193 | def get_preset_pose_fov( 194 | option: Literal[ 195 | "orbit", 196 | "spiral", 197 | "lemniscate", 198 | "zoom-in", 199 | "zoom-out", 200 | "dolly zoom-in", 201 | "dolly zoom-out", 202 | "move-forward", 203 | "move-backward", 204 | "move-up", 205 | "move-down", 206 | "move-left", 207 | "move-right", 208 | "roll", 209 | ], 210 | num_frames: int, 211 | start_w2c: torch.Tensor, 212 | look_at: torch.Tensor, 213 | up_direction: torch.Tensor | None = None, 214 | fov: float = DEFAULT_FOV_RAD, 215 | spiral_radii: list[float] = [0.5, 0.5, 0.2], 216 | zoom_factor: float | None = None, 217 | ): 218 | poses = fovs = None 219 | if option == "orbit": 220 | poses = torch.linalg.inv( 221 | get_arc_horizontal_w2cs( 222 | start_w2c, 223 | look_at, 224 | up_direction, 225 | num_frames=num_frames, 226 | endpoint=False, 227 | ) 228 | ).numpy() 229 | fovs = np.full((num_frames,), fov) 230 | elif option == "spiral": 231 | poses = generate_spiral_path( 232 | torch.linalg.inv(start_w2c)[None].numpy() @ np.diagflat([1, -1, -1, 1]), 233 | np.array([1, 5]), 234 | n_frames=num_frames, 235 | n_rots=2, 236 | zrate=0.5, 237 | radii=spiral_radii, 238 | endpoint=False, 239 | ) @ np.diagflat([1, -1, -1, 1]) 240 | poses = np.concatenate( 241 | [ 242 | poses, 243 | np.array([0.0, 0.0, 0.0, 1.0])[None, None].repeat(len(poses), 0), 244 | ], 245 | 1, 246 | ) 247 | # We want the spiral trajectory to always start from start_w2c. Thus we 248 | # apply the relative pose to get the final trajectory. 249 | poses = ( 250 | np.linalg.inv(start_w2c.numpy())[None] @ np.linalg.inv(poses[:1]) @ poses 251 | ) 252 | fovs = np.full((num_frames,), fov) 253 | elif option == "lemniscate": 254 | poses = torch.linalg.inv( 255 | get_lemniscate_w2cs( 256 | start_w2c, 257 | look_at, 258 | up_direction, 259 | num_frames, 260 | degree=60.0, 261 | endpoint=False, 262 | ) 263 | ).numpy() 264 | fovs = np.full((num_frames,), fov) 265 | elif option == "roll": 266 | poses = torch.linalg.inv( 267 | get_roll_w2cs( 268 | start_w2c, 269 | look_at, 270 | None, 271 | num_frames, 272 | degree=360.0, 273 | endpoint=False, 274 | ) 275 | ).numpy() 276 | fovs = np.full((num_frames,), fov) 277 | elif option in [ 278 | "dolly zoom-in", 279 | "dolly zoom-out", 280 | "zoom-in", 281 | "zoom-out", 282 | ]: 283 | if option.startswith("dolly"): 284 | direction = "backward" if option == "dolly zoom-in" else "forward" 285 | poses = torch.linalg.inv( 286 | get_moving_w2cs( 287 | start_w2c, 288 | look_at, 289 | up_direction, 290 | num_frames, 291 | endpoint=True, 292 | direction=direction, 293 | ) 294 | ).numpy() 295 | else: 296 | poses = torch.linalg.inv(start_w2c)[None].repeat(num_frames, 1, 1).numpy() 297 | fov_rad_start = fov 298 | if zoom_factor is None: 299 | zoom_factor = 0.28 if option.endswith("zoom-in") else 1.5 300 | fov_rad_end = zoom_factor * fov 301 | fovs = ( 302 | np.linspace(0, 1, num_frames) * (fov_rad_end - fov_rad_start) 303 | + fov_rad_start 304 | ) 305 | elif option in [ 306 | "move-forward", 307 | "move-backward", 308 | "move-up", 309 | "move-down", 310 | "move-left", 311 | "move-right", 312 | ]: 313 | poses = torch.linalg.inv( 314 | get_moving_w2cs( 315 | start_w2c, 316 | look_at, 317 | up_direction, 318 | num_frames, 319 | endpoint=True, 320 | direction=option.removeprefix("move-"), 321 | ) 322 | ).numpy() 323 | fovs = np.full((num_frames,), fov) 324 | else: 325 | raise ValueError(f"Unknown preset option {option}.") 326 | 327 | return poses, fovs 328 | 329 | 330 | def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor: 331 | """Triangulate a set of rays to find a single lookat point. 332 | 333 | Args: 334 | origins (torch.Tensor): A (N, 3) array of ray origins. 335 | viewdirs (torch.Tensor): A (N, 3) array of ray view directions. 336 | 337 | Returns: 338 | torch.Tensor: A (3,) lookat point. 339 | """ 340 | 341 | viewdirs = torch.nn.functional.normalize(viewdirs, dim=-1) 342 | eye = torch.eye(3, device=origins.device, dtype=origins.dtype)[None] 343 | # Calculate projection matrix I - rr^T 344 | I_min_cov = eye - (viewdirs[..., None] * viewdirs[..., None, :]) 345 | # Compute sum of projections 346 | sum_proj = I_min_cov.matmul(origins[..., None]).sum(dim=-3) 347 | # Solve for the intersection point using least squares 348 | lookat = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] 349 | # Check NaNs. 350 | assert not torch.any(torch.isnan(lookat)) 351 | return lookat 352 | 353 | 354 | def get_lookat_w2cs( 355 | positions: torch.Tensor, 356 | lookat: torch.Tensor, 357 | up: torch.Tensor, 358 | face_off: bool = False, 359 | ): 360 | """ 361 | Args: 362 | positions: (N, 3) tensor of camera positions 363 | lookat: (3,) tensor of lookat point 364 | up: (3,) or (N, 3) tensor of up vector 365 | 366 | Returns: 367 | w2cs: (N, 3, 3) tensor of world to camera rotation matrices 368 | """ 369 | forward_vectors = F.normalize(lookat - positions, dim=-1) 370 | if face_off: 371 | forward_vectors = -forward_vectors 372 | if up.dim() == 1: 373 | up = up[None] 374 | right_vectors = F.normalize(torch.cross(forward_vectors, up, dim=-1), dim=-1) 375 | down_vectors = F.normalize( 376 | torch.cross(forward_vectors, right_vectors, dim=-1), dim=-1 377 | ) 378 | Rs = torch.stack([right_vectors, down_vectors, forward_vectors], dim=-1) 379 | w2cs = torch.linalg.inv(rt_to_mat4(Rs, positions)) 380 | return w2cs 381 | 382 | 383 | def get_arc_horizontal_w2cs( 384 | ref_w2c: torch.Tensor, 385 | lookat: torch.Tensor, 386 | up: torch.Tensor | None, 387 | num_frames: int, 388 | clockwise: bool = True, 389 | face_off: bool = False, 390 | endpoint: bool = False, 391 | degree: float = 360.0, 392 | ref_up_shift: float = 0.0, 393 | ref_radius_scale: float = 1.0, 394 | **_, 395 | ) -> torch.Tensor: 396 | ref_c2w = torch.linalg.inv(ref_w2c) 397 | ref_position = ref_c2w[:3, 3] 398 | if up is None: 399 | up = -ref_c2w[:3, 1] 400 | assert up is not None 401 | ref_position += up * ref_up_shift 402 | ref_position *= ref_radius_scale 403 | thetas = ( 404 | torch.linspace(0.0, torch.pi * degree / 180, num_frames, device=ref_w2c.device) 405 | if endpoint 406 | else torch.linspace( 407 | 0.0, torch.pi * degree / 180, num_frames + 1, device=ref_w2c.device 408 | )[:-1] 409 | ) 410 | if not clockwise: 411 | thetas = -thetas 412 | positions = ( 413 | torch.einsum( 414 | "nij,j->ni", 415 | roma.rotvec_to_rotmat(thetas[:, None] * up[None]), 416 | ref_position - lookat, 417 | ) 418 | + lookat 419 | ) 420 | return get_lookat_w2cs(positions, lookat, up, face_off=face_off) 421 | 422 | 423 | def get_lemniscate_w2cs( 424 | ref_w2c: torch.Tensor, 425 | lookat: torch.Tensor, 426 | up: torch.Tensor | None, 427 | num_frames: int, 428 | degree: float, 429 | endpoint: bool = False, 430 | **_, 431 | ) -> torch.Tensor: 432 | ref_c2w = torch.linalg.inv(ref_w2c) 433 | a = torch.linalg.norm(ref_c2w[:3, 3] - lookat) * np.tan(degree / 360 * np.pi) 434 | # Lemniscate curve in camera space. Starting at the origin. 435 | thetas = ( 436 | torch.linspace(0, 2 * torch.pi, num_frames, device=ref_w2c.device) 437 | if endpoint 438 | else torch.linspace(0, 2 * torch.pi, num_frames + 1, device=ref_w2c.device)[:-1] 439 | ) + torch.pi / 2 440 | positions = torch.stack( 441 | [ 442 | a * torch.cos(thetas) / (1 + torch.sin(thetas) ** 2), 443 | a * torch.cos(thetas) * torch.sin(thetas) / (1 + torch.sin(thetas) ** 2), 444 | torch.zeros(num_frames, device=ref_w2c.device), 445 | ], 446 | dim=-1, 447 | ) 448 | # Transform to world space. 449 | positions = torch.einsum( 450 | "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0) 451 | ) 452 | if up is None: 453 | up = -ref_c2w[:3, 1] 454 | assert up is not None 455 | return get_lookat_w2cs(positions, lookat, up) 456 | 457 | 458 | def get_moving_w2cs( 459 | ref_w2c: torch.Tensor, 460 | lookat: torch.Tensor, 461 | up: torch.Tensor | None, 462 | num_frames: int, 463 | endpoint: bool = False, 464 | direction: str = "forward", 465 | tilt_xy: torch.Tensor = None, 466 | ): 467 | """ 468 | Args: 469 | ref_w2c: (4, 4) tensor of the reference wolrd-to-camera matrix 470 | lookat: (3,) tensor of lookat point 471 | up: (3,) tensor of up vector 472 | 473 | Returns: 474 | w2cs: (N, 3, 3) tensor of world to camera rotation matrices 475 | """ 476 | ref_c2w = torch.linalg.inv(ref_w2c) 477 | ref_position = ref_c2w[:3, -1] 478 | if up is None: 479 | up = -ref_c2w[:3, 1] 480 | 481 | direction_vectors = { 482 | "forward": (lookat - ref_position).clone(), 483 | "backward": -(lookat - ref_position).clone(), 484 | "up": up.clone(), 485 | "down": -up.clone(), 486 | "right": torch.cross((lookat - ref_position), up, dim=0), 487 | "left": -torch.cross((lookat - ref_position), up, dim=0), 488 | } 489 | if direction not in direction_vectors: 490 | raise ValueError( 491 | f"Invalid direction: {direction}. Must be one of {list(direction_vectors.keys())}" 492 | ) 493 | 494 | positions = ref_position + ( 495 | F.normalize(direction_vectors[direction], dim=0) 496 | * ( 497 | torch.linspace(0, 0.99, num_frames, device=ref_w2c.device) 498 | if endpoint 499 | else torch.linspace(0, 1, num_frames + 1, device=ref_w2c.device)[:-1] 500 | )[:, None] 501 | ) 502 | 503 | if tilt_xy is not None: 504 | positions[:, :2] += tilt_xy 505 | 506 | return get_lookat_w2cs(positions, lookat, up) 507 | 508 | 509 | def get_roll_w2cs( 510 | ref_w2c: torch.Tensor, 511 | lookat: torch.Tensor, 512 | up: torch.Tensor | None, 513 | num_frames: int, 514 | endpoint: bool = False, 515 | degree: float = 360.0, 516 | **_, 517 | ) -> torch.Tensor: 518 | ref_c2w = torch.linalg.inv(ref_w2c) 519 | ref_position = ref_c2w[:3, 3] 520 | if up is None: 521 | up = -ref_c2w[:3, 1] # Infer the up vector from the reference. 522 | 523 | # Create vertical angles 524 | thetas = ( 525 | torch.linspace(0.0, torch.pi * degree / 180, num_frames, device=ref_w2c.device) 526 | if endpoint 527 | else torch.linspace( 528 | 0.0, torch.pi * degree / 180, num_frames + 1, device=ref_w2c.device 529 | )[:-1] 530 | )[:, None] 531 | 532 | lookat_vector = F.normalize(lookat[None].float(), dim=-1) 533 | up = up[None] 534 | up = ( 535 | up * torch.cos(thetas) 536 | + torch.cross(lookat_vector, up) * torch.sin(thetas) 537 | + lookat_vector 538 | * torch.einsum("ij,ij->i", lookat_vector, up)[:, None] 539 | * (1 - torch.cos(thetas)) 540 | ) 541 | 542 | # Normalize the camera orientation 543 | return get_lookat_w2cs(ref_position[None].repeat(num_frames, 1), lookat, up) 544 | 545 | 546 | def normalize(x): 547 | """Normalization helper function.""" 548 | return x / np.linalg.norm(x) 549 | 550 | 551 | def viewmatrix(lookdir, up, position, subtract_position=False): 552 | """Construct lookat view matrix.""" 553 | vec2 = normalize((lookdir - position) if subtract_position else lookdir) 554 | vec0 = normalize(np.cross(up, vec2)) 555 | vec1 = normalize(np.cross(vec2, vec0)) 556 | m = np.stack([vec0, vec1, vec2, position], axis=1) 557 | return m 558 | 559 | 560 | def poses_avg(poses): 561 | """New pose using average position, z-axis, and up vector of input poses.""" 562 | position = poses[:, :3, 3].mean(0) 563 | z_axis = poses[:, :3, 2].mean(0) 564 | up = poses[:, :3, 1].mean(0) 565 | cam2world = viewmatrix(z_axis, up, position) 566 | return cam2world 567 | 568 | 569 | def generate_spiral_path( 570 | poses, bounds, n_frames=120, n_rots=2, zrate=0.5, endpoint=False, radii=None 571 | ): 572 | """Calculates a forward facing spiral path for rendering.""" 573 | # Find a reasonable 'focus depth' for this dataset as a weighted average 574 | # of near and far bounds in disparity space. 575 | close_depth, inf_depth = bounds.min() * 0.9, bounds.max() * 5.0 576 | dt = 0.75 577 | focal = 1 / ((1 - dt) / close_depth + dt / inf_depth) 578 | 579 | # Get radii for spiral path using 90th percentile of camera positions. 580 | positions = poses[:, :3, 3] 581 | if radii is None: 582 | radii = np.percentile(np.abs(positions), 90, 0) 583 | radii = np.concatenate([radii, [1.0]]) 584 | 585 | # Generate poses for spiral path. 586 | render_poses = [] 587 | cam2world = poses_avg(poses) 588 | up = poses[:, :3, 1].mean(0) 589 | for theta in np.linspace(0.0, 2.0 * np.pi * n_rots, n_frames, endpoint=endpoint): 590 | t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0] 591 | position = cam2world @ t 592 | lookat = cam2world @ [0, 0, -focal, 1.0] 593 | z_axis = position - lookat 594 | render_poses.append(viewmatrix(z_axis, up, position)) 595 | render_poses = np.stack(render_poses, axis=0) 596 | return render_poses 597 | 598 | 599 | def generate_interpolated_path( 600 | poses: np.ndarray, 601 | n_interp: int, 602 | spline_degree: int = 5, 603 | smoothness: float = 0.03, 604 | rot_weight: float = 0.1, 605 | endpoint: bool = False, 606 | ): 607 | """Creates a smooth spline path between input keyframe camera poses. 608 | 609 | Spline is calculated with poses in format (position, lookat-point, up-point). 610 | 611 | Args: 612 | poses: (n, 3, 4) array of input pose keyframes. 613 | n_interp: returned path will have n_interp * (n - 1) total poses. 614 | spline_degree: polynomial degree of B-spline. 615 | smoothness: parameter for spline smoothing, 0 forces exact interpolation. 616 | rot_weight: relative weighting of rotation/translation in spline solve. 617 | 618 | Returns: 619 | Array of new camera poses with shape (n_interp * (n - 1), 3, 4). 620 | """ 621 | 622 | def poses_to_points(poses, dist): 623 | """Converts from pose matrices to (position, lookat, up) format.""" 624 | pos = poses[:, :3, -1] 625 | lookat = poses[:, :3, -1] - dist * poses[:, :3, 2] 626 | up = poses[:, :3, -1] + dist * poses[:, :3, 1] 627 | return np.stack([pos, lookat, up], 1) 628 | 629 | def points_to_poses(points): 630 | """Converts from (position, lookat, up) format to pose matrices.""" 631 | return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points]) 632 | 633 | def interp(points, n, k, s): 634 | """Runs multidimensional B-spline interpolation on the input points.""" 635 | sh = points.shape 636 | pts = np.reshape(points, (sh[0], -1)) 637 | k = min(k, sh[0] - 1) 638 | tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s) 639 | u = np.linspace(0, 1, n, endpoint=endpoint) 640 | new_points = np.array(scipy.interpolate.splev(u, tck)) 641 | new_points = np.reshape(new_points.T, (n, sh[1], sh[2])) 642 | return new_points 643 | 644 | points = poses_to_points(poses, dist=rot_weight) 645 | new_points = interp( 646 | points, n_interp * (points.shape[0] - 1), k=spline_degree, s=smoothness 647 | ) 648 | return points_to_poses(new_points) 649 | 650 | 651 | def similarity_from_cameras(c2w, strict_scaling=False, center_method="focus"): 652 | """ 653 | reference: nerf-factory 654 | Get a similarity transform to normalize dataset 655 | from c2w (OpenCV convention) cameras 656 | :param c2w: (N, 4) 657 | :return T (4,4) , scale (float) 658 | """ 659 | t = c2w[:, :3, 3] 660 | R = c2w[:, :3, :3] 661 | 662 | # (1) Rotate the world so that z+ is the up axis 663 | # we estimate the up axis by averaging the camera up axes 664 | ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1) 665 | world_up = np.mean(ups, axis=0) 666 | world_up /= np.linalg.norm(world_up) 667 | 668 | up_camspace = np.array([0.0, -1.0, 0.0]) 669 | c = (up_camspace * world_up).sum() 670 | cross = np.cross(world_up, up_camspace) 671 | skew = np.array( 672 | [ 673 | [0.0, -cross[2], cross[1]], 674 | [cross[2], 0.0, -cross[0]], 675 | [-cross[1], cross[0], 0.0], 676 | ] 677 | ) 678 | if c > -1: 679 | R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c) 680 | else: 681 | # In the unlikely case the original data has y+ up axis, 682 | # rotate 180-deg about x axis 683 | R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) 684 | 685 | # R_align = np.eye(3) # DEBUG 686 | R = R_align @ R 687 | fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1) 688 | t = (R_align @ t[..., None])[..., 0] 689 | 690 | # (2) Recenter the scene. 691 | if center_method == "focus": 692 | # find the closest point to the origin for each camera's center ray 693 | nearest = t + (fwds * -t).sum(-1)[:, None] * fwds 694 | translate = -np.median(nearest, axis=0) 695 | elif center_method == "poses": 696 | # use center of the camera positions 697 | translate = -np.median(t, axis=0) 698 | else: 699 | raise ValueError(f"Unknown center_method {center_method}") 700 | 701 | transform = np.eye(4) 702 | transform[:3, 3] = translate 703 | transform[:3, :3] = R_align 704 | 705 | # (3) Rescale the scene using camera distances 706 | scale_fn = np.max if strict_scaling else np.median 707 | inv_scale = scale_fn(np.linalg.norm(t + translate, axis=-1)) 708 | if inv_scale == 0: 709 | inv_scale = 1.0 710 | scale = 1.0 / inv_scale 711 | transform[:3, :] *= scale 712 | 713 | return transform 714 | 715 | 716 | def align_principle_axes(point_cloud): 717 | # Compute centroid 718 | centroid = np.median(point_cloud, axis=0) 719 | 720 | # Translate point cloud to centroid 721 | translated_point_cloud = point_cloud - centroid 722 | 723 | # Compute covariance matrix 724 | covariance_matrix = np.cov(translated_point_cloud, rowvar=False) 725 | 726 | # Compute eigenvectors and eigenvalues 727 | eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix) 728 | 729 | # Sort eigenvectors by eigenvalues (descending order) so that the z-axis 730 | # is the principal axis with the smallest eigenvalue. 731 | sort_indices = eigenvalues.argsort()[::-1] 732 | eigenvectors = eigenvectors[:, sort_indices] 733 | 734 | # Check orientation of eigenvectors. If the determinant of the eigenvectors is 735 | # negative, then we need to flip the sign of one of the eigenvectors. 736 | if np.linalg.det(eigenvectors) < 0: 737 | eigenvectors[:, 0] *= -1 738 | 739 | # Create rotation matrix 740 | rotation_matrix = eigenvectors.T 741 | 742 | # Create SE(3) matrix (4x4 transformation matrix) 743 | transform = np.eye(4) 744 | transform[:3, :3] = rotation_matrix 745 | transform[:3, 3] = -rotation_matrix @ centroid 746 | 747 | return transform 748 | 749 | 750 | def transform_points(matrix, points): 751 | """Transform points using a SE(4) matrix. 752 | 753 | Args: 754 | matrix: 4x4 SE(4) matrix 755 | points: Nx3 array of points 756 | 757 | Returns: 758 | Nx3 array of transformed points 759 | """ 760 | assert matrix.shape == (4, 4) 761 | assert len(points.shape) == 2 and points.shape[1] == 3 762 | return points @ matrix[:3, :3].T + matrix[:3, 3] 763 | 764 | 765 | def transform_cameras(matrix, camtoworlds): 766 | """Transform cameras using a SE(4) matrix. 767 | 768 | Args: 769 | matrix: 4x4 SE(4) matrix 770 | camtoworlds: Nx4x4 array of camera-to-world matrices 771 | 772 | Returns: 773 | Nx4x4 array of transformed camera-to-world matrices 774 | """ 775 | assert matrix.shape == (4, 4) 776 | assert len(camtoworlds.shape) == 3 and camtoworlds.shape[1:] == (4, 4) 777 | camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, matrix) 778 | scaling = np.linalg.norm(camtoworlds[:, 0, :3], axis=1) 779 | camtoworlds[:, :3, :3] = camtoworlds[:, :3, :3] / scaling[:, None, None] 780 | return camtoworlds 781 | 782 | 783 | def normalize_scene(camtoworlds, points=None, camera_center_method="focus"): 784 | T1 = similarity_from_cameras(camtoworlds, center_method=camera_center_method) 785 | camtoworlds = transform_cameras(T1, camtoworlds) 786 | if points is not None: 787 | points = transform_points(T1, points) 788 | T2 = align_principle_axes(points) 789 | camtoworlds = transform_cameras(T2, camtoworlds) 790 | points = transform_points(T2, points) 791 | return camtoworlds, points, T2 @ T1 792 | else: 793 | return camtoworlds, T1 794 | -------------------------------------------------------------------------------- /seva/gui.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | import dataclasses 3 | import threading 4 | import time 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import scipy 9 | import splines 10 | import splines.quaternion 11 | import torch 12 | import viser 13 | import viser.transforms as vt 14 | 15 | from seva.geometry import get_preset_pose_fov 16 | 17 | 18 | @dataclasses.dataclass 19 | class Keyframe(object): 20 | position: np.ndarray 21 | wxyz: np.ndarray 22 | override_fov_enabled: bool 23 | override_fov_rad: float 24 | aspect: float 25 | override_transition_enabled: bool 26 | override_transition_sec: float | None 27 | 28 | @staticmethod 29 | def from_camera(camera: viser.CameraHandle, aspect: float) -> "Keyframe": 30 | return Keyframe( 31 | camera.position, 32 | camera.wxyz, 33 | override_fov_enabled=False, 34 | override_fov_rad=camera.fov, 35 | aspect=aspect, 36 | override_transition_enabled=False, 37 | override_transition_sec=None, 38 | ) 39 | 40 | @staticmethod 41 | def from_se3(se3: vt.SE3, fov: float, aspect: float) -> "Keyframe": 42 | return Keyframe( 43 | se3.translation(), 44 | se3.rotation().wxyz, 45 | override_fov_enabled=False, 46 | override_fov_rad=fov, 47 | aspect=aspect, 48 | override_transition_enabled=False, 49 | override_transition_sec=None, 50 | ) 51 | 52 | 53 | class CameraTrajectory(object): 54 | def __init__( 55 | self, 56 | server: viser.ViserServer, 57 | duration_element: viser.GuiInputHandle[float], 58 | scene_scale: float, 59 | scene_node_prefix: str = "/", 60 | ): 61 | self._server = server 62 | self._keyframes: dict[int, tuple[Keyframe, viser.CameraFrustumHandle]] = {} 63 | self._keyframe_counter: int = 0 64 | self._spline_nodes: list[viser.SceneNodeHandle] = [] 65 | self._camera_edit_panel: viser.Gui3dContainerHandle | None = None 66 | 67 | self._orientation_spline: splines.quaternion.KochanekBartels | None = None 68 | self._position_spline: splines.KochanekBartels | None = None 69 | self._fov_spline: splines.KochanekBartels | None = None 70 | 71 | self._keyframes_visible: bool = True 72 | 73 | self._duration_element = duration_element 74 | self._scene_node_prefix = scene_node_prefix 75 | 76 | self.scene_scale = scene_scale 77 | # These parameters should be overridden externally. 78 | self.loop: bool = False 79 | self.framerate: float = 30.0 80 | self.tension: float = 0.0 # Tension / alpha term. 81 | self.default_fov: float = 0.0 82 | self.default_transition_sec: float = 0.0 83 | self.show_spline: bool = True 84 | 85 | def set_keyframes_visible(self, visible: bool) -> None: 86 | self._keyframes_visible = visible 87 | for keyframe in self._keyframes.values(): 88 | keyframe[1].visible = visible 89 | 90 | def add_camera(self, keyframe: Keyframe, keyframe_index: int | None = None) -> None: 91 | """Add a new camera, or replace an old one if `keyframe_index` is passed in.""" 92 | server = self._server 93 | 94 | # Add a keyframe if we aren't replacing an existing one. 95 | if keyframe_index is None: 96 | keyframe_index = self._keyframe_counter 97 | self._keyframe_counter += 1 98 | 99 | print( 100 | f"{keyframe.wxyz=} {keyframe.position=} {keyframe_index=} {keyframe.aspect=}" 101 | ) 102 | frustum_handle = server.scene.add_camera_frustum( 103 | str(Path(self._scene_node_prefix) / f"cameras/{keyframe_index}"), 104 | fov=( 105 | keyframe.override_fov_rad 106 | if keyframe.override_fov_enabled 107 | else self.default_fov 108 | ), 109 | aspect=keyframe.aspect, 110 | scale=0.1 * self.scene_scale, 111 | color=(200, 10, 30), 112 | wxyz=keyframe.wxyz, 113 | position=keyframe.position, 114 | visible=self._keyframes_visible, 115 | ) 116 | self._server.scene.add_icosphere( 117 | str(Path(self._scene_node_prefix) / f"cameras/{keyframe_index}/sphere"), 118 | radius=0.03, 119 | color=(200, 10, 30), 120 | ) 121 | 122 | @frustum_handle.on_click 123 | def _(_) -> None: 124 | if self._camera_edit_panel is not None: 125 | self._camera_edit_panel.remove() 126 | self._camera_edit_panel = None 127 | 128 | with server.scene.add_3d_gui_container( 129 | "/camera_edit_panel", 130 | position=keyframe.position, 131 | ) as camera_edit_panel: 132 | self._camera_edit_panel = camera_edit_panel 133 | override_fov = server.gui.add_checkbox( 134 | "Override FOV", initial_value=keyframe.override_fov_enabled 135 | ) 136 | override_fov_degrees = server.gui.add_slider( 137 | "Override FOV (degrees)", 138 | 5.0, 139 | 175.0, 140 | step=0.1, 141 | initial_value=keyframe.override_fov_rad * 180.0 / np.pi, 142 | disabled=not keyframe.override_fov_enabled, 143 | ) 144 | delete_button = server.gui.add_button( 145 | "Delete", color="red", icon=viser.Icon.TRASH 146 | ) 147 | go_to_button = server.gui.add_button("Go to") 148 | close_button = server.gui.add_button("Close") 149 | 150 | @override_fov.on_update 151 | def _(_) -> None: 152 | keyframe.override_fov_enabled = override_fov.value 153 | override_fov_degrees.disabled = not override_fov.value 154 | self.add_camera(keyframe, keyframe_index) 155 | 156 | @override_fov_degrees.on_update 157 | def _(_) -> None: 158 | keyframe.override_fov_rad = override_fov_degrees.value / 180.0 * np.pi 159 | self.add_camera(keyframe, keyframe_index) 160 | 161 | @delete_button.on_click 162 | def _(event: viser.GuiEvent) -> None: 163 | assert event.client is not None 164 | with event.client.gui.add_modal("Confirm") as modal: 165 | event.client.gui.add_markdown("Delete keyframe?") 166 | confirm_button = event.client.gui.add_button( 167 | "Yes", color="red", icon=viser.Icon.TRASH 168 | ) 169 | exit_button = event.client.gui.add_button("Cancel") 170 | 171 | @confirm_button.on_click 172 | def _(_) -> None: 173 | assert camera_edit_panel is not None 174 | 175 | keyframe_id = None 176 | for i, keyframe_tuple in self._keyframes.items(): 177 | if keyframe_tuple[1] is frustum_handle: 178 | keyframe_id = i 179 | break 180 | assert keyframe_id is not None 181 | 182 | self._keyframes.pop(keyframe_id) 183 | frustum_handle.remove() 184 | camera_edit_panel.remove() 185 | self._camera_edit_panel = None 186 | modal.close() 187 | self.update_spline() 188 | 189 | @exit_button.on_click 190 | def _(_) -> None: 191 | modal.close() 192 | 193 | @go_to_button.on_click 194 | def _(event: viser.GuiEvent) -> None: 195 | assert event.client is not None 196 | client = event.client 197 | T_world_current = vt.SE3.from_rotation_and_translation( 198 | vt.SO3(client.camera.wxyz), client.camera.position 199 | ) 200 | T_world_target = vt.SE3.from_rotation_and_translation( 201 | vt.SO3(keyframe.wxyz), keyframe.position 202 | ) @ vt.SE3.from_translation(np.array([0.0, 0.0, -0.5])) 203 | 204 | T_current_target = T_world_current.inverse() @ T_world_target 205 | 206 | for j in range(10): 207 | T_world_set = T_world_current @ vt.SE3.exp( 208 | T_current_target.log() * j / 9.0 209 | ) 210 | 211 | # Important bit: we atomically set both the orientation and 212 | # the position of the camera. 213 | with client.atomic(): 214 | client.camera.wxyz = T_world_set.rotation().wxyz 215 | client.camera.position = T_world_set.translation() 216 | time.sleep(1.0 / 30.0) 217 | 218 | @close_button.on_click 219 | def _(_) -> None: 220 | assert camera_edit_panel is not None 221 | camera_edit_panel.remove() 222 | self._camera_edit_panel = None 223 | 224 | self._keyframes[keyframe_index] = (keyframe, frustum_handle) 225 | 226 | def update_aspect(self, aspect: float) -> None: 227 | for keyframe_index, frame in self._keyframes.items(): 228 | frame = dataclasses.replace(frame[0], aspect=aspect) 229 | self.add_camera(frame, keyframe_index=keyframe_index) 230 | 231 | def get_aspect(self) -> float: 232 | """Get W/H aspect ratio, which is shared across all keyframes.""" 233 | assert len(self._keyframes) > 0 234 | return next(iter(self._keyframes.values()))[0].aspect 235 | 236 | def reset(self) -> None: 237 | for frame in self._keyframes.values(): 238 | print(f"removing {frame[1]}") 239 | frame[1].remove() 240 | self._keyframes.clear() 241 | self.update_spline() 242 | print("camera traj reset") 243 | 244 | def spline_t_from_t_sec(self, time: np.ndarray) -> np.ndarray: 245 | """From a time value in seconds, compute a t value for our geometric 246 | spline interpolation. An increment of 1 for the latter will move the 247 | camera forward by one keyframe. 248 | 249 | We use a PCHIP spline here to guarantee monotonicity. 250 | """ 251 | transition_times_cumsum = self.compute_transition_times_cumsum() 252 | spline_indices = np.arange(transition_times_cumsum.shape[0]) 253 | 254 | if self.loop: 255 | # In the case of a loop, we pad the spline to match the start/end 256 | # slopes. 257 | interpolator = scipy.interpolate.PchipInterpolator( 258 | x=np.concatenate( 259 | [ 260 | [-(transition_times_cumsum[-1] - transition_times_cumsum[-2])], 261 | transition_times_cumsum, 262 | transition_times_cumsum[-1:] + transition_times_cumsum[1:2], 263 | ], 264 | axis=0, 265 | ), 266 | y=np.concatenate( 267 | [[-1], spline_indices, [spline_indices[-1] + 1]], # type: ignore 268 | axis=0, 269 | ), 270 | ) 271 | else: 272 | interpolator = scipy.interpolate.PchipInterpolator( 273 | x=transition_times_cumsum, y=spline_indices 274 | ) 275 | 276 | # Clip to account for floating point error. 277 | return np.clip(interpolator(time), 0, spline_indices[-1]) 278 | 279 | def interpolate_pose_and_fov_rad( 280 | self, normalized_t: float 281 | ) -> tuple[vt.SE3, float] | None: 282 | if len(self._keyframes) < 2: 283 | return None 284 | 285 | self._fov_spline = splines.KochanekBartels( 286 | [ 287 | ( 288 | keyframe[0].override_fov_rad 289 | if keyframe[0].override_fov_enabled 290 | else self.default_fov 291 | ) 292 | for keyframe in self._keyframes.values() 293 | ], 294 | tcb=(self.tension, 0.0, 0.0), 295 | endconditions="closed" if self.loop else "natural", 296 | ) 297 | 298 | assert self._orientation_spline is not None 299 | assert self._position_spline is not None 300 | assert self._fov_spline is not None 301 | 302 | max_t = self.compute_duration() 303 | t = max_t * normalized_t 304 | spline_t = float(self.spline_t_from_t_sec(np.array(t))) 305 | 306 | quat = self._orientation_spline.evaluate(spline_t) 307 | assert isinstance(quat, splines.quaternion.UnitQuaternion) 308 | return ( 309 | vt.SE3.from_rotation_and_translation( 310 | vt.SO3(np.array([quat.scalar, *quat.vector])), 311 | self._position_spline.evaluate(spline_t), 312 | ), 313 | float(self._fov_spline.evaluate(spline_t)), 314 | ) 315 | 316 | def update_spline(self) -> None: 317 | num_frames = int(self.compute_duration() * self.framerate) 318 | keyframes = list(self._keyframes.values()) 319 | 320 | if num_frames <= 0 or not self.show_spline or len(keyframes) < 2: 321 | for node in self._spline_nodes: 322 | node.remove() 323 | self._spline_nodes.clear() 324 | return 325 | 326 | transition_times_cumsum = self.compute_transition_times_cumsum() 327 | 328 | self._orientation_spline = splines.quaternion.KochanekBartels( 329 | [ 330 | splines.quaternion.UnitQuaternion.from_unit_xyzw( 331 | np.roll(keyframe[0].wxyz, shift=-1) 332 | ) 333 | for keyframe in keyframes 334 | ], 335 | tcb=(self.tension, 0.0, 0.0), 336 | endconditions="closed" if self.loop else "natural", 337 | ) 338 | self._position_spline = splines.KochanekBartels( 339 | [keyframe[0].position for keyframe in keyframes], 340 | tcb=(self.tension, 0.0, 0.0), 341 | endconditions="closed" if self.loop else "natural", 342 | ) 343 | 344 | # Update visualized spline. 345 | points_array = self._position_spline.evaluate( 346 | self.spline_t_from_t_sec( 347 | np.linspace(0, transition_times_cumsum[-1], num_frames) 348 | ) 349 | ) 350 | colors_array = np.array( 351 | [ 352 | colorsys.hls_to_rgb(h, 0.5, 1.0) 353 | for h in np.linspace(0.0, 1.0, len(points_array)) 354 | ] 355 | ) 356 | 357 | # Clear prior spline nodes. 358 | for node in self._spline_nodes: 359 | node.remove() 360 | self._spline_nodes.clear() 361 | 362 | self._spline_nodes.append( 363 | self._server.scene.add_spline_catmull_rom( 364 | str(Path(self._scene_node_prefix) / "camera_spline"), 365 | positions=points_array, 366 | color=(220, 220, 220), 367 | closed=self.loop, 368 | line_width=1.0, 369 | segments=points_array.shape[0] + 1, 370 | ) 371 | ) 372 | self._spline_nodes.append( 373 | self._server.scene.add_point_cloud( 374 | str(Path(self._scene_node_prefix) / "camera_spline/points"), 375 | points=points_array, 376 | colors=colors_array, 377 | point_size=0.04, 378 | ) 379 | ) 380 | 381 | def make_transition_handle(i: int) -> None: 382 | assert self._position_spline is not None 383 | transition_pos = self._position_spline.evaluate( 384 | float( 385 | self.spline_t_from_t_sec( 386 | (transition_times_cumsum[i] + transition_times_cumsum[i + 1]) 387 | / 2.0, 388 | ) 389 | ) 390 | ) 391 | transition_sphere = self._server.scene.add_icosphere( 392 | str(Path(self._scene_node_prefix) / f"camera_spline/transition_{i}"), 393 | radius=0.04, 394 | color=(255, 0, 0), 395 | position=transition_pos, 396 | ) 397 | self._spline_nodes.append(transition_sphere) 398 | 399 | @transition_sphere.on_click 400 | def _(_) -> None: 401 | server = self._server 402 | 403 | if self._camera_edit_panel is not None: 404 | self._camera_edit_panel.remove() 405 | self._camera_edit_panel = None 406 | 407 | keyframe_index = (i + 1) % len(self._keyframes) 408 | keyframe = keyframes[keyframe_index][0] 409 | 410 | with server.scene.add_3d_gui_container( 411 | "/camera_edit_panel", 412 | position=transition_pos, 413 | ) as camera_edit_panel: 414 | self._camera_edit_panel = camera_edit_panel 415 | override_transition_enabled = server.gui.add_checkbox( 416 | "Override transition", 417 | initial_value=keyframe.override_transition_enabled, 418 | ) 419 | override_transition_sec = server.gui.add_number( 420 | "Override transition (sec)", 421 | initial_value=( 422 | keyframe.override_transition_sec 423 | if keyframe.override_transition_sec is not None 424 | else self.default_transition_sec 425 | ), 426 | min=0.001, 427 | max=30.0, 428 | step=0.001, 429 | disabled=not override_transition_enabled.value, 430 | ) 431 | close_button = server.gui.add_button("Close") 432 | 433 | @override_transition_enabled.on_update 434 | def _(_) -> None: 435 | keyframe.override_transition_enabled = ( 436 | override_transition_enabled.value 437 | ) 438 | override_transition_sec.disabled = ( 439 | not override_transition_enabled.value 440 | ) 441 | self._duration_element.value = self.compute_duration() 442 | 443 | @override_transition_sec.on_update 444 | def _(_) -> None: 445 | keyframe.override_transition_sec = override_transition_sec.value 446 | self._duration_element.value = self.compute_duration() 447 | 448 | @close_button.on_click 449 | def _(_) -> None: 450 | assert camera_edit_panel is not None 451 | camera_edit_panel.remove() 452 | self._camera_edit_panel = None 453 | 454 | (num_transitions_plus_1,) = transition_times_cumsum.shape 455 | for i in range(num_transitions_plus_1 - 1): 456 | make_transition_handle(i) 457 | 458 | def compute_duration(self) -> float: 459 | """Compute the total duration of the trajectory.""" 460 | total = 0.0 461 | for i, (keyframe, frustum) in enumerate(self._keyframes.values()): 462 | if i == 0 and not self.loop: 463 | continue 464 | del frustum 465 | total += ( 466 | keyframe.override_transition_sec 467 | if keyframe.override_transition_enabled 468 | and keyframe.override_transition_sec is not None 469 | else self.default_transition_sec 470 | ) 471 | return total 472 | 473 | def compute_transition_times_cumsum(self) -> np.ndarray: 474 | """Compute the total duration of the trajectory.""" 475 | total = 0.0 476 | out = [0.0] 477 | for i, (keyframe, frustum) in enumerate(self._keyframes.values()): 478 | if i == 0: 479 | continue 480 | del frustum 481 | total += ( 482 | keyframe.override_transition_sec 483 | if keyframe.override_transition_enabled 484 | and keyframe.override_transition_sec is not None 485 | else self.default_transition_sec 486 | ) 487 | out.append(total) 488 | 489 | if self.loop: 490 | keyframe = next(iter(self._keyframes.values()))[0] 491 | total += ( 492 | keyframe.override_transition_sec 493 | if keyframe.override_transition_enabled 494 | and keyframe.override_transition_sec is not None 495 | else self.default_transition_sec 496 | ) 497 | out.append(total) 498 | 499 | return np.array(out) 500 | 501 | 502 | @dataclasses.dataclass 503 | class GuiState: 504 | preview_render: bool 505 | preview_fov: float 506 | preview_aspect: float 507 | camera_traj_list: list | None 508 | active_input_index: int 509 | 510 | 511 | def define_gui( 512 | server: viser.ViserServer, 513 | init_fov: float = 75.0, 514 | img_wh: tuple[int, int] = (576, 576), 515 | **kwargs, 516 | ) -> GuiState: 517 | gui_state = GuiState( 518 | preview_render=False, 519 | preview_fov=0.0, 520 | preview_aspect=1.0, 521 | camera_traj_list=None, 522 | active_input_index=0, 523 | ) 524 | 525 | with server.gui.add_folder( 526 | "Preset camera trajectories", order=99, expand_by_default=False 527 | ): 528 | preset_traj_dropdown = server.gui.add_dropdown( 529 | "Options", 530 | [ 531 | "orbit", 532 | "spiral", 533 | "lemniscate", 534 | "zoom-out", 535 | "dolly zoom-out", 536 | ], 537 | initial_value="orbit", 538 | hint="Select a preset camera trajectory.", 539 | ) 540 | preset_duration_num = server.gui.add_number( 541 | "Duration (sec)", 542 | min=1.0, 543 | max=60.0, 544 | step=0.5, 545 | initial_value=2.0, 546 | ) 547 | preset_submit_button = server.gui.add_button( 548 | "Submit", 549 | icon=viser.Icon.PICK, 550 | hint="Add a new keyframe at the current pose.", 551 | ) 552 | 553 | @preset_submit_button.on_click 554 | def _(event: viser.GuiEvent) -> None: 555 | camera_traj.reset() 556 | gui_state.camera_traj_list = None 557 | 558 | duration = preset_duration_num.value 559 | fps = framerate_number.value 560 | num_frames = int(duration * fps) 561 | transition_sec = duration / num_frames 562 | transition_sec_number.value = transition_sec 563 | assert event.client_id is not None 564 | transition_sec_number.disabled = True 565 | loop_checkbox.disabled = True 566 | add_keyframe_button.disabled = True 567 | 568 | camera = server.get_clients()[event.client_id].camera 569 | start_w2c = torch.linalg.inv( 570 | torch.as_tensor( 571 | vt.SE3.from_rotation_and_translation( 572 | vt.SO3(camera.wxyz), camera.position 573 | ).as_matrix(), 574 | dtype=torch.float32, 575 | ) 576 | ) 577 | look_at = torch.as_tensor(camera.look_at, dtype=torch.float32) 578 | up_direction = torch.as_tensor(camera.up_direction, dtype=torch.float32) 579 | poses, fovs = get_preset_pose_fov( 580 | option=preset_traj_dropdown.value, # type: ignore 581 | num_frames=num_frames, 582 | start_w2c=start_w2c, 583 | look_at=look_at, 584 | up_direction=up_direction, 585 | fov=camera.fov, 586 | ) 587 | assert poses is not None and fovs is not None 588 | for pose, fov in zip(poses, fovs): 589 | camera_traj.add_camera( 590 | Keyframe.from_se3( 591 | vt.SE3.from_matrix(pose), 592 | fov=fov, 593 | aspect=img_wh[0] / img_wh[1], 594 | ) 595 | ) 596 | 597 | duration_number.value = camera_traj.compute_duration() 598 | camera_traj.update_spline() 599 | 600 | with server.gui.add_folder("Advanced", expand_by_default=False, order=100): 601 | transition_sec_number = server.gui.add_number( 602 | "Transition (sec)", 603 | min=0.001, 604 | max=30.0, 605 | step=0.001, 606 | initial_value=1.5, 607 | hint="Time in seconds between each keyframe, which can also be overridden on a per-transition basis.", 608 | ) 609 | framerate_number = server.gui.add_number( 610 | "FPS", min=0.1, max=240.0, step=1e-2, initial_value=30.0 611 | ) 612 | framerate_buttons = server.gui.add_button_group("", ("24", "30", "60")) 613 | duration_number = server.gui.add_number( 614 | "Duration (sec)", 615 | min=0.0, 616 | max=1e8, 617 | step=0.001, 618 | initial_value=0.0, 619 | disabled=True, 620 | ) 621 | 622 | @framerate_buttons.on_click 623 | def _(_) -> None: 624 | framerate_number.value = float(framerate_buttons.value) 625 | 626 | fov_degree_slider = server.gui.add_slider( 627 | "FOV", 628 | initial_value=init_fov, 629 | min=0.1, 630 | max=175.0, 631 | step=0.01, 632 | hint="Field-of-view for rendering, which can also be overridden on a per-keyframe basis.", 633 | ) 634 | 635 | @fov_degree_slider.on_update 636 | def _(_) -> None: 637 | fov_radians = fov_degree_slider.value / 180.0 * np.pi 638 | for client in server.get_clients().values(): 639 | client.camera.fov = fov_radians 640 | camera_traj.default_fov = fov_radians 641 | 642 | # Updating the aspect ratio will also re-render the camera frustums. 643 | # Could rethink this. 644 | camera_traj.update_aspect(img_wh[0] / img_wh[1]) 645 | compute_and_update_preview_camera_state() 646 | 647 | scene_node_prefix = "/render_assets" 648 | base_scene_node = server.scene.add_frame(scene_node_prefix, show_axes=False) 649 | add_keyframe_button = server.gui.add_button( 650 | "Add keyframe", 651 | icon=viser.Icon.PLUS, 652 | hint="Add a new keyframe at the current pose.", 653 | ) 654 | 655 | @add_keyframe_button.on_click 656 | def _(event: viser.GuiEvent) -> None: 657 | assert event.client_id is not None 658 | camera = server.get_clients()[event.client_id].camera 659 | pose = vt.SE3.from_rotation_and_translation( 660 | vt.SO3(camera.wxyz), camera.position 661 | ) 662 | print(f"client {event.client_id} at {camera.position} {camera.wxyz}") 663 | print(f"camera pose {pose.as_matrix()}") 664 | 665 | # Add this camera to the trajectory. 666 | camera_traj.add_camera( 667 | Keyframe.from_camera( 668 | camera, 669 | aspect=img_wh[0] / img_wh[1], 670 | ), 671 | ) 672 | duration_number.value = camera_traj.compute_duration() 673 | camera_traj.update_spline() 674 | 675 | clear_keyframes_button = server.gui.add_button( 676 | "Clear keyframes", 677 | icon=viser.Icon.TRASH, 678 | hint="Remove all keyframes from the render trajectory.", 679 | ) 680 | 681 | @clear_keyframes_button.on_click 682 | def _(event: viser.GuiEvent) -> None: 683 | assert event.client_id is not None 684 | client = server.get_clients()[event.client_id] 685 | with client.atomic(), client.gui.add_modal("Confirm") as modal: 686 | client.gui.add_markdown("Clear all keyframes?") 687 | confirm_button = client.gui.add_button( 688 | "Yes", color="red", icon=viser.Icon.TRASH 689 | ) 690 | exit_button = client.gui.add_button("Cancel") 691 | 692 | @confirm_button.on_click 693 | def _(_) -> None: 694 | camera_traj.reset() 695 | modal.close() 696 | 697 | duration_number.value = camera_traj.compute_duration() 698 | add_keyframe_button.disabled = False 699 | transition_sec_number.disabled = False 700 | transition_sec_number.value = 1.5 701 | loop_checkbox.disabled = False 702 | 703 | nonlocal gui_state 704 | gui_state.camera_traj_list = None 705 | 706 | @exit_button.on_click 707 | def _(_) -> None: 708 | modal.close() 709 | 710 | play_button = server.gui.add_button("Play", icon=viser.Icon.PLAYER_PLAY) 711 | pause_button = server.gui.add_button( 712 | "Pause", icon=viser.Icon.PLAYER_PAUSE, visible=False 713 | ) 714 | 715 | # Poll the play button to see if we should be playing endlessly. 716 | def play() -> None: 717 | while True: 718 | while not play_button.visible: 719 | max_frame = int(framerate_number.value * duration_number.value) 720 | if max_frame > 0: 721 | assert preview_frame_slider is not None 722 | preview_frame_slider.value = ( 723 | preview_frame_slider.value + 1 724 | ) % max_frame 725 | time.sleep(1.0 / framerate_number.value) 726 | time.sleep(0.1) 727 | 728 | threading.Thread(target=play).start() 729 | 730 | # Play the camera trajectory when the play button is pressed. 731 | @play_button.on_click 732 | def _(_) -> None: 733 | play_button.visible = False 734 | pause_button.visible = True 735 | 736 | # Play the camera trajectory when the play button is pressed. 737 | @pause_button.on_click 738 | def _(_) -> None: 739 | play_button.visible = True 740 | pause_button.visible = False 741 | 742 | preview_render_button = server.gui.add_button( 743 | "Preview render", 744 | hint="Show a preview of the render in the viewport.", 745 | icon=viser.Icon.CAMERA_CHECK, 746 | ) 747 | preview_render_stop_button = server.gui.add_button( 748 | "Exit render preview", 749 | color="red", 750 | icon=viser.Icon.CAMERA_CANCEL, 751 | visible=False, 752 | ) 753 | 754 | @preview_render_button.on_click 755 | def _(_) -> None: 756 | gui_state.preview_render = True 757 | preview_render_button.visible = False 758 | preview_render_stop_button.visible = True 759 | play_button.visible = False 760 | pause_button.visible = True 761 | preset_submit_button.disabled = True 762 | 763 | maybe_pose_and_fov_rad = compute_and_update_preview_camera_state() 764 | if maybe_pose_and_fov_rad is None: 765 | remove_preview_camera() 766 | return 767 | pose, fov = maybe_pose_and_fov_rad 768 | del fov 769 | 770 | # Hide all render assets when we're previewing the render. 771 | nonlocal base_scene_node 772 | base_scene_node.visible = False 773 | 774 | # Back up and then set camera poses. 775 | for client in server.get_clients().values(): 776 | camera_pose_backup_from_id[client.client_id] = ( 777 | client.camera.position, 778 | client.camera.look_at, 779 | client.camera.up_direction, 780 | ) 781 | with client.atomic(): 782 | client.camera.wxyz = pose.rotation().wxyz 783 | client.camera.position = pose.translation() 784 | 785 | def stop_preview_render() -> None: 786 | gui_state.preview_render = False 787 | preview_render_button.visible = True 788 | preview_render_stop_button.visible = False 789 | play_button.visible = True 790 | pause_button.visible = False 791 | preset_submit_button.disabled = False 792 | 793 | # Revert camera poses. 794 | for client in server.get_clients().values(): 795 | if client.client_id not in camera_pose_backup_from_id: 796 | continue 797 | cam_position, cam_look_at, cam_up = camera_pose_backup_from_id.pop( 798 | client.client_id 799 | ) 800 | with client.atomic(): 801 | client.camera.position = cam_position 802 | client.camera.look_at = cam_look_at 803 | client.camera.up_direction = cam_up 804 | client.flush() 805 | 806 | # Un-hide render assets. 807 | nonlocal base_scene_node 808 | base_scene_node.visible = True 809 | remove_preview_camera() 810 | 811 | @preview_render_stop_button.on_click 812 | def _(_) -> None: 813 | stop_preview_render() 814 | 815 | def get_max_frame_index() -> int: 816 | return max(1, int(framerate_number.value * duration_number.value) - 1) 817 | 818 | def add_preview_frame_slider() -> viser.GuiInputHandle[int] | None: 819 | """Helper for creating the current frame # slider. This is removed and 820 | re-added anytime the `max` value changes.""" 821 | 822 | preview_frame_slider = server.gui.add_slider( 823 | "Preview frame", 824 | min=0, 825 | max=get_max_frame_index(), 826 | step=1, 827 | initial_value=0, 828 | order=set_traj_button.order + 0.01, 829 | disabled=get_max_frame_index() == 1, 830 | ) 831 | play_button.disabled = preview_frame_slider.disabled 832 | preview_render_button.disabled = preview_frame_slider.disabled 833 | set_traj_button.disabled = preview_frame_slider.disabled 834 | 835 | @preview_frame_slider.on_update 836 | def _(_) -> None: 837 | nonlocal preview_camera_handle 838 | maybe_pose_and_fov_rad = compute_and_update_preview_camera_state() 839 | if maybe_pose_and_fov_rad is None: 840 | return 841 | pose, fov_rad = maybe_pose_and_fov_rad 842 | 843 | preview_camera_handle = server.scene.add_camera_frustum( 844 | str(Path(scene_node_prefix) / "preview_camera"), 845 | fov=fov_rad, 846 | aspect=img_wh[0] / img_wh[1], 847 | scale=0.35, 848 | wxyz=pose.rotation().wxyz, 849 | position=pose.translation(), 850 | color=(10, 200, 30), 851 | ) 852 | if gui_state.preview_render: 853 | for client in server.get_clients().values(): 854 | with client.atomic(): 855 | client.camera.wxyz = pose.rotation().wxyz 856 | client.camera.position = pose.translation() 857 | 858 | return preview_frame_slider 859 | 860 | set_traj_button = server.gui.add_button( 861 | "Set camera trajectory", 862 | color="green", 863 | icon=viser.Icon.CHECK, 864 | hint="Save the camera trajectory for rendering.", 865 | ) 866 | 867 | @set_traj_button.on_click 868 | def _(event: viser.GuiEvent) -> None: 869 | assert event.client is not None 870 | num_frames = int(framerate_number.value * duration_number.value) 871 | 872 | def get_intrinsics(W, H, fov_rad): 873 | focal = 0.5 * H / np.tan(0.5 * fov_rad) 874 | return np.array( 875 | [[focal, 0.0, 0.5 * W], [0.0, focal, 0.5 * H], [0.0, 0.0, 1.0]] 876 | ) 877 | 878 | camera_traj_list = [] 879 | for i in range(num_frames): 880 | maybe_pose_and_fov_rad = camera_traj.interpolate_pose_and_fov_rad( 881 | i / num_frames 882 | ) 883 | if maybe_pose_and_fov_rad is None: 884 | return 885 | pose, fov_rad = maybe_pose_and_fov_rad 886 | H = img_wh[1] 887 | W = img_wh[0] 888 | K = get_intrinsics(W, H, fov_rad) 889 | w2c = pose.inverse().as_matrix() 890 | camera_traj_list.append( 891 | { 892 | "w2c": w2c.flatten().tolist(), 893 | "K": K.flatten().tolist(), 894 | "img_wh": (W, H), 895 | } 896 | ) 897 | nonlocal gui_state 898 | gui_state.camera_traj_list = camera_traj_list 899 | print(f"Get camera_traj_list: {gui_state.camera_traj_list}") 900 | 901 | stop_preview_render() 902 | 903 | preview_frame_slider = add_preview_frame_slider() 904 | 905 | loop_checkbox = server.gui.add_checkbox( 906 | "Loop", False, hint="Add a segment between the first and last keyframes." 907 | ) 908 | 909 | @loop_checkbox.on_update 910 | def _(_) -> None: 911 | camera_traj.loop = loop_checkbox.value 912 | duration_number.value = camera_traj.compute_duration() 913 | 914 | @transition_sec_number.on_update 915 | def _(_) -> None: 916 | camera_traj.default_transition_sec = transition_sec_number.value 917 | duration_number.value = camera_traj.compute_duration() 918 | 919 | preview_camera_handle: viser.SceneNodeHandle | None = None 920 | 921 | def remove_preview_camera() -> None: 922 | nonlocal preview_camera_handle 923 | if preview_camera_handle is not None: 924 | preview_camera_handle.remove() 925 | preview_camera_handle = None 926 | 927 | def compute_and_update_preview_camera_state() -> tuple[vt.SE3, float] | None: 928 | """Update the render tab state with the current preview camera pose. 929 | Returns current camera pose + FOV if available.""" 930 | 931 | if preview_frame_slider is None: 932 | return None 933 | maybe_pose_and_fov_rad = camera_traj.interpolate_pose_and_fov_rad( 934 | preview_frame_slider.value / get_max_frame_index() 935 | ) 936 | if maybe_pose_and_fov_rad is None: 937 | remove_preview_camera() 938 | return None 939 | pose, fov_rad = maybe_pose_and_fov_rad 940 | gui_state.preview_fov = fov_rad 941 | gui_state.preview_aspect = camera_traj.get_aspect() 942 | return pose, fov_rad 943 | 944 | # We back up the camera poses before and after we start previewing renders. 945 | camera_pose_backup_from_id: dict[int, tuple] = {} 946 | 947 | # Update the # of frames. 948 | @duration_number.on_update 949 | @framerate_number.on_update 950 | def _(_) -> None: 951 | remove_preview_camera() # Will be re-added when slider is updated. 952 | 953 | nonlocal preview_frame_slider 954 | old = preview_frame_slider 955 | assert old is not None 956 | 957 | preview_frame_slider = add_preview_frame_slider() 958 | if preview_frame_slider is not None: 959 | old.remove() 960 | else: 961 | preview_frame_slider = old 962 | 963 | camera_traj.framerate = framerate_number.value 964 | camera_traj.update_spline() 965 | 966 | camera_traj = CameraTrajectory( 967 | server, 968 | duration_number, 969 | scene_node_prefix=scene_node_prefix, 970 | **kwargs, 971 | ) 972 | camera_traj.default_fov = fov_degree_slider.value / 180.0 * np.pi 973 | camera_traj.default_transition_sec = transition_sec_number.value 974 | 975 | return gui_state 976 | -------------------------------------------------------------------------------- /seva/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from seva.modules.layers import ( 7 | Downsample, 8 | GroupNorm32, 9 | ResBlock, 10 | TimestepEmbedSequential, 11 | Upsample, 12 | timestep_embedding, 13 | ) 14 | from seva.modules.transformer import MultiviewTransformer 15 | 16 | 17 | @dataclass 18 | class SevaParams(object): 19 | in_channels: int = 11 20 | model_channels: int = 320 21 | out_channels: int = 4 22 | num_frames: int = 21 23 | num_res_blocks: int = 2 24 | attention_resolutions: list[int] = field(default_factory=lambda: [4, 2, 1]) 25 | channel_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4]) 26 | num_head_channels: int = 64 27 | transformer_depth: list[int] = field(default_factory=lambda: [1, 1, 1, 1]) 28 | context_dim: int = 1024 29 | dense_in_channels: int = 6 30 | dropout: float = 0.0 31 | unflatten_names: list[str] = field( 32 | default_factory=lambda: ["middle_ds8", "output_ds4", "output_ds2"] 33 | ) 34 | 35 | def __post_init__(self): 36 | assert len(self.channel_mult) == len(self.transformer_depth) 37 | 38 | 39 | class Seva(nn.Module): 40 | def __init__(self, params: SevaParams) -> None: 41 | super().__init__() 42 | self.params = params 43 | self.model_channels = params.model_channels 44 | self.out_channels = params.out_channels 45 | self.num_head_channels = params.num_head_channels 46 | 47 | time_embed_dim = params.model_channels * 4 48 | self.time_embed = nn.Sequential( 49 | nn.Linear(params.model_channels, time_embed_dim), 50 | nn.SiLU(), 51 | nn.Linear(time_embed_dim, time_embed_dim), 52 | ) 53 | 54 | self.input_blocks = nn.ModuleList( 55 | [ 56 | TimestepEmbedSequential( 57 | nn.Conv2d(params.in_channels, params.model_channels, 3, padding=1) 58 | ) 59 | ] 60 | ) 61 | self._feature_size = params.model_channels 62 | input_block_chans = [params.model_channels] 63 | ch = params.model_channels 64 | ds = 1 65 | for level, mult in enumerate(params.channel_mult): 66 | for _ in range(params.num_res_blocks): 67 | input_layers: list[ResBlock | MultiviewTransformer | Downsample] = [ 68 | ResBlock( 69 | channels=ch, 70 | emb_channels=time_embed_dim, 71 | out_channels=mult * params.model_channels, 72 | dense_in_channels=params.dense_in_channels, 73 | dropout=params.dropout, 74 | ) 75 | ] 76 | ch = mult * params.model_channels 77 | if ds in params.attention_resolutions: 78 | num_heads = ch // params.num_head_channels 79 | dim_head = params.num_head_channels 80 | input_layers.append( 81 | MultiviewTransformer( 82 | ch, 83 | num_heads, 84 | dim_head, 85 | name=f"input_ds{ds}", 86 | depth=params.transformer_depth[level], 87 | context_dim=params.context_dim, 88 | unflatten_names=params.unflatten_names, 89 | ) 90 | ) 91 | self.input_blocks.append(TimestepEmbedSequential(*input_layers)) 92 | self._feature_size += ch 93 | input_block_chans.append(ch) 94 | if level != len(params.channel_mult) - 1: 95 | ds *= 2 96 | out_ch = ch 97 | self.input_blocks.append( 98 | TimestepEmbedSequential(Downsample(ch, out_channels=out_ch)) 99 | ) 100 | ch = out_ch 101 | input_block_chans.append(ch) 102 | self._feature_size += ch 103 | 104 | num_heads = ch // params.num_head_channels 105 | dim_head = params.num_head_channels 106 | 107 | self.middle_block = TimestepEmbedSequential( 108 | ResBlock( 109 | channels=ch, 110 | emb_channels=time_embed_dim, 111 | out_channels=None, 112 | dense_in_channels=params.dense_in_channels, 113 | dropout=params.dropout, 114 | ), 115 | MultiviewTransformer( 116 | ch, 117 | num_heads, 118 | dim_head, 119 | name=f"middle_ds{ds}", 120 | depth=params.transformer_depth[-1], 121 | context_dim=params.context_dim, 122 | unflatten_names=params.unflatten_names, 123 | ), 124 | ResBlock( 125 | channels=ch, 126 | emb_channels=time_embed_dim, 127 | out_channels=None, 128 | dense_in_channels=params.dense_in_channels, 129 | dropout=params.dropout, 130 | ), 131 | ) 132 | self._feature_size += ch 133 | 134 | self.output_blocks = nn.ModuleList([]) 135 | for level, mult in list(enumerate(params.channel_mult))[::-1]: 136 | for i in range(params.num_res_blocks + 1): 137 | ich = input_block_chans.pop() 138 | output_layers: list[ResBlock | MultiviewTransformer | Upsample] = [ 139 | ResBlock( 140 | channels=ch + ich, 141 | emb_channels=time_embed_dim, 142 | out_channels=params.model_channels * mult, 143 | dense_in_channels=params.dense_in_channels, 144 | dropout=params.dropout, 145 | ) 146 | ] 147 | ch = params.model_channels * mult 148 | if ds in params.attention_resolutions: 149 | num_heads = ch // params.num_head_channels 150 | dim_head = params.num_head_channels 151 | 152 | output_layers.append( 153 | MultiviewTransformer( 154 | ch, 155 | num_heads, 156 | dim_head, 157 | name=f"output_ds{ds}", 158 | depth=params.transformer_depth[level], 159 | context_dim=params.context_dim, 160 | unflatten_names=params.unflatten_names, 161 | ) 162 | ) 163 | if level and i == params.num_res_blocks: 164 | out_ch = ch 165 | ds //= 2 166 | output_layers.append(Upsample(ch, out_ch)) 167 | self.output_blocks.append(TimestepEmbedSequential(*output_layers)) 168 | self._feature_size += ch 169 | 170 | self.out = nn.Sequential( 171 | GroupNorm32(32, ch), 172 | nn.SiLU(), 173 | nn.Conv2d(self.model_channels, params.out_channels, 3, padding=1), 174 | ) 175 | 176 | def forward( 177 | self, 178 | x: torch.Tensor, 179 | t: torch.Tensor, 180 | y: torch.Tensor, 181 | dense_y: torch.Tensor, 182 | num_frames: int | None = None, 183 | ) -> torch.Tensor: 184 | num_frames = num_frames or self.params.num_frames 185 | t_emb = timestep_embedding(t, self.model_channels) 186 | t_emb = self.time_embed(t_emb) 187 | 188 | hs = [] 189 | h = x 190 | for module in self.input_blocks: 191 | h = module( 192 | h, 193 | emb=t_emb, 194 | context=y, 195 | dense_emb=dense_y, 196 | num_frames=num_frames, 197 | ) 198 | hs.append(h) 199 | h = self.middle_block( 200 | h, 201 | emb=t_emb, 202 | context=y, 203 | dense_emb=dense_y, 204 | num_frames=num_frames, 205 | ) 206 | for module in self.output_blocks: 207 | h = torch.cat([h, hs.pop()], dim=1) 208 | h = module( 209 | h, 210 | emb=t_emb, 211 | context=y, 212 | dense_emb=dense_y, 213 | num_frames=num_frames, 214 | ) 215 | h = h.type(x.dtype) 216 | return self.out(h) 217 | 218 | 219 | class SGMWrapper(nn.Module): 220 | def __init__(self, module: Seva): 221 | super().__init__() 222 | self.module = module 223 | 224 | def forward( 225 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs 226 | ) -> torch.Tensor: 227 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) 228 | return self.module( 229 | x, 230 | t=t, 231 | y=c["crossattn"], 232 | dense_y=c["dense_vector"], 233 | **kwargs, 234 | ) 235 | -------------------------------------------------------------------------------- /seva/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/seva/modules/__init__.py -------------------------------------------------------------------------------- /seva/modules/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.models import AutoencoderKL # type: ignore 3 | from torch import nn 4 | 5 | 6 | class AutoEncoder(nn.Module): 7 | scale_factor: float = 0.18215 8 | downsample: int = 8 9 | 10 | def __init__(self, chunk_size: int | None = None): 11 | super().__init__() 12 | self.module = AutoencoderKL.from_pretrained( 13 | "stabilityai/stable-diffusion-2-1-base", 14 | subfolder="vae", 15 | force_download=False, 16 | low_cpu_mem_usage=False, 17 | ) 18 | self.module.eval().requires_grad_(False) # type: ignore 19 | self.chunk_size = chunk_size 20 | 21 | def _encode(self, x: torch.Tensor) -> torch.Tensor: 22 | return ( 23 | self.module.encode(x).latent_dist.mean # type: ignore 24 | * self.scale_factor 25 | ) 26 | 27 | def encode(self, x: torch.Tensor, chunk_size: int | None = None) -> torch.Tensor: 28 | chunk_size = chunk_size or self.chunk_size 29 | if chunk_size is not None: 30 | return torch.cat( 31 | [self._encode(x_chunk) for x_chunk in x.split(chunk_size)], 32 | dim=0, 33 | ) 34 | else: 35 | return self._encode(x) 36 | 37 | def _decode(self, z: torch.Tensor) -> torch.Tensor: 38 | return self.module.decode(z / self.scale_factor).sample # type: ignore 39 | 40 | def decode(self, z: torch.Tensor, chunk_size: int | None = None) -> torch.Tensor: 41 | chunk_size = chunk_size or self.chunk_size 42 | if chunk_size is not None: 43 | return torch.cat( 44 | [self._decode(z_chunk) for z_chunk in z.split(chunk_size)], 45 | dim=0, 46 | ) 47 | else: 48 | return self._decode(z) 49 | 50 | def forward(self, x: torch.Tensor) -> torch.Tensor: 51 | return self.decode(self.encode(x)) 52 | -------------------------------------------------------------------------------- /seva/modules/conditioner.py: -------------------------------------------------------------------------------- 1 | import kornia 2 | import open_clip 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CLIPConditioner(nn.Module): 8 | mean: torch.Tensor 9 | std: torch.Tensor 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.module = open_clip.create_model_and_transforms( 14 | "ViT-H-14", pretrained="laion2b_s32b_b79k" 15 | )[0] 16 | self.module.eval().requires_grad_(False) # type: ignore 17 | self.register_buffer( 18 | "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False 19 | ) 20 | self.register_buffer( 21 | "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False 22 | ) 23 | 24 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 25 | x = kornia.geometry.resize( 26 | x, 27 | (224, 224), 28 | interpolation="bicubic", 29 | align_corners=True, 30 | antialias=True, 31 | ) 32 | x = (x + 1.0) / 2.0 33 | x = kornia.enhance.normalize(x, self.mean, self.std) 34 | return x 35 | 36 | def forward(self, x: torch.Tensor) -> torch.Tensor: 37 | x = self.preprocess(x) 38 | x = self.module.encode_image(x) 39 | return x 40 | -------------------------------------------------------------------------------- /seva/modules/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import repeat 6 | from torch import nn 7 | 8 | from .transformer import MultiviewTransformer 9 | 10 | 11 | def timestep_embedding( 12 | timesteps: torch.Tensor, 13 | dim: int, 14 | max_period: int = 10000, 15 | repeat_only: bool = False, 16 | ) -> torch.Tensor: 17 | if not repeat_only: 18 | half = dim // 2 19 | freqs = torch.exp( 20 | -math.log(max_period) 21 | * torch.arange(start=0, end=half, dtype=torch.float32) 22 | / half 23 | ).to(device=timesteps.device) 24 | args = timesteps[:, None].float() * freqs[None] 25 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 26 | if dim % 2: 27 | embedding = torch.cat( 28 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 29 | ) 30 | else: 31 | embedding = repeat(timesteps, "b -> b d", d=dim) 32 | return embedding 33 | 34 | 35 | class Upsample(nn.Module): 36 | def __init__(self, channels: int, out_channels: int | None = None): 37 | super().__init__() 38 | self.channels = channels 39 | self.out_channels = out_channels or channels 40 | self.conv = nn.Conv2d(self.channels, self.out_channels, 3, 1, 1) 41 | 42 | def forward(self, x: torch.Tensor) -> torch.Tensor: 43 | assert x.shape[1] == self.channels 44 | x = F.interpolate(x, scale_factor=2, mode="nearest") 45 | x = self.conv(x) 46 | return x 47 | 48 | 49 | class Downsample(nn.Module): 50 | def __init__(self, channels: int, out_channels: int | None = None): 51 | super().__init__() 52 | self.channels = channels 53 | self.out_channels = out_channels or channels 54 | self.op = nn.Conv2d(self.channels, self.out_channels, 3, 2, 1) 55 | 56 | def forward(self, x: torch.Tensor) -> torch.Tensor: 57 | assert x.shape[1] == self.channels 58 | return self.op(x) 59 | 60 | 61 | class GroupNorm32(nn.GroupNorm): 62 | def forward(self, input: torch.Tensor) -> torch.Tensor: 63 | return super().forward(input.float()).type(input.dtype) 64 | 65 | 66 | class TimestepEmbedSequential(nn.Sequential): 67 | def forward( # type: ignore[override] 68 | self, 69 | x: torch.Tensor, 70 | emb: torch.Tensor, 71 | context: torch.Tensor, 72 | dense_emb: torch.Tensor, 73 | num_frames: int, 74 | ) -> torch.Tensor: 75 | for layer in self: 76 | if isinstance(layer, MultiviewTransformer): 77 | assert num_frames is not None 78 | x = layer(x, context, num_frames) 79 | elif isinstance(layer, ResBlock): 80 | x = layer(x, emb, dense_emb) 81 | else: 82 | x = layer(x) 83 | return x 84 | 85 | 86 | class ResBlock(nn.Module): 87 | def __init__( 88 | self, 89 | channels: int, 90 | emb_channels: int, 91 | out_channels: int | None, 92 | dense_in_channels: int, 93 | dropout: float, 94 | ): 95 | super().__init__() 96 | out_channels = out_channels or channels 97 | 98 | self.in_layers = nn.Sequential( 99 | GroupNorm32(32, channels), 100 | nn.SiLU(), 101 | nn.Conv2d(channels, out_channels, 3, 1, 1), 102 | ) 103 | self.emb_layers = nn.Sequential( 104 | nn.SiLU(), nn.Linear(emb_channels, out_channels) 105 | ) 106 | self.dense_emb_layers = nn.Sequential( 107 | nn.Conv2d(dense_in_channels, 2 * channels, 1, 1, 0) 108 | ) 109 | self.out_layers = nn.Sequential( 110 | GroupNorm32(32, out_channels), 111 | nn.SiLU(), 112 | nn.Dropout(dropout), 113 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), 114 | ) 115 | if out_channels == channels: 116 | self.skip_connection = nn.Identity() 117 | else: 118 | self.skip_connection = nn.Conv2d(channels, out_channels, 1, 1, 0) 119 | 120 | def forward( 121 | self, x: torch.Tensor, emb: torch.Tensor, dense_emb: torch.Tensor 122 | ) -> torch.Tensor: 123 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 124 | h = in_rest(x) 125 | dense = self.dense_emb_layers( 126 | F.interpolate( 127 | dense_emb, size=h.shape[2:], mode="bilinear", align_corners=True 128 | ) 129 | ).type(h.dtype) 130 | dense_scale, dense_shift = torch.chunk(dense, 2, dim=1) 131 | h = h * (1 + dense_scale) + dense_shift 132 | h = in_conv(h) 133 | emb_out = self.emb_layers(emb).type(h.dtype) 134 | while len(emb_out.shape) < len(h.shape): 135 | emb_out = emb_out[..., None] 136 | h = h + emb_out 137 | h = self.out_layers(h) 138 | h = self.skip_connection(x) + h 139 | return h 140 | -------------------------------------------------------------------------------- /seva/modules/preprocessor.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | import os.path as osp 4 | import sys 5 | from typing import cast 6 | 7 | import imageio.v3 as iio 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class Dust3rPipeline(object): 13 | def __init__(self, device: str | torch.device = "cuda"): 14 | submodule_path = osp.realpath( 15 | osp.join(osp.dirname(__file__), "../../third_party/dust3r/") 16 | ) 17 | if submodule_path not in sys.path: 18 | sys.path.insert(0, submodule_path) 19 | try: 20 | with open(os.devnull, "w") as f, contextlib.redirect_stdout(f): 21 | from dust3r.cloud_opt import ( # type: ignore[import] 22 | GlobalAlignerMode, 23 | global_aligner, 24 | ) 25 | from dust3r.image_pairs import make_pairs # type: ignore[import] 26 | from dust3r.inference import inference # type: ignore[import] 27 | from dust3r.model import AsymmetricCroCo3DStereo # type: ignore[import] 28 | from dust3r.utils.image import load_images # type: ignore[import] 29 | except ImportError: 30 | raise ImportError( 31 | "Missing required submodule: 'dust3r'. Please ensure that all submodules are properly set up.\n\n" 32 | "To initialize them, run the following command in the project root:\n" 33 | " git submodule update --init --recursive" 34 | ) 35 | 36 | self.device = torch.device(device) 37 | self.model = AsymmetricCroCo3DStereo.from_pretrained( 38 | "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt" 39 | ).to(self.device) 40 | 41 | self._GlobalAlignerMode = GlobalAlignerMode 42 | self._global_aligner = global_aligner 43 | self._make_pairs = make_pairs 44 | self._inference = inference 45 | self._load_images = load_images 46 | 47 | def infer_cameras_and_points( 48 | self, 49 | img_paths: list[str], 50 | Ks: list[list] = None, 51 | c2ws: list[list] = None, 52 | batch_size: int = 16, 53 | schedule: str = "cosine", 54 | lr: float = 0.01, 55 | niter: int = 500, 56 | min_conf_thr: int = 3, 57 | ) -> tuple[ 58 | list[np.ndarray], np.ndarray, np.ndarray, list[np.ndarray], list[np.ndarray] 59 | ]: 60 | num_img = len(img_paths) 61 | if num_img == 1: 62 | print("Only one image found, duplicating it to create a stereo pair.") 63 | img_paths = img_paths * 2 64 | 65 | images = self._load_images(img_paths, size=512) 66 | pairs = self._make_pairs( 67 | images, 68 | scene_graph="complete", 69 | prefilter=None, 70 | symmetrize=True, 71 | ) 72 | output = self._inference(pairs, self.model, self.device, batch_size=batch_size) 73 | 74 | ori_imgs = [iio.imread(p) for p in img_paths] 75 | ori_img_whs = np.array([img.shape[1::-1] for img in ori_imgs]) 76 | img_whs = np.concatenate([image["true_shape"][:, ::-1] for image in images], 0) 77 | 78 | scene = self._global_aligner( 79 | output, 80 | device=self.device, 81 | mode=self._GlobalAlignerMode.PointCloudOptimizer, 82 | same_focals=True, 83 | optimize_pp=False, # True, 84 | min_conf_thr=min_conf_thr, 85 | ) 86 | 87 | # if Ks is not None: 88 | # scene.preset_focal( 89 | # torch.tensor([[K[0, 0], K[1, 1]] for K in Ks]) 90 | # ) 91 | 92 | if c2ws is not None: 93 | scene.preset_pose(c2ws) 94 | 95 | _ = scene.compute_global_alignment( 96 | init="msp", niter=niter, schedule=schedule, lr=lr 97 | ) 98 | 99 | imgs = cast(list, scene.imgs) 100 | Ks = scene.get_intrinsics().detach().cpu().numpy().copy() 101 | c2ws = scene.get_im_poses().detach().cpu().numpy() # type: ignore 102 | pts3d = [x.detach().cpu().numpy() for x in scene.get_pts3d()] # type: ignore 103 | if num_img > 1: 104 | masks = [x.detach().cpu().numpy() for x in scene.get_masks()] 105 | points = [p[m] for p, m in zip(pts3d, masks)] 106 | point_colors = [img[m] for img, m in zip(imgs, masks)] 107 | else: 108 | points = [p.reshape(-1, 3) for p in pts3d] 109 | point_colors = [img.reshape(-1, 3) for img in imgs] 110 | 111 | # Convert back to the original image size. 112 | imgs = ori_imgs 113 | Ks[:, :2, -1] *= ori_img_whs / img_whs 114 | Ks[:, :2, :2] *= (ori_img_whs / img_whs).mean(axis=1, keepdims=True)[..., None] 115 | 116 | return imgs, Ks, c2ws, points, point_colors 117 | -------------------------------------------------------------------------------- /seva/modules/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange, repeat 4 | from torch import nn 5 | from torch.nn.attention import SDPBackend, sdpa_kernel 6 | 7 | 8 | class GEGLU(nn.Module): 9 | def __init__(self, dim_in: int, dim_out: int): 10 | super().__init__() 11 | self.proj = nn.Linear(dim_in, dim_out * 2) 12 | 13 | def forward(self, x: torch.Tensor) -> torch.Tensor: 14 | x, gate = self.proj(x).chunk(2, dim=-1) 15 | return x * F.gelu(gate) 16 | 17 | 18 | class FeedForward(nn.Module): 19 | def __init__( 20 | self, 21 | dim: int, 22 | dim_out: int | None = None, 23 | mult: int = 4, 24 | dropout: float = 0.0, 25 | ): 26 | super().__init__() 27 | inner_dim = int(dim * mult) 28 | dim_out = dim_out or dim 29 | self.net = nn.Sequential( 30 | GEGLU(dim, inner_dim), nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) 31 | ) 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: 34 | return self.net(x) 35 | 36 | 37 | class Attention(nn.Module): 38 | def __init__( 39 | self, 40 | query_dim: int, 41 | context_dim: int | None = None, 42 | heads: int = 8, 43 | dim_head: int = 64, 44 | dropout: float = 0.0, 45 | ): 46 | super().__init__() 47 | self.heads = heads 48 | self.dim_head = dim_head 49 | inner_dim = dim_head * heads 50 | context_dim = context_dim or query_dim 51 | 52 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 53 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 54 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 55 | self.to_out = nn.Sequential( 56 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) 57 | ) 58 | 59 | def forward( 60 | self, x: torch.Tensor, context: torch.Tensor | None = None 61 | ) -> torch.Tensor: 62 | q = self.to_q(x) 63 | context = context if context is not None else x 64 | k = self.to_k(context) 65 | v = self.to_v(context) 66 | q, k, v = map( 67 | lambda t: rearrange(t, "b l (h d) -> b h l d", h=self.heads), 68 | (q, k, v), 69 | ) 70 | with sdpa_kernel(SDPBackend.FLASH_ATTENTION): 71 | out = F.scaled_dot_product_attention(q, k, v) 72 | out = rearrange(out, "b h l d -> b l (h d)") 73 | out = self.to_out(out) 74 | return out 75 | 76 | 77 | class TransformerBlock(nn.Module): 78 | def __init__( 79 | self, 80 | dim: int, 81 | n_heads: int, 82 | d_head: int, 83 | context_dim: int, 84 | dropout: float = 0.0, 85 | ): 86 | super().__init__() 87 | self.attn1 = Attention( 88 | query_dim=dim, 89 | context_dim=None, 90 | heads=n_heads, 91 | dim_head=d_head, 92 | dropout=dropout, 93 | ) 94 | self.ff = FeedForward(dim, dropout=dropout) 95 | self.attn2 = Attention( 96 | query_dim=dim, 97 | context_dim=context_dim, 98 | heads=n_heads, 99 | dim_head=d_head, 100 | dropout=dropout, 101 | ) 102 | self.norm1 = nn.LayerNorm(dim) 103 | self.norm2 = nn.LayerNorm(dim) 104 | self.norm3 = nn.LayerNorm(dim) 105 | 106 | def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: 107 | x = self.attn1(self.norm1(x)) + x 108 | x = self.attn2(self.norm2(x), context=context) + x 109 | x = self.ff(self.norm3(x)) + x 110 | return x 111 | 112 | 113 | class TransformerBlockTimeMix(nn.Module): 114 | def __init__( 115 | self, 116 | dim: int, 117 | n_heads: int, 118 | d_head: int, 119 | context_dim: int, 120 | dropout: float = 0.0, 121 | ): 122 | super().__init__() 123 | inner_dim = n_heads * d_head 124 | self.norm_in = nn.LayerNorm(dim) 125 | self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout) 126 | self.attn1 = Attention( 127 | query_dim=inner_dim, 128 | context_dim=None, 129 | heads=n_heads, 130 | dim_head=d_head, 131 | dropout=dropout, 132 | ) 133 | self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout) 134 | self.attn2 = Attention( 135 | query_dim=inner_dim, 136 | context_dim=context_dim, 137 | heads=n_heads, 138 | dim_head=d_head, 139 | dropout=dropout, 140 | ) 141 | self.norm1 = nn.LayerNorm(inner_dim) 142 | self.norm2 = nn.LayerNorm(inner_dim) 143 | self.norm3 = nn.LayerNorm(inner_dim) 144 | 145 | def forward( 146 | self, x: torch.Tensor, context: torch.Tensor, num_frames: int 147 | ) -> torch.Tensor: 148 | _, s, _ = x.shape 149 | x = rearrange(x, "(b t) s c -> (b s) t c", t=num_frames) 150 | x = self.ff_in(self.norm_in(x)) + x 151 | x = self.attn1(self.norm1(x), context=None) + x 152 | x = self.attn2(self.norm2(x), context=context) + x 153 | x = self.ff(self.norm3(x)) 154 | x = rearrange(x, "(b s) t c -> (b t) s c", s=s) 155 | return x 156 | 157 | 158 | class SkipConnect(nn.Module): 159 | def __init__(self): 160 | super().__init__() 161 | 162 | def forward( 163 | self, x_spatial: torch.Tensor, x_temporal: torch.Tensor 164 | ) -> torch.Tensor: 165 | return x_spatial + x_temporal 166 | 167 | 168 | class MultiviewTransformer(nn.Module): 169 | def __init__( 170 | self, 171 | in_channels: int, 172 | n_heads: int, 173 | d_head: int, 174 | name: str, 175 | unflatten_names: list[str] = [], 176 | depth: int = 1, 177 | context_dim: int = 1024, 178 | dropout: float = 0.0, 179 | ): 180 | super().__init__() 181 | self.in_channels = in_channels 182 | self.name = name 183 | self.unflatten_names = unflatten_names 184 | 185 | inner_dim = n_heads * d_head 186 | self.norm = nn.GroupNorm(32, in_channels, eps=1e-6) 187 | self.proj_in = nn.Linear(in_channels, inner_dim) 188 | self.transformer_blocks = nn.ModuleList( 189 | [ 190 | TransformerBlock( 191 | inner_dim, 192 | n_heads, 193 | d_head, 194 | context_dim=context_dim, 195 | dropout=dropout, 196 | ) 197 | for _ in range(depth) 198 | ] 199 | ) 200 | self.proj_out = nn.Linear(inner_dim, in_channels) 201 | self.time_mixer = SkipConnect() 202 | self.time_mix_blocks = nn.ModuleList( 203 | [ 204 | TransformerBlockTimeMix( 205 | inner_dim, 206 | n_heads, 207 | d_head, 208 | context_dim=context_dim, 209 | dropout=dropout, 210 | ) 211 | for _ in range(depth) 212 | ] 213 | ) 214 | 215 | def forward( 216 | self, x: torch.Tensor, context: torch.Tensor, num_frames: int 217 | ) -> torch.Tensor: 218 | assert context.ndim == 3 219 | _, _, h, w = x.shape 220 | x_in = x 221 | 222 | time_context = context 223 | time_context_first_timestep = time_context[::num_frames] 224 | time_context = repeat( 225 | time_context_first_timestep, "b ... -> (b n) ...", n=h * w 226 | ) 227 | 228 | if self.name in self.unflatten_names: 229 | context = context[::num_frames] 230 | 231 | x = self.norm(x) 232 | x = rearrange(x, "b c h w -> b (h w) c") 233 | x = self.proj_in(x) 234 | 235 | for block, mix_block in zip(self.transformer_blocks, self.time_mix_blocks): 236 | if self.name in self.unflatten_names: 237 | x = rearrange(x, "(b t) (h w) c -> b (t h w) c", t=num_frames, h=h, w=w) 238 | x = block(x, context=context) 239 | if self.name in self.unflatten_names: 240 | x = rearrange(x, "b (t h w) c -> (b t) (h w) c", t=num_frames, h=h, w=w) 241 | x_mix = mix_block(x, context=time_context, num_frames=num_frames) 242 | x = self.time_mixer(x_spatial=x, x_temporal=x_mix) 243 | 244 | x = self.proj_out(x) 245 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 246 | out = x + x_in 247 | return out 248 | -------------------------------------------------------------------------------- /seva/sampling.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import gradio as gr 6 | from einops import rearrange 7 | from tqdm import tqdm 8 | 9 | from seva.geometry import get_camera_dist 10 | 11 | 12 | def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: 13 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 14 | dims_to_append = target_dims - x.ndim 15 | if dims_to_append < 0: 16 | raise ValueError( 17 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" 18 | ) 19 | return x[(...,) + (None,) * dims_to_append] 20 | 21 | 22 | def append_zero(x: torch.Tensor) -> torch.Tensor: 23 | return torch.cat([x, x.new_zeros([1])]) 24 | 25 | 26 | def to_d(x: torch.Tensor, sigma: torch.Tensor, denoised: torch.Tensor) -> torch.Tensor: 27 | return (x - denoised) / append_dims(sigma, x.ndim) 28 | 29 | 30 | def make_betas( 31 | num_timesteps: int, linear_start: float = 1e-4, linear_end: float = 2e-2 32 | ) -> np.ndarray: 33 | betas = ( 34 | torch.linspace( 35 | linear_start**0.5, linear_end**0.5, num_timesteps, dtype=torch.float64 36 | ) 37 | ** 2 38 | ) 39 | return betas.numpy() 40 | 41 | 42 | def generate_roughly_equally_spaced_steps( 43 | num_substeps: int, max_step: int 44 | ) -> np.ndarray: 45 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] 46 | 47 | 48 | ####################################################### 49 | # Discretization 50 | ####################################################### 51 | 52 | 53 | class Discretization(object): 54 | def __init__(self, num_timesteps: int = 1000): 55 | self.num_timesteps = num_timesteps 56 | 57 | def __call__( 58 | self, 59 | n: int, 60 | do_append_zero: bool = True, 61 | flip: bool = False, 62 | device: str | torch.device = "cpu", 63 | ) -> torch.Tensor: 64 | sigmas = self.get_sigmas(n, device=device) 65 | sigmas = append_zero(sigmas) if do_append_zero else sigmas 66 | return sigmas if not flip else torch.flip(sigmas, (0,)) 67 | 68 | 69 | class DDPMDiscretization(Discretization): 70 | def __init__( 71 | self, 72 | linear_start: float = 5e-06, 73 | linear_end: float = 0.012, 74 | log_snr_shift: float | None = 2.4, 75 | **kwargs, 76 | ): 77 | super().__init__(**kwargs) 78 | betas = make_betas( 79 | self.num_timesteps, 80 | linear_start=linear_start, 81 | linear_end=linear_end, 82 | ) 83 | self.log_snr_shift = log_snr_shift 84 | 85 | alphas = 1.0 - betas # first alpha here is on data side 86 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 87 | 88 | def get_sigmas(self, n: int, device: str | torch.device = "cpu") -> torch.Tensor: 89 | if n < self.num_timesteps: 90 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) 91 | alphas_cumprod = self.alphas_cumprod[timesteps] 92 | elif n == self.num_timesteps: 93 | alphas_cumprod = self.alphas_cumprod 94 | else: 95 | raise ValueError(f"Expected n <= {self.num_timesteps}, but got n = {n}.") 96 | 97 | sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 98 | if self.log_snr_shift is not None: 99 | sigmas = sigmas * np.exp(self.log_snr_shift) 100 | return torch.flip( 101 | torch.tensor(sigmas, dtype=torch.float32, device=device), (0,) 102 | ) 103 | 104 | 105 | ####################################################### 106 | # Denoiser 107 | ####################################################### 108 | 109 | 110 | class DiscreteDenoiser(object): 111 | discretization: Discretization = DDPMDiscretization() 112 | sigmas: torch.Tensor 113 | 114 | def __init__( 115 | self, 116 | num_idx: int = 1000, 117 | device: str | torch.device = "cpu", 118 | ): 119 | self.num_idx = num_idx 120 | self.device = device 121 | self.register_sigmas() 122 | 123 | def scaling( 124 | self, sigma: torch.Tensor 125 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 126 | c_skip = torch.ones_like(sigma, device=sigma.device) 127 | c_out = -sigma 128 | c_in = 1 / (sigma**2 + 1.0) ** 0.5 129 | c_noise = sigma.clone() 130 | return c_skip, c_out, c_in, c_noise 131 | 132 | def register_sigmas(self): 133 | self.sigmas = self.discretization( 134 | self.num_idx, do_append_zero=False, flip=True, device=self.device 135 | ) 136 | 137 | def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: 138 | dists = sigma - self.sigmas[:, None] 139 | return dists.abs().argmin(dim=0).view(sigma.shape) 140 | 141 | def idx_to_sigma(self, idx: torch.Tensor | int) -> torch.Tensor: 142 | return self.sigmas[idx] 143 | 144 | def __call__( 145 | self, 146 | network: nn.Module, 147 | input: torch.Tensor, 148 | sigma: torch.Tensor, 149 | cond: dict, 150 | **additional_model_inputs, 151 | ) -> torch.Tensor: 152 | sigma = self.idx_to_sigma(self.sigma_to_idx(sigma)) 153 | sigma_shape = sigma.shape 154 | sigma = append_dims(sigma, input.ndim) 155 | c_skip, c_out, c_in, c_noise = self.scaling(sigma) 156 | c_noise = self.sigma_to_idx(c_noise.reshape(sigma_shape)) 157 | if "replace" in cond: 158 | x, mask = cond.pop("replace").split((input.shape[1], 1), dim=1) 159 | input = input * (1 - mask) + x * mask 160 | return ( 161 | network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out 162 | + input * c_skip 163 | ) 164 | 165 | 166 | ####################################################### 167 | # Scale rules and schedules 168 | ####################################################### 169 | 170 | 171 | class MultiviewScaleRule(object): 172 | def __init__(self, min_scale: float = 1.0): 173 | self.min_scale = min_scale 174 | 175 | def __call__( 176 | self, 177 | scale: float | torch.Tensor, 178 | c2w: torch.Tensor, 179 | K: torch.Tensor, 180 | input_frame_mask: torch.Tensor, 181 | ) -> torch.Tensor: 182 | c2w_input = c2w[input_frame_mask] 183 | rotation_diff = get_camera_dist(c2w, c2w_input, mode="rotation").min(-1).values 184 | translation_diff = ( 185 | get_camera_dist(c2w, c2w_input, mode="translation").min(-1).values 186 | ) 187 | K_diff = ( 188 | ((K[:, None] - K[input_frame_mask][None]).flatten(-2) == 0).all(-1).any(-1) 189 | ) 190 | close_frame = (rotation_diff < 10.0) & (translation_diff < 1e-5) & K_diff 191 | if isinstance(scale, torch.Tensor): 192 | scale = scale.clone() 193 | scale[close_frame] = self.min_scale 194 | elif isinstance(scale, float): 195 | scale = torch.where(close_frame, self.min_scale, scale) 196 | else: 197 | raise ValueError(f"Invalid scale type {type(scale)}.") 198 | return scale 199 | 200 | 201 | class VanillaCFG(object): 202 | def __init__(self): 203 | self.scale_rule = lambda scale: scale 204 | 205 | def _expand_scale( 206 | self, sigma: float | torch.Tensor, scale: float | torch.Tensor 207 | ) -> float | torch.Tensor: 208 | if isinstance(sigma, float): 209 | return scale 210 | elif isinstance(sigma, torch.Tensor): 211 | if len(sigma.shape) == 1 and isinstance(scale, torch.Tensor): 212 | sigma = append_dims(sigma, scale.ndim) 213 | return scale * torch.ones_like(sigma) 214 | else: 215 | raise ValueError(f"Invalid sigma type {type(sigma)}.") 216 | 217 | def guidance( 218 | self, 219 | uncond: torch.Tensor, 220 | cond: torch.Tensor, 221 | scale: float | torch.Tensor, 222 | ) -> torch.Tensor: 223 | if isinstance(scale, torch.Tensor) and len(scale.shape) == 1: 224 | scale = append_dims(scale, cond.ndim) 225 | return uncond + scale * (cond - uncond) 226 | 227 | def __call__( 228 | self, x: torch.Tensor, sigma: float | torch.Tensor, scale: float | torch.Tensor 229 | ) -> torch.Tensor: 230 | x_u, x_c = x.chunk(2) 231 | scale = self.scale_rule(scale) 232 | x_pred = self.guidance(x_u, x_c, self._expand_scale(sigma, scale)) 233 | return x_pred 234 | 235 | def prepare_inputs( 236 | self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict 237 | ) -> tuple[torch.Tensor, torch.Tensor, dict]: 238 | c_out = dict() 239 | 240 | for k in c: 241 | if k in ["vector", "crossattn", "concat", "replace", "dense_vector"]: 242 | c_out[k] = torch.cat((uc[k], c[k]), 0) 243 | else: 244 | assert c[k] == uc[k] 245 | c_out[k] = c[k] 246 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 247 | 248 | 249 | class MultiviewCFG(VanillaCFG): 250 | def __init__(self, cfg_min: float = 1.0): 251 | self.scale_min = cfg_min 252 | self.scale_rule = MultiviewScaleRule(min_scale=cfg_min) 253 | 254 | def __call__( # type: ignore 255 | self, 256 | x: torch.Tensor, 257 | sigma: float | torch.Tensor, 258 | scale: float | torch.Tensor, 259 | c2w: torch.Tensor, 260 | K: torch.Tensor, 261 | input_frame_mask: torch.Tensor, 262 | ) -> torch.Tensor: 263 | x_u, x_c = x.chunk(2) 264 | scale = self.scale_rule(scale, c2w, K, input_frame_mask) 265 | x_pred = self.guidance(x_u, x_c, self._expand_scale(sigma, scale)) 266 | return x_pred 267 | 268 | 269 | class MultiviewTemporalCFG(MultiviewCFG): 270 | def __init__(self, num_frames: int, cfg_min: float = 1.0): 271 | super().__init__(cfg_min=cfg_min) 272 | self.num_frames = num_frames 273 | distance_matrix = ( 274 | torch.arange(num_frames)[None] - torch.arange(num_frames)[:, None] 275 | ).abs() 276 | self.distance_matrix = distance_matrix 277 | 278 | def __call__( 279 | self, 280 | x: torch.Tensor, 281 | sigma: float | torch.Tensor, 282 | scale: float | torch.Tensor, 283 | c2w: torch.Tensor, 284 | K: torch.Tensor, 285 | input_frame_mask: torch.Tensor, 286 | ) -> torch.Tensor: 287 | input_frame_mask = rearrange( 288 | input_frame_mask, "(b t) ... -> b t ...", t=self.num_frames 289 | ) 290 | min_distance = ( 291 | self.distance_matrix[None].to(x.device) 292 | + (~input_frame_mask[:, None]) * self.num_frames 293 | ).min(-1)[0] 294 | min_distance = min_distance / min_distance.max(-1, keepdim=True)[0].clamp(min=1) 295 | scale = min_distance * (scale - self.scale_min) + self.scale_min 296 | scale = rearrange(scale, "b t ... -> (b t) ...") 297 | scale = append_dims(scale, x.ndim) 298 | return super().__call__(x, sigma, scale, c2w, K, input_frame_mask.flatten(0, 1)) 299 | 300 | 301 | ####################################################### 302 | # Samplers 303 | ####################################################### 304 | 305 | 306 | class GradioTrackedSampler(object): 307 | def __init__(self, *args, abort_event: threading.Event | None = None, **kwargs): 308 | super().__init__(*args, **kwargs) 309 | self.abort_event = abort_event 310 | 311 | def possibly_update_pbar(self, global_pbar: gr.Progress | None): 312 | if global_pbar is not None: 313 | global_pbar.update() 314 | if self.abort_event is not None and self.abort_event.is_set(): 315 | return False 316 | return True 317 | 318 | 319 | class EulerEDMSampler(GradioTrackedSampler): 320 | def __init__( 321 | self, 322 | discretization: Discretization, 323 | guider: VanillaCFG | MultiviewCFG | MultiviewTemporalCFG, 324 | num_steps: int | None = None, 325 | verbose: bool = False, 326 | device: str | torch.device = "cuda", 327 | s_churn=0.0, 328 | s_tmin=0.0, 329 | s_tmax=float("inf"), 330 | s_noise=1.0, 331 | **kwargs, 332 | ): 333 | super().__init__(**kwargs) 334 | self.num_steps = num_steps 335 | self.discretization = discretization 336 | self.guider = guider 337 | self.verbose = verbose 338 | self.device = device 339 | 340 | self.s_churn = s_churn 341 | self.s_tmin = s_tmin 342 | self.s_tmax = s_tmax 343 | self.s_noise = s_noise 344 | 345 | def prepare_sampling_loop( 346 | self, x: torch.Tensor, cond: dict, uc: dict, num_steps: int | None = None 347 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, dict, dict]: 348 | num_steps = num_steps or self.num_steps 349 | assert num_steps is not None, "num_steps must be specified" 350 | sigmas = self.discretization(num_steps, device=self.device) 351 | x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) 352 | num_sigmas = len(sigmas) 353 | s_in = x.new_ones([x.shape[0]]) 354 | return x, s_in, sigmas, num_sigmas, cond, uc 355 | 356 | def get_sigma_gen(self, num_sigmas: int, verbose: bool = True) -> range | tqdm: 357 | sigma_generator = range(num_sigmas - 1) 358 | if self.verbose and verbose: 359 | sigma_generator = tqdm( 360 | sigma_generator, 361 | total=num_sigmas - 1, 362 | desc="Sampling", 363 | leave=False, 364 | ) 365 | return sigma_generator 366 | 367 | def sampler_step( 368 | self, 369 | sigma: torch.Tensor, 370 | next_sigma: torch.Tensor, 371 | denoiser, 372 | x: torch.Tensor, 373 | scale: float | torch.Tensor, 374 | cond: dict, 375 | uc: dict, 376 | gamma: float = 0.0, 377 | **guider_kwargs, 378 | ) -> torch.Tensor: 379 | sigma_hat = sigma * (gamma + 1.0) + 1e-6 380 | 381 | eps = torch.randn_like(x) * self.s_noise 382 | x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 383 | 384 | denoised = denoiser(*self.guider.prepare_inputs(x, sigma_hat, cond, uc)) 385 | denoised = self.guider(denoised, sigma_hat, scale, **guider_kwargs) 386 | d = to_d(x, sigma_hat, denoised) 387 | dt = append_dims(next_sigma - sigma_hat, x.ndim) 388 | return x + dt * d 389 | 390 | def __call__( 391 | self, 392 | denoiser, 393 | x: torch.Tensor, 394 | scale: float | torch.Tensor, 395 | cond: dict, 396 | uc: dict | None = None, 397 | num_steps: int | None = None, 398 | verbose: bool = True, 399 | global_pbar: gr.Progress | None = None, 400 | **guider_kwargs, 401 | ) -> torch.Tensor: 402 | uc = cond if uc is None else uc 403 | x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( 404 | x, 405 | cond, 406 | uc, 407 | num_steps, 408 | ) 409 | for i in self.get_sigma_gen(num_sigmas, verbose=verbose): 410 | gamma = ( 411 | min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) 412 | if self.s_tmin <= sigmas[i] <= self.s_tmax 413 | else 0.0 414 | ) 415 | x = self.sampler_step( 416 | s_in * sigmas[i], 417 | s_in * sigmas[i + 1], 418 | denoiser, 419 | x, 420 | scale, 421 | cond, 422 | uc, 423 | gamma, 424 | **guider_kwargs, 425 | ) 426 | if not self.possibly_update_pbar(global_pbar): 427 | return None 428 | return x 429 | -------------------------------------------------------------------------------- /seva/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import safetensors.torch 4 | import torch 5 | from huggingface_hub import hf_hub_download 6 | 7 | from seva.model import Seva, SevaParams 8 | 9 | 10 | def seed_everything(seed: int = 0): 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | torch.backends.cudnn.deterministic = True 15 | torch.backends.cudnn.benchmark = False 16 | 17 | 18 | def print_load_warning(missing: list[str], unexpected: list[str]) -> None: 19 | if len(missing) > 0 and len(unexpected) > 0: 20 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 21 | print("\n" + "-" * 79 + "\n") 22 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 23 | elif len(missing) > 0: 24 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 25 | elif len(unexpected) > 0: 26 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 27 | 28 | 29 | def load_model( 30 | model_version: float = 1.1, 31 | pretrained_model_name_or_path: str = "stabilityai/stable-virtual-camera", 32 | weight_name: str = "model.safetensors", 33 | device: str | torch.device = "cuda", 34 | verbose: bool = False, 35 | ) -> Seva: 36 | if os.path.isdir(pretrained_model_name_or_path): 37 | weight_path = os.path.join(pretrained_model_name_or_path, weight_name) 38 | else: 39 | if model_version > 1: 40 | base, ext = os.path.splitext(weight_name) 41 | weight_name = f"{base}v{model_version}{ext}" 42 | weight_path = hf_hub_download( 43 | repo_id=pretrained_model_name_or_path, filename=weight_name 44 | ) 45 | _ = hf_hub_download( 46 | repo_id=pretrained_model_name_or_path, filename="config.yaml" 47 | ) 48 | 49 | state_dict = safetensors.torch.load_file( 50 | weight_path, 51 | device=str(device), 52 | ) 53 | 54 | with torch.device("meta"): 55 | model = Seva(SevaParams()).to(torch.bfloat16) 56 | 57 | missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True) 58 | if verbose: 59 | print_load_warning(missing, unexpected) 60 | return model 61 | --------------------------------------------------------------------------------