├── .figs
├── pani-badge.svg
├── preview.png
└── thumbnails
│ ├── 360Beach.png
│ ├── 360Garden.png
│ ├── 360Road.png
│ ├── 360Siegen.png
│ ├── Beppu.png
│ ├── BikeRacks.png
│ ├── BikeShelf.png
│ ├── BluePit.png
│ ├── BluePlane.png
│ ├── Bridge.png
│ ├── CatBar.png
│ ├── CityCars.png
│ ├── Construction.png
│ ├── Convocation.png
│ ├── DarkDistillery.png
│ ├── DarkPeace.png
│ ├── DarkShrine.png
│ ├── DarkTruck.png
│ ├── Eiffel.png
│ ├── Escalatosaur.png
│ ├── Fireworks.png
│ ├── Fukuoka.png
│ ├── GlassGarden.png
│ ├── LanternDeer.png
│ ├── MellonDoor.png
│ ├── MountainTop.png
│ ├── NaraCity.png
│ ├── Ocean.png
│ ├── ParisCity.png
│ ├── PlaneHall.png
│ ├── PondHouse.png
│ ├── RainyPath.png
│ ├── RedPit.png
│ ├── RedShrine.png
│ ├── River.png
│ ├── RockStream.png
│ ├── Seafood.png
│ ├── ShinyPlane.png
│ ├── ShinySticks.png
│ ├── SnowTree.png
│ ├── Stalls.png
│ ├── StatueLeft.png
│ ├── StatueRight.png
│ ├── Tenjin.png
│ ├── Tigers.png
│ ├── Toronto.png
│ ├── UniversityCollege.png
│ ├── Vending.png
│ ├── Waterfall.png
│ ├── WoodOffice.png
│ └── thumbnail.png
├── LICENSE
├── README.md
├── checkpoints
└── __init__.py
├── config
├── config_large.json
├── config_medium.json
├── config_small.json
├── config_tiny.json
└── config_ultrakill.json
├── data
└── __init__.py
├── lightning_logs
└── __init__.py
├── outputs
└── __init__.py
├── render.py
├── requirements.txt
├── train.py
├── tutorial.ipynb
└── utils
└── utils.py
/.figs/pani-badge.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
69 |
--------------------------------------------------------------------------------
/.figs/preview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/preview.png
--------------------------------------------------------------------------------
/.figs/thumbnails/360Beach.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/360Beach.png
--------------------------------------------------------------------------------
/.figs/thumbnails/360Garden.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/360Garden.png
--------------------------------------------------------------------------------
/.figs/thumbnails/360Road.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/360Road.png
--------------------------------------------------------------------------------
/.figs/thumbnails/360Siegen.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/360Siegen.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Beppu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Beppu.png
--------------------------------------------------------------------------------
/.figs/thumbnails/BikeRacks.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/BikeRacks.png
--------------------------------------------------------------------------------
/.figs/thumbnails/BikeShelf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/BikeShelf.png
--------------------------------------------------------------------------------
/.figs/thumbnails/BluePit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/BluePit.png
--------------------------------------------------------------------------------
/.figs/thumbnails/BluePlane.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/BluePlane.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Bridge.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Bridge.png
--------------------------------------------------------------------------------
/.figs/thumbnails/CatBar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/CatBar.png
--------------------------------------------------------------------------------
/.figs/thumbnails/CityCars.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/CityCars.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Construction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Construction.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Convocation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Convocation.png
--------------------------------------------------------------------------------
/.figs/thumbnails/DarkDistillery.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/DarkDistillery.png
--------------------------------------------------------------------------------
/.figs/thumbnails/DarkPeace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/DarkPeace.png
--------------------------------------------------------------------------------
/.figs/thumbnails/DarkShrine.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/DarkShrine.png
--------------------------------------------------------------------------------
/.figs/thumbnails/DarkTruck.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/DarkTruck.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Eiffel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Eiffel.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Escalatosaur.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Escalatosaur.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Fireworks.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Fireworks.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Fukuoka.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Fukuoka.png
--------------------------------------------------------------------------------
/.figs/thumbnails/GlassGarden.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/GlassGarden.png
--------------------------------------------------------------------------------
/.figs/thumbnails/LanternDeer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/LanternDeer.png
--------------------------------------------------------------------------------
/.figs/thumbnails/MellonDoor.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/MellonDoor.png
--------------------------------------------------------------------------------
/.figs/thumbnails/MountainTop.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/MountainTop.png
--------------------------------------------------------------------------------
/.figs/thumbnails/NaraCity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/NaraCity.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Ocean.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Ocean.png
--------------------------------------------------------------------------------
/.figs/thumbnails/ParisCity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/ParisCity.png
--------------------------------------------------------------------------------
/.figs/thumbnails/PlaneHall.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/PlaneHall.png
--------------------------------------------------------------------------------
/.figs/thumbnails/PondHouse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/PondHouse.png
--------------------------------------------------------------------------------
/.figs/thumbnails/RainyPath.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/RainyPath.png
--------------------------------------------------------------------------------
/.figs/thumbnails/RedPit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/RedPit.png
--------------------------------------------------------------------------------
/.figs/thumbnails/RedShrine.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/RedShrine.png
--------------------------------------------------------------------------------
/.figs/thumbnails/River.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/River.png
--------------------------------------------------------------------------------
/.figs/thumbnails/RockStream.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/RockStream.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Seafood.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Seafood.png
--------------------------------------------------------------------------------
/.figs/thumbnails/ShinyPlane.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/ShinyPlane.png
--------------------------------------------------------------------------------
/.figs/thumbnails/ShinySticks.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/ShinySticks.png
--------------------------------------------------------------------------------
/.figs/thumbnails/SnowTree.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/SnowTree.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Stalls.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Stalls.png
--------------------------------------------------------------------------------
/.figs/thumbnails/StatueLeft.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/StatueLeft.png
--------------------------------------------------------------------------------
/.figs/thumbnails/StatueRight.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/StatueRight.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Tenjin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Tenjin.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Tigers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Tigers.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Toronto.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Toronto.png
--------------------------------------------------------------------------------
/.figs/thumbnails/UniversityCollege.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/UniversityCollege.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Vending.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Vending.png
--------------------------------------------------------------------------------
/.figs/thumbnails/Waterfall.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Waterfall.png
--------------------------------------------------------------------------------
/.figs/thumbnails/WoodOffice.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/WoodOffice.png
--------------------------------------------------------------------------------
/.figs/thumbnails/thumbnail.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/thumbnail.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 princeton-computational-imaging
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Neural Light Spheres for Implicit Image Stitching and View Synthesis
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | This is the official code repository for the SIGGRAPH Asia 2024 work: [Neural Light Spheres for Implicit Image Stitching and View Synthesis](https://light.princeton.edu/publication/neuls/). If you use parts of this work, or otherwise take inspiration from it, please considering citing our paper:
11 | ```
12 | @inproceedings{chugunov2024light,
13 | author = {Chugunov, Ilya and Joshi, Amogh and Murthy, Kiran and Bleibel, Francois and Heide, Felix},
14 | title = {Neural Light Spheres for {Implicit Image Stitching and View Synthesis}},
15 | booktitle = {Proceedings of the ACM SIGGRAPH Asia 2024},
16 | year = {2024},
17 | publisher = {ACM},
18 | doi = {10.1145/3680528.3687660},
19 | url = {https://doi.org/10.1145/3680528.3687660}
20 | }
21 | ```
22 |
23 | ## Requirements:
24 | - Code was written in PyTorch 2.2.1 and Pytorch Lightning 2.0.1 on an Ubuntu 22.04 machine.
25 | - Condensed package requirements are in `\requirements.txt`. Note that this contains the exact package versions at the time of publishing. Code will most likely work with newer versions of the libraries, but you will need to watch out for changes in class/function calls.
26 | - The non-standard packages you may need are `pytorch_lightning`, `commentjson`, `rawpy`, `pygame`, and `tinycudann`. See [NVlabs/tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn) for installation instructions. Depending on your system you might just be able to do `pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch`, or might have to cmake and build it from source.
27 |
28 | ## Project Structure:
29 | ```cpp
30 | NSF
31 | ├── checkpoints
32 | │ └── // folder for network checkpoints
33 | ├── config
34 | │ └── // network and encoding configurations for different sizes of MLPs
35 | ├── data
36 | │ └── // folder for image sequence data
37 | ├── lightning_logs
38 | │ └── // folder for tensorboard logs
39 | ├── outputs
40 | │ └── // folder for model outputs (e.g., final reconstructions)
41 | ├── utils
42 | │ └── utils.py // model helper functions (e.g., RAW demosaicing, quaternion math)
43 | ├── LICENSE // legal stuff
44 | ├── README.md // <- recursion errors
45 | ├── render.py // interactive render demo
46 | ├── requirements.txt // frozen package requirements
47 | ├── train.py // dataloader, network, visualization, and trainer code
48 | └── tutorial.ipynb // interactive tutorial for training the model
49 | ```
50 | ## Getting Started:
51 | We highly recommend you start by going through `tutorial.ipynb`, either on your own machine or [with this Google Colab link](https://colab.research.google.com/github/princeton-computational-imaging/NeuLS/blob/main/tutorial.ipynb).
52 |
53 | TLDR: models can be trained with:
54 |
55 | `python3 train.py --data_path {path_to_data} --name {checkpoint_name}`
56 |
57 | For a full list of training arguments, we recommend looking through the argument parser section at the bottom of `\train.py`.
58 |
59 | The final model checkpoint will be saved to `checkpoints/{checkpoint_name}/last.ckpt`
60 |
61 | And you can launch an interactive demo for the scene via:
62 |
63 | `python3 render.py --checkpoint {path_to_checkpoint_folder}`
64 |
65 |
66 |

67 |
68 |
69 | Press `P` to turn the view-dependent color model off/on, `O` to turn the ray offset model off/on, `+` and `-` to raise and lower the image brightness, click and drag to rotate the camera, hold shift and click and drag to translate the camera, scroll to zoom in/out, press `R` to reset the view, and `Escape` to quit.
70 |
71 |
72 | ## Data:
73 | This models trains on data recorded by our Android RAW capture app [Pani](https://github.com/Ilya-Muromets/Pani), see that repo for details on collecting/processing your own captures.
74 |
75 | You can download the full `30gb` dataset via [this link](https://soap.cs.princeton.edu/neuls/neuls_data.zip), or the individual scenes from the paper via following links:
76 |
77 | | Image | Name | Num Frames | Camera Lens | ISO | Exposure Time (s) | Aperture | Focal Length | Height | Width |
78 | |-------|------|------------|-------------|-----|-------------------|----------|--------------|--------|-------|
79 | |
| 360Road | 77 | main | 21 | 1/1110 | 1.68 | 6.9 | 3072 | 4080 |
80 | |
| Beppu | 40 | main | 21 | 1/630 | 1.68 | 6.9 | 3072 | 4080 |
81 | |
| BikeRacks | 48 | main | 21 | 1/1960 | 1.68 | 6.9 | 3072 | 4080 |
82 | |
| BikeShelf | 37 | main | 8065 | 1/100 | 1.68 | 6.9 | 3072 | 4080 |
83 | |
| BluePit | 32 | main | 21 | 1/199 | 1.68 | 6.9 | 3072 | 4080 |
84 | |
| BluePlane | 59 | main | 1005 | 1/120 | 1.68 | 6.9 | 3072 | 4080 |
85 | |
| Bridge | 49 | main | 21 | 1/1384 | 1.68 | 6.9 | 3072 | 4080 |
86 | |
| CityCars | 46 | main | 21 | 1/1744 | 1.68 | 6.9 | 3072 | 4080 |
87 | |
| Construction | 53 | main | 21 | 1/1653 | 1.68 | 6.9 | 3072 | 4080 |
88 | |
| DarkDistillery | 57 | main | 10667 | 1/100 | 1.68 | 6.9 | 3072 | 4080 |
89 | |
| DarkPeace | 51 | main | 10667 | 1/100 | 1.68 | 6.9 | 3072 | 4080 |
90 | |
| DarkShrine | 34 | main | 5065 | 1/80 | 1.68 | 6.9 | 3072 | 4080 |
91 | |
| DarkTruck | 43 | main | 10667 | 1/60 | 1.68 | 6.9 | 3072 | 4080 |
92 | |
| Eiffel | 73 | main | 21 | 1/1183 | 1.68 | 6.9 | 3072 | 4080 |
93 | |
| Escalatosaur | 49 | main | 589 | 1/100 | 1.68 | 6.9 | 3072 | 4080 |
94 | |
| Fireworks | 78 | main | 5000 | 1/100 | 1.68 | 6.9 | 3072 | 4080 |
95 | |
| Fukuoka | 40 | main | 21 | 1/1312 | 1.68 | 6.9 | 3072 | 4080 |
96 | |
| LanternDeer | 34 | main | 42 | 1/103 | 1.68 | 6.9 | 3072 | 4080 |
97 | |
| MountainTop | 59 | main | 21 | 1/2405 | 1.68 | 6.9 | 3072 | 4080 |
98 | |
| Ocean | 44 | main | 110 | 1/127 | 1.68 | 6.9 | 3072 | 4080 |
99 | |
| ParisCity | 55 | main | 21 | 1/1265 | 1.68 | 6.9 | 3072 | 4080 |
100 | |
| PlaneHall | 77 | main | 1005 | 1/120 | 1.68 | 6.9 | 3072 | 4080 |
101 | |
| PondHouse | 51 | main | 21 | 1/684 | 1.68 | 6.9 | 3072 | 4080 |
102 | |
| RainyPath | 38 | main | 600 | 1/100 | 1.68 | 6.9 | 3072 | 4080 |
103 | |
| RedShrine | 40 | main | 21 | 1/499 | 1.68 | 6.9 | 3072 | 4080 |
104 | |
| RockStream | 31 | main | 21 | 1/1110 | 1.68 | 6.9 | 3072 | 4080 |
105 | |
| Seafood | 44 | main | 21 | 1/193 | 1.68 | 6.9 | 3072 | 4080 |
106 | |
| ShinyPlane | 50 | main | 21 | 1/210 | 1.68 | 6.9 | 3072 | 4080 |
107 | |
| ShinySticks | 37 | main | 21 | 1/1417 | 1.68 | 6.9 | 3072 | 4080 |
108 | |
| SnowTree | 42 | main | 21 | 1/320 | 1.68 | 6.9 | 3072 | 4080 |
109 | |
| Stalls | 52 | main | 49 | 1/79 | 1.68 | 6.9 | 3072 | 4080 |
110 | |
| StatueLeft | 22 | main | 805 | 1/60 | 1.68 | 6.9 | 3072 | 4080 |
111 | |
| StatueRight | 26 | main | 802 | 1/60 | 1.68 | 6.9 | 3072 | 4080 |
112 | |
| Tenjin | 36 | main | 602 | 1/100 | 1.68 | 6.9 | 3072 | 4080 |
113 | |
| Tigers | 42 | main | 507 | 1/200 | 1.68 | 6.9 | 3072 | 4080 |
114 | |
| Toronto | 31 | main | 21 | 1/1250 | 1.68 | 6.9 | 3072 | 4080 |
115 | |
| Vending | 42 | main | 21 | 1/352 | 1.68 | 6.9 | 3072 | 4080 |
116 | |
| WoodOffice | 83 | main | 206 | 1/120 | 1.68 | 6.9 | 3072 | 4080 |
117 | |
| GlassGarden | 59 | telephoto | 24 | 1/231 | 2.8 | 18.0 | 3024 | 4032 |
118 | |
| NaraCity | 54 | telephoto | 17 | 1/327 | 2.8 | 18.0 | 3024 | 4032 |
119 | |
| 360Beach | 56 | ultrawide | 42 | 1/3175 | 1.95 | 2.23 | 3000 | 4000 |
120 | |
| 360Garden | 77 | ultrawide | 41 | 1/1104 | 1.95 | 2.23 | 3000 | 4000 |
121 | |
| 360Siegen | 67 | ultrawide | 41 | 1/1029 | 1.95 | 2.23 | 3000 | 4000 |
122 | |
| CatBar | 37 | ultrawide | 88 | 1/110 | 1.95 | 2.23 | 3000 | 4000 |
123 | |
| Convocation | 43 | ultrawide | 41 | 1/2309 | 1.95 | 2.23 | 3000 | 4000 |
124 | |
| MellonDoor | 40 | ultrawide | 41 | 1/564 | 1.95 | 2.23 | 3000 | 4000 |
125 | |
| RedPit | 53 | ultrawide | 41 | 1/418 | 1.95 | 2.23 | 3000 | 4000 |
126 | |
| River | 65 | ultrawide | 48 | 1/78 | 1.95 | 2.23 | 3000 | 4000 |
127 | |
| UniversityCollege | 55 | ultrawide | 42 | 1/2177 | 1.95 | 2.23 | 3000 | 4000 |
128 | |
| Waterfall | 74 | ultrawide | 41 | 1/1621 | 1.95 | 2.23 | 3000 | 4000 |
129 |
130 | Glhf,
131 | Ilya
132 |
--------------------------------------------------------------------------------
/checkpoints/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/checkpoints/__init__.py
--------------------------------------------------------------------------------
/config/config_large.json:
--------------------------------------------------------------------------------
1 | {
2 | "encoding": {
3 | "otype": "Grid",
4 | "type": "Hash",
5 | "n_levels": 15,
6 | "n_features_per_level": 4,
7 | "log2_hashmap_size": 19,
8 | "base_resolution": 4,
9 | "per_level_scale": 1.61,
10 | "interpolation": "Linear"
11 | },
12 | "network": {
13 | "otype": "FullyFusedMLP",
14 | "activation": "ReLU",
15 | "output_activation": "None",
16 | "n_neurons": 128,
17 | "n_hidden_layers": 5
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/config/config_medium.json:
--------------------------------------------------------------------------------
1 | {
2 | "encoding": {
3 | "otype": "Grid",
4 | "type": "Hash",
5 | "n_levels": 12,
6 | "n_features_per_level": 4,
7 | "log2_hashmap_size": 16,
8 | "base_resolution": 4,
9 | "per_level_scale": 1.61,
10 | "interpolation": "Linear"
11 | },
12 | "network": {
13 | "otype": "FullyFusedMLP",
14 | "activation": "ReLU",
15 | "output_activation": "None",
16 | "n_neurons": 64,
17 | "n_hidden_layers": 5
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/config/config_small.json:
--------------------------------------------------------------------------------
1 | {
2 | "encoding": {
3 | "otype": "Grid",
4 | "type": "Hash",
5 | "n_levels": 8,
6 | "n_features_per_level": 4,
7 | "log2_hashmap_size": 16,
8 | "base_resolution": 4,
9 | "per_level_scale": 1.61,
10 | "interpolation": "Linear"
11 | },
12 | "network": {
13 | "otype": "FullyFusedMLP",
14 | "activation": "ReLU",
15 | "output_activation": "None",
16 | "n_neurons": 64,
17 | "n_hidden_layers": 3
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/config/config_tiny.json:
--------------------------------------------------------------------------------
1 | {
2 | "encoding": {
3 | "otype": "Grid",
4 | "type": "Hash",
5 | "n_levels": 6,
6 | "n_features_per_level": 4,
7 | "log2_hashmap_size": 12,
8 | "base_resolution": 4,
9 | "per_level_scale": 1.61,
10 | "interpolation": "Linear"
11 | },
12 | "network": {
13 | "otype": "FullyFusedMLP",
14 | "activation": "ReLU",
15 | "output_activation": "None",
16 | "n_neurons": 32,
17 | "n_hidden_layers": 3
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/config/config_ultrakill.json:
--------------------------------------------------------------------------------
1 | {
2 | "encoding": {
3 | "otype": "Grid",
4 | "type": "Hash",
5 | "n_levels": 17,
6 | "n_features_per_level": 4,
7 | "log2_hashmap_size": 20,
8 | "base_resolution": 4,
9 | "per_level_scale": 1.61,
10 | "interpolation": "Linear"
11 | },
12 | "network": {
13 | "otype": "CutlassMLP",
14 | "activation": "ReLU",
15 | "output_activation": "None",
16 | "n_neurons": 256,
17 | "n_hidden_layers": 5
18 | }
19 | }
20 |
21 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/data/__init__.py
--------------------------------------------------------------------------------
/lightning_logs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/lightning_logs/__init__.py
--------------------------------------------------------------------------------
/outputs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/outputs/__init__.py
--------------------------------------------------------------------------------
/render.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import torch
4 | import pygame
5 | import argparse
6 |
7 | from utils import utils
8 | from train import PanoModel, BundleDataset
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser(description="Interactive Rendering")
12 | parser.add_argument("-d", "--checkpoint", required=True, help="Path to the checkpoint file")
13 | return parser.parse_args()
14 |
15 | args = parse_args()
16 | chkpt_path = args.checkpoint
17 | chkpt_path = os.path.join(os.path.dirname(chkpt_path), "last.ckpt")
18 | data_path = os.path.join(os.path.dirname(chkpt_path), "data.pkl")
19 | model = PanoModel.load_from_checkpoint(chkpt_path, device="cuda", cached_data=data_path)
20 | model = model.to("cuda")
21 | model = model.eval()
22 | model.args.no_offset = False
23 | model.args.no_view_color = False
24 |
25 | class InteractiveWindow:
26 | def __init__(self):
27 | self.offset = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float32).to("cuda")
28 | self.fov_factor = 1.0
29 | self.width = 1440
30 | self.height = 1080
31 | self.t = 0.5
32 | self.brightness = 1
33 | self.t_tensor = torch.full((self.width * self.height,), self.t, device="cuda", dtype=torch.float32)
34 |
35 | self.generate_camera()
36 |
37 |
38 | pygame.init()
39 | pygame.font.init() # Initialize font module
40 | self.font = pygame.font.SysFont(None, 36) # Set up the font
41 |
42 | self.screen = pygame.display.set_mode((self.width, self.height), pygame.DOUBLEBUF | pygame.HWSURFACE | pygame.SRCALPHA)
43 | pygame.display.set_caption("NeuLS Interactive Rendering")
44 |
45 | self.last_mouse_position = None
46 | self.update_image()
47 |
48 | def generate_camera(self):
49 | self.intrinsics_inv = model.data.intrinsics_inv[int(self.t * model.args.num_frames - 1)].clone()
50 | self.intrinsics_inv[0] *= 1.77 # Widen FOV
51 | self.intrinsics_inv = self.intrinsics_inv[None, :, :].repeat(self.width * self.height, 1, 1).to("cuda")
52 |
53 | self.quaternion_camera_to_world = model.data.quaternion_camera_to_world[int(self.t * model.args.num_frames - 1)].to("cuda")
54 | self.quaternion_camera_to_world_offset = self.quaternion_camera_to_world.clone()
55 | self.camera_to_world = model.model_rotation(self.quaternion_camera_to_world, self.t_tensor).to("cuda")
56 |
57 | self.ray_origins = model.model_translation(self.t_tensor, 1.0)
58 |
59 | self.uv = utils.make_grid(self.height, self.width, [0, 1], [0, 1]).to("cuda")
60 | self.ray_directions = model.generate_ray_directions(self.uv, self.camera_to_world, self.intrinsics_inv)
61 |
62 |
63 |
64 | def generate_image(self):
65 | with torch.no_grad():
66 | self.camera_to_world = utils.convert_quaternions_to_rot(self.quaternion_camera_to_world_offset).repeat(self.width * self.height, 1, 1).to("cuda")
67 | self.ray_directions = model.generate_ray_directions(self.uv * self.fov_factor + 0.5 * (1 - self.fov_factor), self.camera_to_world, self.intrinsics_inv)
68 |
69 | rgb_transmission = model.inference(self.t_tensor, self.uv * self.fov_factor + 0.5 * (1 - self.fov_factor), self.ray_origins + self.offset, self.ray_directions, 1.0)
70 |
71 | rgb_transmission = model.color(rgb_transmission, self.height, self.width).permute(2, 1, 0)
72 | return (rgb_transmission ** 0.7 * 255 * self.brightness).clamp(0, 255).byte().cpu().numpy()
73 |
74 | def update_image(self):
75 | image = self.generate_image()
76 | surf = pygame.surfarray.make_surface(image)
77 | self.screen.blit(surf, (0, 0))
78 |
79 | # Render and display status text
80 | self.render_status_text()
81 |
82 | pygame.display.flip()
83 |
84 | def render_status_text(self):
85 | # Double the font size
86 | large_font = pygame.font.SysFont(None, int(self.font.get_height() * 1.4))
87 |
88 | view_color_status = "View-Dependent Color: On" if not model.args.no_view_color else "View-Dependent Color: Off"
89 | ray_offset_status = "Ray Offset: On" if not model.args.no_offset else "Ray Offset: Off"
90 |
91 | # Render the text in white
92 | view_color_text = large_font.render(view_color_status, True, (255, 255, 255))
93 | ray_offset_text = large_font.render(ray_offset_status, True, (255, 255, 255))
94 |
95 | # Calculate the size of the black box
96 | box_width = max(view_color_text.get_width(), ray_offset_text.get_width()) + 20
97 | box_height = view_color_text.get_height() + ray_offset_text.get_height() + 20
98 |
99 | # Position for the box
100 | box_x = 10
101 | box_y = self.height - box_height - 10
102 |
103 | # Draw the black box
104 | pygame.draw.rect(self.screen, (0, 0, 0), (box_x, box_y, box_width, box_height))
105 |
106 | # Draw the white text on top of the black box
107 | self.screen.blit(view_color_text, (box_x + 10, box_y + 10))
108 | self.screen.blit(ray_offset_text, (box_x + 10, box_y + view_color_text.get_height() + 10))
109 |
110 |
111 | def handle_keys(self, event):
112 | fov_step = 1.05
113 |
114 | if event.type == pygame.KEYDOWN:
115 | if event.key == pygame.K_o:
116 | model.args.no_offset = not model.args.no_offset
117 | elif event.key == pygame.K_p:
118 | model.args.no_view_color = not model.args.no_view_color
119 | elif event.key == pygame.K_EQUALS:
120 | self.brightness *= 1.05
121 | elif event.key == pygame.K_MINUS:
122 | self.brightness *= 0.9523
123 | elif event.key == pygame.K_LEFT:
124 | self.t = max(0, self.t - 0.01)
125 | self.t_tensor.fill_(self.t)
126 | self.generate_camera()
127 | elif event.key == pygame.K_RIGHT:
128 | self.t = min(1, self.t + 0.01)
129 | self.t_tensor.fill_(self.t)
130 | self.generate_camera()
131 | elif event.key == pygame.K_ESCAPE:
132 | pygame.quit()
133 | sys.exit()
134 | elif event.key == pygame.K_r:
135 | self.offset.zero_()
136 | self.quaternion_camera_to_world_offset = self.quaternion_camera_to_world.clone()
137 | self.fov_factor = 1.0
138 |
139 | def handle_mouse(self, event):
140 | shift_pressed = pygame.key.get_mods() & pygame.KMOD_SHIFT
141 |
142 | if event.type == pygame.MOUSEBUTTONDOWN and event.button == 1: # Left button
143 | self.last_mouse_position = pygame.mouse.get_pos()
144 | elif event.type == pygame.MOUSEMOTION and self.last_mouse_position is not None:
145 | x, y = pygame.mouse.get_pos()
146 | dx, dy = x - self.last_mouse_position[0], y - self.last_mouse_position[1]
147 | sensitivity = 0.001 * self.fov_factor * model.args.focal_compensation
148 |
149 | if event.buttons[0]: # Left button held down
150 | if shift_pressed:
151 | step = 0.5
152 | movement_camera_space = torch.tensor([-dx * step * sensitivity, dy * step * sensitivity, 0.0]).to("cuda")
153 | movement_world_space = self.camera_to_world[0] @ movement_camera_space
154 | self.offset[0] += movement_world_space
155 | else:
156 | pitch = torch.tensor(-dy * sensitivity, dtype=torch.float32).to("cuda")
157 | yaw = torch.tensor(-dx * sensitivity, dtype=torch.float32).to("cuda")
158 |
159 | # Compute the pitch and yaw quaternions (small rotations)
160 | pitch_quat = torch.tensor([torch.cos(pitch / 2), torch.sin(pitch / 2), 0, 0]).to("cuda")
161 | yaw_quat = torch.tensor([torch.cos(yaw / 2), 0, torch.sin(yaw / 2), 0]).to("cuda")
162 |
163 | # Combine the pitch and yaw quaternions by multiplying them
164 | rotation_quat = utils.quaternion_multiply(yaw_quat, pitch_quat)
165 |
166 | # Apply the rotation incrementally
167 | self.quaternion_camera_to_world_offset = utils.quaternion_multiply(
168 | self.quaternion_camera_to_world_offset, rotation_quat
169 | )
170 |
171 | # Normalize the quaternion to avoid drift
172 | self.quaternion_camera_to_world_offset = self.quaternion_camera_to_world_offset / torch.norm(self.quaternion_camera_to_world_offset)
173 |
174 |
175 | self.last_mouse_position = (x, y)
176 | elif event.type == pygame.MOUSEBUTTONUP and event.button == 1: # Left button
177 | self.last_mouse_position = None
178 | elif event.type == pygame.MOUSEWHEEL:
179 | fov_step = 1.02
180 | if event.y > 0: # Scroll up
181 | self.fov_factor /= fov_step
182 | elif event.y < 0: # Scroll down
183 | self.fov_factor *= fov_step
184 |
185 | def run(self):
186 | clock = pygame.time.Clock()
187 | while True:
188 | for event in pygame.event.get():
189 | if event.type == pygame.QUIT:
190 | pygame.quit()
191 | sys.exit()
192 | elif event.type in [pygame.MOUSEBUTTONDOWN, pygame.MOUSEBUTTONUP, pygame.MOUSEMOTION, pygame.MOUSEWHEEL]:
193 | self.handle_mouse(event)
194 | else:
195 | self.handle_keys(event)
196 |
197 | self.update_image()
198 | clock.tick(30) # Limit the frame rate to 30 FPS
199 |
200 | if __name__ == "__main__":
201 | window = InteractiveWindow()
202 | window.run()
203 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | commentjson==0.9.0
2 | ipython==8.21.0
3 | ipywidgets==8.0.6
4 | matplotlib==3.7.0
5 | numpy==1.24.3
6 | Pillow==11.0.0
7 | pygame==2.6.0
8 | pytorch_lightning==2.4.0
9 | torch==2.7.0
10 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import commentjson as json
3 | import numpy as np
4 | import os
5 | import re
6 | import pickle
7 |
8 | import tinycudann as tcnn
9 |
10 | from utils import utils
11 | from utils.utils import debatch
12 | import matplotlib.pyplot as plt
13 |
14 | import torch
15 | import torch.nn as nn
16 | from torch.nn import functional as F
17 | from torch.utils.data import Dataset
18 | from torch.utils.data import DataLoader
19 | import pytorch_lightning as pl
20 |
21 | #########################################################################################################
22 | ################################################ DATASET ################################################
23 | #########################################################################################################
24 |
25 | class BundleDataset(Dataset):
26 | def __init__(self, args, load_volume=False):
27 | self.args = args
28 | print("Loading from:", self.args.data_path)
29 |
30 | data = np.load(args.data_path, allow_pickle=True)
31 |
32 | self.characteristics = data['characteristics'].item() # camera characteristics
33 | self.motion = data['motion'].item()
34 | self.frame_timestamps = torch.tensor([data[f'raw_{i}'].item()['timestamp'] for i in range(data['num_raw_frames'])], dtype=torch.float64)
35 | self.motion_timestamps = torch.tensor(self.motion['timestamp'], dtype=torch.float64)
36 | self.num_frames = data['num_raw_frames'].item()
37 |
38 | # WXYZ quaternions, remove phase wraps (2pi jumps)
39 | self.quaternions = utils.unwrap_quaternions(torch.tensor(self.motion['quaternion']).float()) # T',4, has different timestamps from frames
40 | # flip x/y to match out convention:
41 | # +y towards selfie camera, +x towards buttons, +z towards scene
42 | self.quaternions[:,1] *= -1
43 | self.quaternions[:,2] *= -1
44 |
45 | self.quaternions = utils.multi_interp(self.frame_timestamps, self.motion_timestamps, self.quaternions) # interpolate to frame times
46 |
47 | self.reference_quaternion = self.quaternions[0:1]
48 | self.quaternion_camera_to_world = utils.quaternion_multiply(utils.quaternion_conjugate(self.reference_quaternion), self.quaternions)
49 |
50 | self.intrinsics = torch.tensor(np.array([data[f'raw_{i}'].item()['intrinsics'] for i in range(data['num_raw_frames'])])).float() # T,3,3
51 | # swap cx,cy -> landscape to portrait
52 | cx, cy = self.intrinsics[:, 2, 1].clone(), self.intrinsics[:, 2, 0].clone()
53 | self.intrinsics[:, 2, 0], self.intrinsics[:, 2, 1] = cx, cy
54 | # transpose to put cx,cy in right column
55 | self.intrinsics = self.intrinsics.transpose(1, 2)
56 | self.intrinsics_inv = torch.inverse(self.intrinsics)
57 |
58 | self.lens_distortion = torch.tensor(data['raw_0'].item()['lens_distortion'])
59 | self.tonemap_curve = torch.tensor(data['raw_0'].item()['tonemap_curve'], dtype=torch.float32)
60 | self.ccm = torch.tensor(data['raw_0'].item()['ccm'], dtype=torch.float32)
61 |
62 | self.img_channels = 3
63 | self.img_height = data['raw_0'].item()['width'] # rotated 90
64 | self.img_width = data['raw_0'].item()['height']
65 | self.rgb_volume = None # placeholder volume for fast loading
66 |
67 | # rolling shutter timing compensation, off by default (can bug for data not from a Pixel 8 Pro)
68 | self.rolling_shutter_skew = data['raw_0'].item()['android']['sensor.rollingShutterSkew'] / 1e9 # delay between top and bottom row readout, seconds
69 | self.rolling_shutter_skew_row = self.rolling_shutter_skew / (self.img_height - 1)
70 | self.row_timestamps = torch.zeros(len(self.frame_timestamps), self.img_height, dtype=torch.float64) # NxH
71 | for i, frame_timestamp in enumerate(self.frame_timestamps):
72 | for j in range(self.img_height):
73 | if args.rolling_shutter:
74 | self.row_timestamps[i,j] = frame_timestamp + j * self.rolling_shutter_skew_row
75 | else:
76 | self.row_timestamps[i,j] = frame_timestamp
77 |
78 |
79 | self.row_timestamps = self.row_timestamps - self.row_timestamps[0,0] # zero at start
80 | self.row_timestamps = self.row_timestamps/self.row_timestamps[-1,-1] # normalize to 0-1
81 |
82 | if args.frames is not None:
83 | # subsample frames
84 | self.num_frames = len(args.frames)
85 | self.frame_timestamps = self.frame_timestamps[args.frames]
86 | self.intrinsics = self.intrinsics[args.frames]
87 | self.intrinsics_inv = self.intrinsics_inv[args.frames]
88 | self.quaternions = self.quaternions[args.frames]
89 | self.reference_quaternion = self.quaternions[0:1]
90 | self.quaternion_camera_to_world = self.quaternion_camera_to_world[args.frames]
91 |
92 | if load_volume:
93 | self.load_volume()
94 |
95 | self.frame_batch_size = 2 * (self.args.point_batch_size // self.num_frames // 2) # nearest even cut
96 | self.point_batch_size = self.frame_batch_size * self.num_frames # nearest multiple of num_frames
97 | self.num_batches = self.args.num_batches
98 |
99 | self.training_phase = 0.0 # fraction of training complete
100 | print("Frame Count: ", self.num_frames)
101 |
102 | def load_volume(self):
103 | volume_path = self.args.data_path.replace("frame_bundle.npz", "rgb_volume.npy")
104 | if os.path.exists(volume_path):
105 | print("Loading cached volume from:", volume_path)
106 | self.rgb_volume = torch.tensor(np.load(volume_path)).float()
107 | else:
108 | data = dict(np.load(self.args.data_path, allow_pickle=True))
109 | utils.de_item(data)
110 | self.rgb_volume = (utils.raw_to_rgb(data)) # T,C,H,W
111 | if self.args.cache:
112 | print("Saving cached volume to:", volume_path)
113 | np.save(volume_path, self.rgb_volume.numpy())
114 |
115 | if self.args.max_percentile < 100: # cut off highlights (long-tail-distribution)
116 | self.clip = np.percentile(self.rgb_volume[0], self.args.max_percentile)
117 | self.rgb_volume = torch.clamp(self.rgb_volume, 0, self.clip)
118 | self.rgb_volume = self.rgb_volume/self.clip
119 | else:
120 | self.clip = 1.0
121 |
122 | self.mean = self.rgb_volume[0].mean()
123 | self.rgb_volume = (16 * (self.rgb_volume - self.mean)).to(torch.float16)
124 |
125 | if self.args.frames is not None:
126 | self.rgb_volume = self.rgb_volume[self.args.frames] # subsample frames
127 |
128 | self.img_height, self.img_width = self.rgb_volume.shape[2], self.rgb_volume.shape[3]
129 |
130 |
131 | def __len__(self):
132 | return self.num_batches # arbitrary as we continuously generate random samples
133 |
134 | def __getitem__(self, idx):
135 | uv = torch.rand((self.frame_batch_size * self.num_frames), 2) # uniform random in [0,1]
136 | uv = uv * torch.tensor([[self.img_width-1, self.img_height-1]]) # scale to image dimensions
137 | uv = uv.round() # quantize to pixels
138 | u,v = uv.unbind(-1)
139 |
140 | t = torch.zeros_like(uv[:,0:1])
141 | for frame in range(self.num_frames):
142 | # use row to index into row_timestamps
143 | t[frame * self.frame_batch_size:(frame + 1) * self.frame_batch_size,0] = self.row_timestamps[frame, v[frame * self.frame_batch_size:(frame + 1) * self.frame_batch_size].long()]
144 |
145 | uv = uv / torch.tensor([[self.img_width-1, self.img_height-1]]) # scale back to 0-1
146 |
147 | return self.generate_samples(t, uv)
148 |
149 | def generate_samples(self, t, uv):
150 | """ generate samples from dataset and camera parameters for training
151 | """
152 |
153 | # create frame_batch_size of quaterions for each frame
154 | quaternion_camera_to_world = (self.quaternion_camera_to_world[:self.num_frames]).repeat_interleave(self.frame_batch_size, dim=0)
155 | # create frame_batch_size of intrinsics for each frame
156 | intrinsics = (self.intrinsics[:self.num_frames]).repeat_interleave(self.frame_batch_size, dim=0)
157 | intrinsics_inv = (self.intrinsics_inv[:self.num_frames]).repeat_interleave(self.frame_batch_size, dim=0)
158 |
159 | # sample grid
160 | u,v = uv.unbind(-1)
161 | u, v = (u * (self.img_width - 1)).round().long(), (v * (self.img_height - 1)).round().long() # pixel coordinates
162 | u, v = torch.clamp(u, 0, self.img_width-1), torch.clamp(v, 0, self.img_height-1) # clamp to image bounds
163 | x, y = u, (self.img_height - 1) - v # convert to array coordinates
164 |
165 | if self.rgb_volume is not None:
166 | rgb_samples = []
167 | for frame in range(self.num_frames):
168 | frame_min, frame_max = frame * self.frame_batch_size, (frame + 1) * self.frame_batch_size
169 | rgb_samples.append(self.rgb_volume[frame,:,y[frame_min:frame_max],x[frame_min:frame_max]].permute(1,0))
170 | rgb_samples = torch.cat(rgb_samples, dim=0)
171 | else:
172 | rgb_samples = torch.zeros(self.frame_batch_size * self.num_frames, 3)
173 |
174 | return t, uv, quaternion_camera_to_world, intrinsics, intrinsics_inv, rgb_samples
175 |
176 | def sample_frame(self, frame, uv):
177 | """ sample frame [frame] at coordinates u,v
178 | """
179 |
180 | u,v = uv.unbind(-1)
181 | u, v = (u * self.img_width).round().long(), (v * self.img_height).round().long() # pixel coordinates
182 | u, v = torch.clamp(u, 0, self.img_width-1), torch.clamp(v, 0, self.img_height-1) # clamp to image bounds
183 | x, y = u, (self.img_height - 1) - v # convert to array coordinates
184 |
185 | if self.rgb_volume is not None:
186 | rgb_samples = self.rgb_volume[frame:frame+1,:,y,x] # frames x 3 x H x W volume
187 | rgb_samples = rgb_samples.permute(0,2,1).flatten(0,1) # point_batch_size x channels
188 | else:
189 | rgb_samples = torch.zeros(u.shape[0], 3)
190 |
191 | return rgb_samples
192 |
193 | #########################################################################################################
194 | ################################################ MODELS #################$###############################
195 | #########################################################################################################
196 |
197 | class RotationModel(nn.Module):
198 | def __init__(self, args):
199 | super().__init__()
200 | self.args = args
201 | self.rotations = nn.Parameter(torch.zeros(1, 3, self.args.num_frames, dtype=torch.float32), requires_grad=True)
202 |
203 | def forward(self, quaternion_camera_to_world, t):
204 | self.rotations.data[:, :, 0] = 0.0 # zero out first frame's rotation
205 |
206 | rotations = utils.interpolate_params(self.rotations, t)
207 | rotations = self.args.rotation_weight * rotations
208 | rx, ry, rz = rotations[:, 0], rotations[:, 1], rotations[:, 2]
209 | r1 = torch.ones_like(rx)
210 |
211 | rotation_offsets = torch.stack([torch.stack([r1, -rz, ry], dim=-1),
212 | torch.stack([rz, r1, -rx], dim=-1),
213 | torch.stack([-ry, rx, r1], dim=-1)], dim=-1)
214 |
215 | return rotation_offsets @ utils.convert_quaternions_to_rot(quaternion_camera_to_world)
216 |
217 | class TranslationModel(nn.Module):
218 | def __init__(self, args):
219 | super().__init__()
220 | self.args = args
221 | self.translations_coarse = nn.Parameter(torch.rand(1, 3, 7, dtype=torch.float32) * 1e-5, requires_grad=True)
222 | self.translations_fine = nn.Parameter(torch.rand(1, 3, args.num_frames, dtype=torch.float32) * 1e-5, requires_grad=True)
223 | self.center = nn.Parameter(torch.zeros(1, 3, dtype=torch.float32), requires_grad=True)
224 |
225 | def forward(self, t, training_phase=1.0):
226 | self.translations_coarse.data[:, :, 0] = 0.0 # zero out first frame's translation
227 | self.translations_fine.data[:, :, 0] = 0.0 # zero out first frame's translation
228 |
229 | if training_phase < 0.25:
230 | translation = utils.interpolate_params(self.translations_coarse, t)
231 | else:
232 | translation = utils.interpolate_params(self.translations_coarse, t) + utils.interpolate_params(self.translations_fine, t)
233 |
234 | return self.args.focal_compensation * self.args.translation_weight * (translation + 5 * self.center)
235 |
236 | class LightSphereModel(pl.LightningModule):
237 |
238 | def __init__(self, args):
239 | super().__init__()
240 | self.args = args
241 |
242 | encoding_offset_position_config = json.load(open(f"config/config_{args.encoding_offset_position_config}.json"))["encoding"]
243 | encoding_offset_angle_config = json.load(open(f"config/config_{args.encoding_offset_angle_config}.json"))["encoding"]
244 | network_offset_config = json.load(open(f"config/config_{args.network_offset_config}.json"))["network"]
245 |
246 | encoding_color_position_config = json.load(open(f"config/config_{args.encoding_color_position_config}.json"))["encoding"]
247 | encoding_color_angle_config = json.load(open(f"config/config_{args.encoding_color_angle_config}.json"))["encoding"]
248 | network_color_position_config = json.load(open(f"config/config_{args.network_color_position_config}.json"))["network"]
249 | network_color_angle_config = json.load(open(f"config/config_{args.network_color_angle_config}.json"))["network"]
250 |
251 | self.encoding_offset_position = tcnn.Encoding(n_input_dims=3, encoding_config=encoding_offset_position_config)
252 | self.encoding_offset_angle = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_offset_angle_config)
253 |
254 | self.network_offset = tcnn.Network(n_input_dims=self.encoding_offset_position.n_output_dims + self.encoding_offset_angle.n_output_dims,
255 | n_output_dims=3, network_config=network_offset_config)
256 |
257 | self.encoding_color_position = tcnn.Encoding(n_input_dims=3, encoding_config=encoding_color_position_config)
258 | self.encoding_color_angle = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_color_angle_config)
259 |
260 | self.network_color_position = tcnn.Network(n_input_dims=self.encoding_color_position.n_output_dims, n_output_dims=64, network_config=network_color_position_config)
261 | self.network_color_angle = tcnn.Network(n_input_dims=self.encoding_color_position.n_output_dims + self.encoding_color_angle.n_output_dims, n_output_dims=64, network_config=network_color_angle_config)
262 | self.network_color = nn.Linear(64, 3, dtype=torch.float32, bias=False) # faster than tcnn.Network
263 | self.network_color.weight.data = torch.rand_like(self.network_color.weight.data)
264 |
265 | self.initial_rgb = torch.nn.Parameter(data=torch.zeros([1,3], dtype=torch.float16), requires_grad=False)
266 |
267 | self.enc_feat_direction = None
268 | self.enc_offset_intersection = None
269 | self.enc_offset_direction = None
270 | self.enc_image_outer = None
271 |
272 | def mask(self, encoding, training_phase, initial):
273 | if self.args.no_mask:
274 | return encoding
275 | else:
276 | return utils.mask(encoding, training_phase, initial)
277 |
278 | @torch.jit.script
279 | def solve_sphere_crossings(ray_origins, ray_directions):
280 | # Coefficients for the quadratic equation
281 | b = 2 * torch.sum(ray_origins * ray_directions, dim=1)
282 | c = torch.sum(ray_origins**2, dim=1) - 1.0
283 |
284 | discriminant = b**2 - 4 * c
285 | sqrt_discriminant = torch.sqrt(discriminant)
286 | t = (-b + sqrt_discriminant) / 2
287 | intersections = ray_origins + t.unsqueeze(-1) * ray_directions
288 | return intersections
289 |
290 | def inference(self, t, uv, ray_origins, ray_directions, training_phase=1.0):
291 | # Slightly slimmed down version of forward() for inference
292 | uv = uv.clamp(0, 1)
293 |
294 | intersections_sphere = self.solve_sphere_crossings(ray_origins, ray_directions)
295 |
296 | if not self.args.no_offset:
297 | encoded_offset_position = self.encoding_offset_position((intersections_sphere + 1) / 2)
298 | encoded_offset_angle = self.encoding_offset_angle(uv)
299 | encoded_offset = torch.cat((encoded_offset_position, encoded_offset_angle), dim=1)
300 |
301 | offset = self.network_offset(encoded_offset).float()
302 | ray_directions_offset = ray_directions + torch.cross(offset, ray_directions, dim=1)
303 | intersections_sphere_offset = self.solve_sphere_crossings(ray_origins, ray_directions_offset)
304 | else:
305 | intersections_sphere_offset = intersections_sphere
306 |
307 | encoded_color_position = self.encoding_color_position((intersections_sphere_offset + 1) / 2)
308 | feat_color = self.network_color_position(encoded_color_position).float()
309 |
310 | if not self.args.no_view_color:
311 | encoded_color_angle = self.encoding_color_angle(uv)
312 | encoded_color = torch.cat((encoded_color_position, encoded_color_angle), dim=1)
313 | feat_color_angle = self.network_color_angle(encoded_color)
314 | feat_color = feat_color + feat_color_angle
315 |
316 | rgb = self.initial_rgb + self.network_color(feat_color)
317 | return rgb
318 |
319 | def forward(self, t, uv, ray_origins, ray_directions, training_phase):
320 | uv = uv.clamp(0,1)
321 |
322 | if training_phase < 0.2: # Apply random perturbation for training during early epochs
323 | factor = 0.015 * self.args.focal_compensation
324 | perturbation = torch.randn_like(ray_origins) * factor / 0.2 * max(0.0, 0.25 - training_phase)
325 | ray_origins = ray_origins + perturbation
326 |
327 | intersections_sphere = self.solve_sphere_crossings(ray_origins, ray_directions)
328 |
329 | if training_phase > 0.2 and not self.args.no_offset:
330 | encoded_offset_position = self.mask(self.encoding_offset_position((intersections_sphere + 1) / 2), training_phase, initial=0.2)
331 | encoded_offset_angle = self.mask(self.encoding_offset_angle(uv), training_phase, initial=0.5)
332 | encoded_offset = torch.cat((encoded_offset_position, encoded_offset_angle), dim=1)
333 |
334 | offset = self.network_offset(encoded_offset).float()
335 | ray_directions_offset = ray_directions + torch.cross(offset, ray_directions, dim=1) # linearized rotation
336 | intersections_sphere_offset = self.solve_sphere_crossings(ray_origins, ray_directions_offset)
337 | else:
338 | offset = torch.ones_like(ray_directions)
339 | intersections_sphere_offset = intersections_sphere
340 |
341 | encoded_color_position = self.mask(self.encoding_color_position((intersections_sphere_offset + 1) / 2), training_phase, initial=0.8)
342 | feat_color = self.network_color_position(encoded_color_position).float()
343 |
344 | if training_phase > 0.25 and not self.args.no_view_color:
345 | encoded_color_angle = self.mask(self.encoding_color_angle(uv), training_phase, initial=0.2)
346 | encoded_color = torch.cat((encoded_color_position, encoded_color_angle), dim=1)
347 | feat_color_angle = self.network_color_angle(encoded_color)
348 | feat_color = feat_color + feat_color_angle
349 | else:
350 | feat_color_angle = torch.zeros_like(feat_color)
351 |
352 | rgb = self.initial_rgb + self.network_color(feat_color)
353 | return rgb, offset, feat_color_angle
354 |
355 |
356 | class DistortionModel(pl.LightningModule):
357 | def __init__(self, args):
358 | super().__init__()
359 | self.args = args
360 |
361 | def forward(self, kappa):
362 | if self.args.no_lens_distortion: # no distortion
363 | return (kappa * 0.0).to(self.device)
364 | else:
365 | return kappa.to(self.device)
366 |
367 |
368 | #########################################################################################################
369 | ################################################ NETWORK ################################################
370 | #########################################################################################################
371 |
372 | class PanoModel(pl.LightningModule):
373 | def __init__(self, args, cached_data=None):
374 | super().__init__()
375 | # load network configs
376 |
377 | self.args = args
378 | if cached_data is None:
379 | self.data = BundleDataset(self.args)
380 | else:
381 | with open(cached_data, 'rb') as file:
382 | self.data = pickle.load(file)
383 |
384 | if args.frames is None:
385 | self.args.frames = list(range(self.data.num_frames))
386 | self.args.num_frames = self.data.num_frames
387 |
388 |
389 | # to account for varying focal lengths, scale camera/ray motion to match 82deg Pixel 8 Pro main lens
390 | self.args.focal_compensation = 1.9236 / (self.data.intrinsics[0,0,0].item()/self.data.intrinsics[0,0,2].item())
391 |
392 | self.model_translation = TranslationModel(self.args)
393 | self.model_rotation = RotationModel(self.args)
394 | self.model_distortion = DistortionModel(self.args)
395 | self.model_light_sphere = LightSphereModel(self.args)
396 |
397 | self.training_phase = 1.0
398 | self.save_hyperparameters()
399 |
400 | def load_volume(self):
401 | self.data.load_volume()
402 |
403 | def configure_optimizers(self):
404 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr, betas=[0.9,0.99], eps=1e-9, weight_decay=self.args.weight_decay)
405 | gamma = 1.0
406 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
407 |
408 | return [optimizer], [scheduler]
409 |
410 | def inference(self, *args, **kwargs):
411 | with torch.no_grad():
412 | return self.model_light_sphere.inference(*args, **kwargs)
413 |
414 | def forward(self, *args, **kwargs):
415 | return self.model_light_sphere(*args, **kwargs)
416 |
417 | def generate_ray_directions(self, uv, camera_to_world, intrinsics_inv):
418 | u, v = uv[:, 0:1] * self.data.img_width, uv[:, 1:2] * self.data.img_height
419 | uv1 = torch.cat([u, v, torch.ones_like(u)], dim=1) # N x 3
420 | # Transform pixel coordinates to camera coordinates
421 | xy1 = torch.bmm(intrinsics_inv, uv1.unsqueeze(2)).squeeze(2) # N x 3
422 | xy = xy1[:, 0:2]
423 |
424 |
425 | x, y = xy[:, 0:1], xy[:, 1:2]
426 | r2 = x**2 + y**2
427 | r4 = r2**2
428 | r6 = r4 * r2
429 |
430 | kappa1, kappa2, kappa3, kappa4, kappa5 = self.model_distortion(self.data.lens_distortion)
431 |
432 | # Apply radial distortion
433 | x_distorted = x * (1 + kappa1 * r2 + kappa2 * r4 + kappa3 * r6) + \
434 | 2 * kappa4 * x * y + kappa5 * (r2 + 2 * x**2)
435 | y_distorted = y * (1 + kappa1 * r2 + kappa2 * r4 + kappa3 * r6) + \
436 | 2 * kappa5 * x * y + kappa4 * (r2 + 2 * y**2)
437 |
438 | xy_distorted = torch.cat([x_distorted, y_distorted], dim=1)
439 |
440 | # Combine with z = 1 for direction calculation
441 | ray_directions_unrotated = torch.cat([xy_distorted, torch.ones_like(x)], dim=1) # N x 3
442 |
443 | ray_directions = torch.bmm(camera_to_world, ray_directions_unrotated.unsqueeze(2)).squeeze(2) # Apply camera rotation
444 |
445 | # Normalize ray directions
446 | ray_directions = ray_directions / ray_directions.norm(dim=1, keepdim=True)
447 |
448 | return ray_directions
449 |
450 | def training_step(self, train_batch, batch_idx):
451 | t, uv, quaternion_camera_to_world, intrinsics, intrinsics_inv, rgb_reference = debatch(train_batch) # collapse batch + point dimensions
452 |
453 | camera_to_world = self.model_rotation(quaternion_camera_to_world, t) # apply rotation offset
454 | ray_origins = self.model_translation(t, self.training_phase) # camera center in world coordinates
455 | ray_origins.clamp(-0.99,0.99) # bound within the sphere
456 | ray_directions = self.generate_ray_directions(uv, camera_to_world, intrinsics_inv)
457 |
458 | rgb, offset, feat_color_angle = self.forward(t, uv, ray_origins, ray_directions, self.training_phase)
459 |
460 | loss = 0.0
461 |
462 | rgb_loss = F.l1_loss(rgb, rgb_reference)
463 | self.log('loss/rgb', rgb_loss.mean())
464 | loss += rgb_loss.mean()
465 |
466 |
467 | factor = 1.2 # Adjust this parameter to control the bend of the curve (1 for linear, 2 for quadratic, etc.)
468 | normalized_epoch = (self.current_epoch + (batch_idx/self.args.num_batches)) / (self.args.max_epochs - 1)
469 |
470 | # Update the training_phase calculation with the factor
471 | self.training_phase = min(1.0, 0.05 + (normalized_epoch ** factor))
472 | self.data.training_phase = self.training_phase
473 |
474 | return loss
475 |
476 | def color_and_tone(self, rgb_samples, height, width):
477 | """ Apply CCM and tone mapping to raw samples
478 | """
479 |
480 | img = self.color(rgb_samples, height, width)
481 | img = utils.apply_tonemap(img, self.data.tonemap_curve.to(rgb_samples.device))
482 |
483 | return img.clamp(0,1)
484 |
485 | def color(self, rgb_samples, height, width):
486 | """ Apply CCM to raw samples
487 | """
488 |
489 | img = self.data.ccm.to(rgb_samples.device) @ (self.data.mean + rgb_samples.float()/16.0).T
490 | img = img.reshape(3, height, width)
491 |
492 | return img.clamp(0,1)
493 |
494 | @torch.no_grad()
495 | def chunk_forward(self, quaternion_camera_to_world, intrinsics_inv, t, uv, translation, chunk_size=1000000):
496 | """ Forward model with chunking to avoid OOM
497 | """
498 | total_elements = t.shape[0]
499 |
500 | for start_idx in range(0, total_elements, chunk_size):
501 | end_idx = min(start_idx + chunk_size, total_elements)
502 |
503 | camera_to_world_chunk = self.model_rotation(quaternion_camera_to_world[start_idx:end_idx], t[start_idx:end_idx])
504 | ray_directions_chunk = self.generate_ray_directions(uv[start_idx:end_idx], camera_to_world_chunk, intrinsics_inv[start_idx:end_idx])
505 |
506 | if translation is None:
507 | ray_origins_chunk = self.model_translation(t[start_idx:end_idx])
508 | else:
509 | ray_origins_chunk = torch.zeros_like(ray_directions_chunk) + translation
510 |
511 | chunk_outputs = self.forward(t[start_idx:end_idx], uv[start_idx:end_idx], ray_origins_chunk, ray_directions_chunk, self.training_phase)
512 |
513 | if start_idx == 0:
514 | num_outputs = len(chunk_outputs)
515 | final_outputs = [[] for _ in range(num_outputs)]
516 |
517 | for i, output in enumerate(chunk_outputs):
518 | final_outputs[i].append(output.cpu())
519 |
520 | final_outputs = tuple(torch.cat(outputs, dim=0) for outputs in final_outputs)
521 |
522 | return final_outputs
523 |
524 | @torch.no_grad()
525 | def generate_outputs(self, time, height=720, width=720, u_lims=[0,1], v_lims=[0,1], fov_scale=1.0, quaternion_camera_to_world=None, intrinsics_inv=None, translation=None, sensor_size=None):
526 |
527 | device = self.device
528 |
529 | uv = utils.make_grid(height, width, u_lims, v_lims)
530 | frame = int(time * (self.data.num_frames - 1))
531 | t = torch.tensor(time, dtype=torch.float32).repeat(uv.shape[0])[:,None] # num_points x 1
532 |
533 | rgb_reference = self.data.sample_frame(frame, uv) # reference rgb samples
534 |
535 | if intrinsics_inv is None :
536 | intrinsics_inv = self.data.intrinsics_inv[frame:frame+2] # 2 x 3 x 3
537 | if time <= 0 or time >= 1.0: # select exact frame timestamp
538 | intrinsics_inv = intrinsics_inv[0:1]
539 | else: # interpolate between frames
540 | fraction = time * (self.data.num_frames - 1) - frame
541 | intrinsics_inv = intrinsics_inv[0:1] * (1 - fraction) + intrinsics_inv[1:2] * fraction
542 |
543 | if quaternion_camera_to_world is None:
544 | quaternion_camera_to_world = self.data.quaternion_camera_to_world[frame:frame+2] # 2 x 3 x 3
545 | if time <= 0 or time >= 1.0:
546 | quaternion_camera_to_world = quaternion_camera_to_world[0:1]
547 | else: # interpolate between frames
548 | fraction = time * (self.data.num_frames - 1) - frame
549 | quaternion_camera_to_world = quaternion_camera_to_world[0:1] * (1 - fraction) + quaternion_camera_to_world[1:2] * fraction
550 |
551 | intrinsics_inv = intrinsics_inv.clone()
552 | intrinsics_inv[:,0] = intrinsics_inv[:,0] * fov_scale
553 | intrinsics_inv[:,1] = intrinsics_inv[:,1]
554 |
555 | intrinsics_inv = intrinsics_inv.repeat(uv.shape[0],1,1) # num_points x 3 x 3
556 | quaternion_camera_to_world = quaternion_camera_to_world.repeat(uv.shape[0],1) # num_points x 4
557 |
558 | quaternion_camera_to_world, intrinsics_inv, t, uv = quaternion_camera_to_world.to(device), intrinsics_inv.to(device), t.to(device), uv.to(device)
559 |
560 | rgb, offset, feat_color_angle = self.chunk_forward(quaternion_camera_to_world, intrinsics_inv, t, uv, translation, chunk_size=3000**2) # break into chunks to avoid OOM
561 |
562 | rgb_reference = self.color_and_tone(rgb_reference, height, width)
563 | rgb = self.color_and_tone(rgb, height, width)
564 |
565 | offset = offset.reshape(height, width, 3).float().permute(2,0,1)
566 |
567 | # Normalize the offset tensor along the axis
568 | offset = offset
569 | offset_img = utils.colorize_tensor(offset.mean(dim=0), vmin=offset.min(), vmax=offset.max(), cmap="RdYlBu")
570 |
571 | return rgb_reference, rgb, offset, offset_img
572 |
573 |
574 | def save_outputs(self, path, high_res=False):
575 | os.makedirs(f"outputs/{self.args.name + path}", exist_ok=True)
576 | if high_res:
577 | rgb_reference, rgb, offset, offset_img = model.generate_outputs(time=0, height=2560, width=1920)
578 | else:
579 | rgb_reference, rgb, offset, offset_img = model.generate_outputs(time=0, height=1080, width=810)
580 |
581 |
582 | #########################################################################################################
583 | ############################################### VALIDATION ##############################################
584 | #########################################################################################################
585 |
586 | class ValidationCallback(pl.Callback):
587 | def __init__(self):
588 | super().__init__()
589 | self.unlock = False
590 |
591 | def bright(self, rgb):
592 | return ((rgb / np.percentile(rgb, 95)) ** 0.7).clamp(0,1)
593 |
594 | def on_train_epoch_start(self, trainer, model):
595 | print(f"Training phase (0-1): {model.training_phase}")
596 |
597 | if model.current_epoch == 1:
598 | model.model_translation.translations_coarse.requires_grad_(True)
599 | model.model_translation.translations_fine.requires_grad_(True)
600 | model.model_rotation.rotations.requires_grad_(True)
601 |
602 | if model.args.fast: # skip tensorboarding except for beginning and end
603 | if model.current_epoch == model.args.max_epochs - 1 or model.current_epoch == 0:
604 | pass
605 | else:
606 | return
607 |
608 | for i, time in enumerate([0.2, 0.5, 0.8]):
609 | if model.args.hd:
610 | rgb_reference, rgb, offset, offset_img = model.generate_outputs(time=time, height=1080, width=1080, fov_scale=1.4)
611 | else:
612 | rgb_reference, rgb, offset, offset_img = model.generate_outputs(time=time, fov_scale=2.5)
613 |
614 | model.logger.experiment.add_image(f'pred/{i}_rgb_combined', rgb, global_step=trainer.global_step)
615 | model.logger.experiment.add_image(f'pred/{i}_rgb_combined_bright', self.bright(rgb), global_step=trainer.global_step)
616 | model.logger.experiment.add_image(f'pred/{i}_offset', offset_img, global_step=trainer.global_step)
617 |
618 |
619 | def on_train_start(self, trainer, model):
620 | pl.seed_everything(42) # the answer to life, the universe, and everything
621 |
622 | # initialize rgb as average color of first frame of data (minimize the amount the rgb models have to learn)
623 | model.model_light_sphere.initial_rgb.data = torch.mean(model.data.rgb_volume[0], dim=(1,2))[None,:].to(model.device).to(torch.float16)
624 |
625 | model.logger.experiment.add_text("args", str(model.args))
626 |
627 | for i, time in enumerate([0, 0.5, 1.0]):
628 | rgb_reference, rgb, offset, offset_img = model.generate_outputs(time=time, height=1080, width=810)
629 | model.logger.experiment.add_image(f'gt/{i}_rgb_reference', rgb_reference, global_step=trainer.global_step)
630 | model.logger.experiment.add_image(f'gt/{i}_rgb_reference_bright', self.bright(rgb_reference), global_step=trainer.global_step)
631 |
632 | model.training_phase = 0.05
633 | model.data.training_phase = model.training_phase
634 |
635 |
636 | def on_train_end(self, trainer, model):
637 | checkpoint_dir = os.path.join("checkpoints", model.args.name, "last.ckpt")
638 | data_dir = os.path.join("checkpoints", model.args.name, "data.pkl")
639 |
640 | os.makedirs(os.path.dirname(checkpoint_dir), exist_ok=True)
641 | checkpoint = trainer._checkpoint_connector.dump_checkpoint()
642 |
643 | # Forcibly remove optimizer states from the checkpoint
644 | if 'optimizer_states' in checkpoint:
645 | del checkpoint['optimizer_states']
646 |
647 | torch.save(checkpoint, checkpoint_dir)
648 |
649 | with open(data_dir, 'wb') as file:
650 | model.data.rgb_volume = None
651 | pickle.dump(model.data, file)
652 |
653 | if __name__ == "__main__":
654 |
655 | # argparse
656 | parser = argparse.ArgumentParser()
657 |
658 | # data
659 | parser.add_argument('--point_batch_size', type=int, default=2**18, help="Number of points to sample per dataloader index.")
660 | parser.add_argument('--num_batches', type=int, default=200, help="Number of training batches.")
661 | parser.add_argument('--max_percentile', type=float, default=99.5, help="Percentile of lightest pixels to cut.")
662 | parser.add_argument('--frames', type=str, help="Which subset of frames to use for training, e.g. 0,10,20,30,40")
663 |
664 | # model
665 | parser.add_argument('--rotation_weight', type=float, default=1e-2, help="Scale learned rotation.")
666 | parser.add_argument('--translation_weight', type=float, default=1e0, help="Scale learned translation.")
667 | parser.add_argument('--rolling_shutter', action='store_true', help="Use rolling shutter compensation.")
668 | parser.add_argument('--no_mask', action='store_true', help="Do not use mask.")
669 | parser.add_argument('--no_offset', action='store_true', help="Do not use ray offset model.")
670 | parser.add_argument('--no_view_color', action='store_true', help="Do not use view dependent color model.")
671 | parser.add_argument('--no_lens_distortion', action='store_true', help="Do not use lens distortion model.")
672 |
673 |
674 |
675 | # light sphere
676 | parser.add_argument('--encoding_offset_position_config', type=str, default="small", help="Encoding offset position configuration (tiny, small, medium, large, ultrakill).")
677 | parser.add_argument('--encoding_offset_angle_config', type=str, default="small", help="Encoding offset angle configuration (tiny, small, medium, large, ultrakill).")
678 | parser.add_argument('--network_offset_config', type=str, default="large", help="Network offset configuration (tiny, small, medium, large, ultrakill).")
679 |
680 | parser.add_argument('--encoding_color_position_config', type=str, default="large", help="Encoding color position configuration (tiny, small, medium, large, ultrakill).")
681 | parser.add_argument('--encoding_color_angle_config', type=str, default="small", help="Encoding color angle configuration (tiny, small, medium, large, ultrakill).")
682 | parser.add_argument('--network_color_position_config', type=str, default="large", help="Network color position configuration (tiny, small, medium, large, ultrakill).")
683 | parser.add_argument('--network_color_angle_config', type=str, default="large", help="Network color angle configuration (tiny, small, medium, large, ultrakill).")
684 |
685 | # training
686 | parser.add_argument('--data_path', '--d', type=str, required=True, help="Path to frame_bundle.npz")
687 | parser.add_argument('--name', type=str, required=True, help="Experiment name for logs and checkpoints.")
688 | parser.add_argument('--max_epochs', type=int, default=100, help="Number of training epochs.")
689 | parser.add_argument('--lr', type=float, default=5e-4, help="Learning rate.")
690 | parser.add_argument('--weight_decay', type=float, default=1e-5, help="Weight decay.")
691 | parser.add_argument('--save_video', action='store_true', help="Store training outputs at each epoch for visualization.")
692 | parser.add_argument('--num_workers', type=int, default=4, help="Number of dataloader workers.")
693 | parser.add_argument('--debug', action='store_true', help="Debug mode, only use 1 batch.")
694 | parser.add_argument('--fast', action='store_true', help="Fast mode.")
695 | parser.add_argument('--cache', action='store_true', help="Cache data.")
696 | parser.add_argument('--hd', action='store_true', help="Make tensorboard HD.")
697 |
698 |
699 | args = parser.parse_args()
700 | # parse plane args
701 | print(args)
702 | if args.frames is not None:
703 | args.frames = [int(x) for x in re.findall(r"[-+]?\d*\.\d+|\d+", args.frames)]
704 |
705 | # model
706 | model = PanoModel(args)
707 | model.load_volume()
708 |
709 | # dataset
710 | data = model.data
711 | train_loader = DataLoader(data, batch_size=1, num_workers=args.num_workers, shuffle=False, pin_memory=True, prefetch_factor=1)
712 |
713 | model.model_translation.translations_coarse.requires_grad_(False)
714 | model.model_translation.translations_fine.requires_grad_(False)
715 | model.model_rotation.rotations.requires_grad_(False)
716 |
717 | torch.set_float32_matmul_precision('medium')
718 |
719 | # training
720 | lr_callback = pl.callbacks.LearningRateMonitor()
721 | logger = pl.loggers.TensorBoardLogger(save_dir=os.getcwd(), version=args.name, name="lightning_logs")
722 | validation_callback = ValidationCallback()
723 | trainer = pl.Trainer(accelerator="gpu", devices=torch.cuda.device_count(), num_nodes=1, strategy="auto", max_epochs=args.max_epochs,
724 | logger=logger, callbacks=[validation_callback, lr_callback], enable_checkpointing=False, fast_dev_run=args.debug)
725 | trainer.fit(model, train_loader)
726 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | plt.switch_backend('Agg') # non-visual
7 |
8 | def interp(x, xp, fp):
9 | """
10 | Linear interpolation of values fp from known points xp to new points x.
11 | """
12 | indices = torch.searchsorted(xp, x).clamp(1, len(xp) - 1)
13 | x0, x1 = xp[indices - 1], xp[indices]
14 | y0, y1 = fp[indices - 1], fp[indices]
15 |
16 | return (y0 + (y1 - y0) * (x - x0) / (x1 - x0)).to(fp.dtype)
17 |
18 | def multi_interp(x, xp, fp):
19 | """
20 | Multi-dimensional linear interpolation of fp from xp to x along all axes.
21 | """
22 | if torch.is_tensor(fp):
23 | out = [interp(x, xp, fp[:, i]) for i in range(fp.shape[-1])]
24 | return torch.stack(out, dim=-1).to(fp.dtype)
25 | else:
26 | out = [np.interp(x, xp, fp[:, i]) for i in range(fp.shape[-1])]
27 | return np.stack(out, axis=-1).astype(fp.dtype)
28 |
29 | def interpolate_params(params, t):
30 | """
31 | Interpolate parameters over time t, linearly between frames.
32 | """
33 | num_frames = params.shape[-1]
34 | frame_number = t * (num_frames - 1)
35 | integer_frame = torch.floor(frame_number).long()
36 | fractional_frame = frame_number.to(params.dtype) - integer_frame.to(params.dtype)
37 |
38 | # Ensure indices are within valid range
39 | next_frame = torch.clamp(integer_frame + 1, 0, num_frames - 1)
40 | integer_frame = torch.clamp(integer_frame, 0, num_frames - 1)
41 |
42 | param_now = params[:, :, integer_frame]
43 | param_next = params[:, :, next_frame]
44 |
45 | # Linear interpolation between current and next frame parameters
46 | interpolated_params = param_now + fractional_frame * (param_next - param_now)
47 |
48 | return interpolated_params.squeeze(0).squeeze(-1).permute(1, 0)
49 |
50 | class MaskFunction(torch.autograd.Function):
51 | @staticmethod
52 | def forward(ctx, encoding, mask_coef, initial=0.4):
53 | """
54 | Forward pass for masking, scales mask_coef to blend encoding.
55 | """
56 | mask_coef = initial + (1 - initial) * mask_coef
57 | mask = torch.zeros_like(encoding[0:1])
58 | mask_ceil = int(np.ceil(mask_coef * encoding.shape[1]))
59 | mask[..., :mask_ceil] = 1.0
60 | ctx.save_for_backward(mask)
61 | return encoding * mask
62 |
63 | @staticmethod
64 | def backward(ctx, grad_output):
65 | """
66 | Backward pass to retain masked gradients.
67 | """
68 | mask, = ctx.saved_tensors
69 | return grad_output * mask, None, None
70 |
71 | def mask(encoding, mask_coef, initial=0.4):
72 | """
73 | Apply mask to encoding with scaling factor mask_coef.
74 | """
75 | return MaskFunction.apply(encoding, mask_coef, initial)
76 |
77 | def make_grid(height, width, u_lims, v_lims):
78 | """
79 | Create (u,v) meshgrid of size (height, width) with given limits.
80 | """
81 | u = torch.linspace(u_lims[0], u_lims[1], width)
82 | v = torch.linspace(v_lims[1], v_lims[0], height) # Flip for array convention
83 | u_grid, v_grid = torch.meshgrid([u, v], indexing="xy")
84 | return torch.stack((u_grid.flatten(), v_grid.flatten())).permute(1, 0)
85 |
86 | def unwrap_quaternions(q):
87 | """
88 | Remove 2pi wraps from quaternion rotations.
89 | """
90 | n = q.shape[0]
91 | unwrapped_q = q.clone()
92 | for i in range(1, n):
93 | cos_theta = torch.dot(unwrapped_q[i-1], unwrapped_q[i])
94 | if cos_theta < 0:
95 | unwrapped_q[i] = -unwrapped_q[i]
96 | return unwrapped_q
97 |
98 | @torch.jit.script
99 | def quaternion_conjugate(q):
100 | """
101 | Return the conjugate of a quaternion.
102 | """
103 | q_conj = q.clone()
104 | q_conj[:, 1:] = -q_conj[:, 1:] # Invert vector part
105 | return q_conj
106 |
107 | def quaternion_multiply(q1, q2):
108 | """
109 | Multiply two quaternions.
110 | """
111 | w1, v1 = q1[..., 0], q1[..., 1:]
112 | w2, v2 = q2[..., 0], q2[..., 1:]
113 | w = w1 * w2 - torch.sum(v1 * v2, dim=-1)
114 | v = w1.unsqueeze(-1) * v2 + w2.unsqueeze(-1) * v1 + torch.cross(v1, v2, dim=-1)
115 | return torch.cat((w.unsqueeze(-1), v), dim=-1)
116 |
117 | @torch.jit.script
118 | def convert_quaternions_to_rot(quaternions):
119 | """
120 | Convert quaternions (WXYZ) to 3x3 rotation matrices.
121 | """
122 | w, x, y, z = quaternions.unbind(-1)
123 | r00 = 1 - 2 * (y**2 + z**2)
124 | r01 = 2 * (x * y - z * w)
125 | r02 = 2 * (x * z + y * w)
126 | r10 = 2 * (x * y + z * w)
127 | r11 = 1 - 2 * (x**2 + z**2)
128 | r12 = 2 * (y * z - x * w)
129 | r20 = 2 * (x * z - y * w)
130 | r21 = 2 * (y * z + x * w)
131 | r22 = 1 - 2 * (x**2 + y**2)
132 | R = torch.stack([r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=-1)
133 | return R.reshape(-1, 3, 3)
134 |
135 | @torch.no_grad()
136 | def raw_to_rgb(bundle):
137 | """
138 | Convert RAW mosaic to three-channel RGB volume by filling empty pixels.
139 | """
140 | raw_frames = torch.tensor(np.array([bundle[f'raw_{i}']['raw'] for i in range(bundle['num_raw_frames'])]).astype(np.int32), dtype=torch.float32)[None]
141 | raw_frames = raw_frames.permute(1, 0, 2, 3)
142 | color_correction_gains = bundle['raw_0']['color_correction_gains']
143 | color_filter_arrangement = bundle['characteristics']['color_filter_arrangement']
144 | blacklevel = torch.tensor(np.array([bundle[f'raw_{i}']['blacklevel'] for i in range(bundle['num_raw_frames'])]))[:, :, None, None]
145 | whitelevel = torch.tensor(np.array([bundle[f'raw_{i}']['whitelevel'] for i in range(bundle['num_raw_frames'])]))[:, None, None, None]
146 | shade_maps = torch.tensor(np.array([bundle[f'raw_{i}']['shade_map'] for i in range(bundle['num_raw_frames'])])).permute(0, 3, 1, 2)
147 | shade_maps = F.interpolate(shade_maps, size=(raw_frames.shape[-2]//2, raw_frames.shape[-1]//2), mode='bilinear', align_corners=False)
148 |
149 | top_left = raw_frames[:, :, 0::2, 0::2]
150 | top_right = raw_frames[:, :, 0::2, 1::2]
151 | bottom_left = raw_frames[:, :, 1::2, 0::2]
152 | bottom_right = raw_frames[:, :, 1::2, 1::2]
153 |
154 | if color_filter_arrangement == 0: # RGGB
155 | R, G1, G2, B = top_left, top_right, bottom_left, bottom_right
156 | elif color_filter_arrangement == 1: # GRBG
157 | G1, R, B, G2 = top_left, top_right, bottom_left, bottom_right
158 | elif color_filter_arrangement == 2: # GBRG
159 | G1, B, R, G2 = top_left, top_right, bottom_left, bottom_right
160 | elif color_filter_arrangement == 3: # BGGR
161 | B, G1, G2, R = top_left, top_right, bottom_left, bottom_right
162 |
163 | # Apply color correction gains, flip to portrait
164 | R = ((R - blacklevel[:, 0:1]) / (whitelevel - blacklevel[:, 0:1]) * color_correction_gains[0])
165 | R *= shade_maps[:, 0:1]
166 | G1 = ((G1 - blacklevel[:, 1:2]) / (whitelevel - blacklevel[:, 1:2]) * color_correction_gains[1])
167 | G1 *= shade_maps[:, 1:2]
168 | G2 = ((G2 - blacklevel[:, 2:3]) / (whitelevel - blacklevel[:, 2:3]) * color_correction_gains[2])
169 | G2 *= shade_maps[:, 2:3]
170 | B = ((B - blacklevel[:, 3:4]) / (whitelevel - blacklevel[:, 3:4]) * color_correction_gains[3])
171 | B *= shade_maps[:, 3:4]
172 |
173 | rgb_volume = torch.zeros(raw_frames.shape[0], 3, raw_frames.shape[-2], raw_frames.shape[-1], dtype=torch.float32)
174 |
175 | # Fill gaps in blue channel
176 | rgb_volume[:, 2, 0::2, 0::2] = B.squeeze(1)
177 | rgb_volume[:, 2, 0::2, 1::2] = (B + torch.roll(B, -1, dims=3)).squeeze(1) / 2
178 | rgb_volume[:, 2, 1::2, 0::2] = (B + torch.roll(B, -1, dims=2)).squeeze(1) / 2
179 | rgb_volume[:, 2, 1::2, 1::2] = (B + torch.roll(B, -1, dims=2) + torch.roll(B, -1, dims=3) + torch.roll(B, [-1, -1], dims=[2, 3])).squeeze(1) / 4
180 |
181 | # Fill gaps in green channel
182 | rgb_volume[:, 1, 0::2, 0::2] = G1.squeeze(1)
183 | rgb_volume[:, 1, 0::2, 1::2] = (G1 + torch.roll(G1, -1, dims=3) + G2 + torch.roll(G2, 1, dims=2)).squeeze(1) / 4
184 | rgb_volume[:, 1, 1::2, 0::2] = (G1 + torch.roll(G1, -1, dims=2) + G2 + torch.roll(G2, 1, dims=3)).squeeze(1) / 4
185 | rgb_volume[:, 1, 1::2, 1::2] = G2.squeeze(1)
186 |
187 | # Fill gaps in red channel
188 | rgb_volume[:, 0, 0::2, 0::2] = R.squeeze(1)
189 | rgb_volume[:, 0, 0::2, 1::2] = (R + torch.roll(R, -1, dims=3)).squeeze(1) / 2
190 | rgb_volume[:, 0, 1::2, 0::2] = (R + torch.roll(R, -1, dims=2)).squeeze(1) / 2
191 | rgb_volume[:, 0, 1::2, 1::2] = (R + torch.roll(R, -1, dims=2) + torch.roll(R, -1, dims=3) + torch.roll(R, [-1, -1], dims=[2, 3])).squeeze(1) / 4
192 |
193 | rgb_volume = torch.flip(rgb_volume.transpose(-1, -2), [-1]) # Rotate 90 degrees to portrait
194 | return rgb_volume
195 |
196 | def de_item(bundle):
197 | """
198 | Call .item() on all dictionary items, removing extra dimensions.
199 | """
200 | bundle['motion'] = bundle['motion'].item()
201 | bundle['characteristics'] = bundle['characteristics'].item()
202 |
203 | for i in range(bundle['num_raw_frames']):
204 | bundle[f'raw_{i}'] = bundle[f'raw_{i}'].item()
205 |
206 | def debatch(batch):
207 | """
208 | Collapse batch and channel dimensions together.
209 | """
210 | debatched = []
211 | for x in batch:
212 | if len(x.shape) <= 1:
213 | raise Exception("This tensor is too small to debatch.")
214 | elif len(x.shape) == 2:
215 | debatched.append(x.reshape(x.shape[0] * x.shape[1]))
216 | else:
217 | debatched.append(x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]))
218 | return debatched
219 |
220 | def apply_ccm(image, ccm):
221 | """
222 | Apply Color Correction Matrix (CCM) to the image.
223 | """
224 | if image.dim() == 3:
225 | corrected_image = torch.einsum('ij,jkl->ikl', ccm, image)
226 | else:
227 | corrected_image = ccm @ image
228 | return corrected_image.clamp(0, 1)
229 |
230 | def apply_tonemap(image, tonemap):
231 | """
232 | Apply tonemapping curve to the image using custom linear interpolation.
233 | """
234 | toned_image = torch.empty_like(image)
235 | for i in range(3):
236 | x_vals = tonemap[i][:, 0].contiguous()
237 | y_vals = tonemap[i][:, 1].contiguous()
238 | toned_image[i] = interp(image[i], x_vals, y_vals)
239 | return toned_image
240 |
241 | def colorize_tensor(value, vmin=None, vmax=None, cmap=None, colorbar=False, height=9.6, width=7.2):
242 | """
243 | Convert tensor to 3-channel RGB array using colormap (similar to plt.imshow).
244 | """
245 | assert len(value.shape) == 2
246 | fig, ax = plt.subplots(1, 1)
247 | fig.set_size_inches(width, height)
248 | a = ax.imshow(value.detach().cpu(), vmin=vmin, vmax=vmax, cmap=cmap)
249 | ax.set_axis_off()
250 | if colorbar:
251 | cbar = plt.colorbar(a, fraction=0.05)
252 | cbar.ax.tick_params(labelsize=30)
253 | plt.tight_layout()
254 |
255 | # Convert figure to numpy array
256 | fig.canvas.draw()
257 | img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
258 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
259 | img = img / 255.0
260 |
261 | plt.close(fig)
262 | return torch.tensor(img).permute(2, 0, 1).float()
263 |
--------------------------------------------------------------------------------