├── .figs ├── pani-badge.svg ├── preview.png └── thumbnails │ ├── 360Beach.png │ ├── 360Garden.png │ ├── 360Road.png │ ├── 360Siegen.png │ ├── Beppu.png │ ├── BikeRacks.png │ ├── BikeShelf.png │ ├── BluePit.png │ ├── BluePlane.png │ ├── Bridge.png │ ├── CatBar.png │ ├── CityCars.png │ ├── Construction.png │ ├── Convocation.png │ ├── DarkDistillery.png │ ├── DarkPeace.png │ ├── DarkShrine.png │ ├── DarkTruck.png │ ├── Eiffel.png │ ├── Escalatosaur.png │ ├── Fireworks.png │ ├── Fukuoka.png │ ├── GlassGarden.png │ ├── LanternDeer.png │ ├── MellonDoor.png │ ├── MountainTop.png │ ├── NaraCity.png │ ├── Ocean.png │ ├── ParisCity.png │ ├── PlaneHall.png │ ├── PondHouse.png │ ├── RainyPath.png │ ├── RedPit.png │ ├── RedShrine.png │ ├── River.png │ ├── RockStream.png │ ├── Seafood.png │ ├── ShinyPlane.png │ ├── ShinySticks.png │ ├── SnowTree.png │ ├── Stalls.png │ ├── StatueLeft.png │ ├── StatueRight.png │ ├── Tenjin.png │ ├── Tigers.png │ ├── Toronto.png │ ├── UniversityCollege.png │ ├── Vending.png │ ├── Waterfall.png │ ├── WoodOffice.png │ └── thumbnail.png ├── LICENSE ├── README.md ├── checkpoints └── __init__.py ├── config ├── config_large.json ├── config_medium.json ├── config_small.json ├── config_tiny.json └── config_ultrakill.json ├── data └── __init__.py ├── lightning_logs └── __init__.py ├── outputs └── __init__.py ├── render.py ├── requirements.txt ├── train.py ├── tutorial.ipynb └── utils └── utils.py /.figs/pani-badge.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | Android App 45 | 46 | Android App 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /.figs/preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/preview.png -------------------------------------------------------------------------------- /.figs/thumbnails/360Beach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/360Beach.png -------------------------------------------------------------------------------- /.figs/thumbnails/360Garden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/360Garden.png -------------------------------------------------------------------------------- /.figs/thumbnails/360Road.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/360Road.png -------------------------------------------------------------------------------- /.figs/thumbnails/360Siegen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/360Siegen.png -------------------------------------------------------------------------------- /.figs/thumbnails/Beppu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Beppu.png -------------------------------------------------------------------------------- /.figs/thumbnails/BikeRacks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/BikeRacks.png -------------------------------------------------------------------------------- /.figs/thumbnails/BikeShelf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/BikeShelf.png -------------------------------------------------------------------------------- /.figs/thumbnails/BluePit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/BluePit.png -------------------------------------------------------------------------------- /.figs/thumbnails/BluePlane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/BluePlane.png -------------------------------------------------------------------------------- /.figs/thumbnails/Bridge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Bridge.png -------------------------------------------------------------------------------- /.figs/thumbnails/CatBar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/CatBar.png -------------------------------------------------------------------------------- /.figs/thumbnails/CityCars.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/CityCars.png -------------------------------------------------------------------------------- /.figs/thumbnails/Construction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Construction.png -------------------------------------------------------------------------------- /.figs/thumbnails/Convocation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Convocation.png -------------------------------------------------------------------------------- /.figs/thumbnails/DarkDistillery.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/DarkDistillery.png -------------------------------------------------------------------------------- /.figs/thumbnails/DarkPeace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/DarkPeace.png -------------------------------------------------------------------------------- /.figs/thumbnails/DarkShrine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/DarkShrine.png -------------------------------------------------------------------------------- /.figs/thumbnails/DarkTruck.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/DarkTruck.png -------------------------------------------------------------------------------- /.figs/thumbnails/Eiffel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Eiffel.png -------------------------------------------------------------------------------- /.figs/thumbnails/Escalatosaur.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Escalatosaur.png -------------------------------------------------------------------------------- /.figs/thumbnails/Fireworks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Fireworks.png -------------------------------------------------------------------------------- /.figs/thumbnails/Fukuoka.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Fukuoka.png -------------------------------------------------------------------------------- /.figs/thumbnails/GlassGarden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/GlassGarden.png -------------------------------------------------------------------------------- /.figs/thumbnails/LanternDeer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/LanternDeer.png -------------------------------------------------------------------------------- /.figs/thumbnails/MellonDoor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/MellonDoor.png -------------------------------------------------------------------------------- /.figs/thumbnails/MountainTop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/MountainTop.png -------------------------------------------------------------------------------- /.figs/thumbnails/NaraCity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/NaraCity.png -------------------------------------------------------------------------------- /.figs/thumbnails/Ocean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Ocean.png -------------------------------------------------------------------------------- /.figs/thumbnails/ParisCity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/ParisCity.png -------------------------------------------------------------------------------- /.figs/thumbnails/PlaneHall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/PlaneHall.png -------------------------------------------------------------------------------- /.figs/thumbnails/PondHouse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/PondHouse.png -------------------------------------------------------------------------------- /.figs/thumbnails/RainyPath.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/RainyPath.png -------------------------------------------------------------------------------- /.figs/thumbnails/RedPit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/RedPit.png -------------------------------------------------------------------------------- /.figs/thumbnails/RedShrine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/RedShrine.png -------------------------------------------------------------------------------- /.figs/thumbnails/River.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/River.png -------------------------------------------------------------------------------- /.figs/thumbnails/RockStream.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/RockStream.png -------------------------------------------------------------------------------- /.figs/thumbnails/Seafood.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Seafood.png -------------------------------------------------------------------------------- /.figs/thumbnails/ShinyPlane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/ShinyPlane.png -------------------------------------------------------------------------------- /.figs/thumbnails/ShinySticks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/ShinySticks.png -------------------------------------------------------------------------------- /.figs/thumbnails/SnowTree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/SnowTree.png -------------------------------------------------------------------------------- /.figs/thumbnails/Stalls.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Stalls.png -------------------------------------------------------------------------------- /.figs/thumbnails/StatueLeft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/StatueLeft.png -------------------------------------------------------------------------------- /.figs/thumbnails/StatueRight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/StatueRight.png -------------------------------------------------------------------------------- /.figs/thumbnails/Tenjin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Tenjin.png -------------------------------------------------------------------------------- /.figs/thumbnails/Tigers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Tigers.png -------------------------------------------------------------------------------- /.figs/thumbnails/Toronto.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Toronto.png -------------------------------------------------------------------------------- /.figs/thumbnails/UniversityCollege.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/UniversityCollege.png -------------------------------------------------------------------------------- /.figs/thumbnails/Vending.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Vending.png -------------------------------------------------------------------------------- /.figs/thumbnails/Waterfall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/Waterfall.png -------------------------------------------------------------------------------- /.figs/thumbnails/WoodOffice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/WoodOffice.png -------------------------------------------------------------------------------- /.figs/thumbnails/thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/.figs/thumbnails/thumbnail.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 princeton-computational-imaging 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Neural Light Spheres for Implicit Image Stitching and View Synthesis 3 | 4 | Open In Colab 5 | 6 | 7 | Android Capture App 8 | 9 | 10 | This is the official code repository for the SIGGRAPH Asia 2024 work: [Neural Light Spheres for Implicit Image Stitching and View Synthesis](https://light.princeton.edu/publication/neuls/). If you use parts of this work, or otherwise take inspiration from it, please considering citing our paper: 11 | ``` 12 | @inproceedings{chugunov2024light, 13 | author = {Chugunov, Ilya and Joshi, Amogh and Murthy, Kiran and Bleibel, Francois and Heide, Felix}, 14 | title = {Neural Light Spheres for {Implicit Image Stitching and View Synthesis}}, 15 | booktitle = {Proceedings of the ACM SIGGRAPH Asia 2024}, 16 | year = {2024}, 17 | publisher = {ACM}, 18 | doi = {10.1145/3680528.3687660}, 19 | url = {https://doi.org/10.1145/3680528.3687660} 20 | } 21 | ``` 22 | 23 | ## Requirements: 24 | - Code was written in PyTorch 2.2.1 and Pytorch Lightning 2.0.1 on an Ubuntu 22.04 machine. 25 | - Condensed package requirements are in `\requirements.txt`. Note that this contains the exact package versions at the time of publishing. Code will most likely work with newer versions of the libraries, but you will need to watch out for changes in class/function calls. 26 | - The non-standard packages you may need are `pytorch_lightning`, `commentjson`, `rawpy`, `pygame`, and `tinycudann`. See [NVlabs/tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn) for installation instructions. Depending on your system you might just be able to do `pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch`, or might have to cmake and build it from source. 27 | 28 | ## Project Structure: 29 | ```cpp 30 | NSF 31 | ├── checkpoints 32 | │ └── // folder for network checkpoints 33 | ├── config 34 | │ └── // network and encoding configurations for different sizes of MLPs 35 | ├── data 36 | │ └── // folder for image sequence data 37 | ├── lightning_logs 38 | │ └── // folder for tensorboard logs 39 | ├── outputs 40 | │ └── // folder for model outputs (e.g., final reconstructions) 41 | ├── utils 42 | │ └── utils.py // model helper functions (e.g., RAW demosaicing, quaternion math) 43 | ├── LICENSE // legal stuff 44 | ├── README.md // <- recursion errors 45 | ├── render.py // interactive render demo 46 | ├── requirements.txt // frozen package requirements 47 | ├── train.py // dataloader, network, visualization, and trainer code 48 | └── tutorial.ipynb // interactive tutorial for training the model 49 | ``` 50 | ## Getting Started: 51 | We highly recommend you start by going through `tutorial.ipynb`, either on your own machine or [with this Google Colab link](https://colab.research.google.com/github/princeton-computational-imaging/NeuLS/blob/main/tutorial.ipynb). 52 | 53 | TLDR: models can be trained with: 54 | 55 | `python3 train.py --data_path {path_to_data} --name {checkpoint_name}` 56 | 57 | For a full list of training arguments, we recommend looking through the argument parser section at the bottom of `\train.py`. 58 | 59 | The final model checkpoint will be saved to `checkpoints/{checkpoint_name}/last.ckpt` 60 | 61 | And you can launch an interactive demo for the scene via: 62 | 63 | `python3 render.py --checkpoint {path_to_checkpoint_folder}` 64 | 65 |
66 | Preview 67 |
68 | 69 | Press `P` to turn the view-dependent color model off/on, `O` to turn the ray offset model off/on, `+` and `-` to raise and lower the image brightness, click and drag to rotate the camera, hold shift and click and drag to translate the camera, scroll to zoom in/out, press `R` to reset the view, and `Escape` to quit. 70 | 71 | 72 | ## Data: 73 | This models trains on data recorded by our Android RAW capture app [Pani](https://github.com/Ilya-Muromets/Pani), see that repo for details on collecting/processing your own captures. 74 | 75 | You can download the full `30gb` dataset via [this link](https://soap.cs.princeton.edu/neuls/neuls_data.zip), or the individual scenes from the paper via following links: 76 | 77 | | Image | Name | Num Frames | Camera Lens | ISO | Exposure Time (s) | Aperture | Focal Length | Height | Width | 78 | |-------|------|------------|-------------|-----|-------------------|----------|--------------|--------|-------| 79 | | 360Road | 360Road | 77 | main | 21 | 1/1110 | 1.68 | 6.9 | 3072 | 4080 | 80 | | Beppu | Beppu | 40 | main | 21 | 1/630 | 1.68 | 6.9 | 3072 | 4080 | 81 | | BikeRacks | BikeRacks | 48 | main | 21 | 1/1960 | 1.68 | 6.9 | 3072 | 4080 | 82 | | BikeShelf | BikeShelf | 37 | main | 8065 | 1/100 | 1.68 | 6.9 | 3072 | 4080 | 83 | | BluePit | BluePit | 32 | main | 21 | 1/199 | 1.68 | 6.9 | 3072 | 4080 | 84 | | BluePlane | BluePlane | 59 | main | 1005 | 1/120 | 1.68 | 6.9 | 3072 | 4080 | 85 | | Bridge | Bridge | 49 | main | 21 | 1/1384 | 1.68 | 6.9 | 3072 | 4080 | 86 | | CityCars | CityCars | 46 | main | 21 | 1/1744 | 1.68 | 6.9 | 3072 | 4080 | 87 | | Construction | Construction | 53 | main | 21 | 1/1653 | 1.68 | 6.9 | 3072 | 4080 | 88 | | DarkDistillery | DarkDistillery | 57 | main | 10667 | 1/100 | 1.68 | 6.9 | 3072 | 4080 | 89 | | DarkPeace | DarkPeace | 51 | main | 10667 | 1/100 | 1.68 | 6.9 | 3072 | 4080 | 90 | | DarkShrine | DarkShrine | 34 | main | 5065 | 1/80 | 1.68 | 6.9 | 3072 | 4080 | 91 | | DarkTruck | DarkTruck | 43 | main | 10667 | 1/60 | 1.68 | 6.9 | 3072 | 4080 | 92 | | Eiffel | Eiffel | 73 | main | 21 | 1/1183 | 1.68 | 6.9 | 3072 | 4080 | 93 | | Escalatosaur | Escalatosaur | 49 | main | 589 | 1/100 | 1.68 | 6.9 | 3072 | 4080 | 94 | | Fireworks | Fireworks | 78 | main | 5000 | 1/100 | 1.68 | 6.9 | 3072 | 4080 | 95 | | Fukuoka | Fukuoka | 40 | main | 21 | 1/1312 | 1.68 | 6.9 | 3072 | 4080 | 96 | | LanternDeer | LanternDeer | 34 | main | 42 | 1/103 | 1.68 | 6.9 | 3072 | 4080 | 97 | | MountainTop | MountainTop | 59 | main | 21 | 1/2405 | 1.68 | 6.9 | 3072 | 4080 | 98 | | Ocean | Ocean | 44 | main | 110 | 1/127 | 1.68 | 6.9 | 3072 | 4080 | 99 | | ParisCity | ParisCity | 55 | main | 21 | 1/1265 | 1.68 | 6.9 | 3072 | 4080 | 100 | | PlaneHall | PlaneHall | 77 | main | 1005 | 1/120 | 1.68 | 6.9 | 3072 | 4080 | 101 | | PondHouse | PondHouse | 51 | main | 21 | 1/684 | 1.68 | 6.9 | 3072 | 4080 | 102 | | RainyPath | RainyPath | 38 | main | 600 | 1/100 | 1.68 | 6.9 | 3072 | 4080 | 103 | | RedShrine | RedShrine | 40 | main | 21 | 1/499 | 1.68 | 6.9 | 3072 | 4080 | 104 | | RockStream | RockStream | 31 | main | 21 | 1/1110 | 1.68 | 6.9 | 3072 | 4080 | 105 | | Seafood | Seafood | 44 | main | 21 | 1/193 | 1.68 | 6.9 | 3072 | 4080 | 106 | | ShinyPlane | ShinyPlane | 50 | main | 21 | 1/210 | 1.68 | 6.9 | 3072 | 4080 | 107 | | ShinySticks | ShinySticks | 37 | main | 21 | 1/1417 | 1.68 | 6.9 | 3072 | 4080 | 108 | | SnowTree | SnowTree | 42 | main | 21 | 1/320 | 1.68 | 6.9 | 3072 | 4080 | 109 | | Stalls | Stalls | 52 | main | 49 | 1/79 | 1.68 | 6.9 | 3072 | 4080 | 110 | | StatueLeft | StatueLeft | 22 | main | 805 | 1/60 | 1.68 | 6.9 | 3072 | 4080 | 111 | | StatueRight | StatueRight | 26 | main | 802 | 1/60 | 1.68 | 6.9 | 3072 | 4080 | 112 | | Tenjin | Tenjin | 36 | main | 602 | 1/100 | 1.68 | 6.9 | 3072 | 4080 | 113 | | Tigers | Tigers | 42 | main | 507 | 1/200 | 1.68 | 6.9 | 3072 | 4080 | 114 | | Toronto | Toronto | 31 | main | 21 | 1/1250 | 1.68 | 6.9 | 3072 | 4080 | 115 | | Vending | Vending | 42 | main | 21 | 1/352 | 1.68 | 6.9 | 3072 | 4080 | 116 | | WoodOffice | WoodOffice | 83 | main | 206 | 1/120 | 1.68 | 6.9 | 3072 | 4080 | 117 | | GlassGarden | GlassGarden | 59 | telephoto | 24 | 1/231 | 2.8 | 18.0 | 3024 | 4032 | 118 | | NaraCity | NaraCity | 54 | telephoto | 17 | 1/327 | 2.8 | 18.0 | 3024 | 4032 | 119 | | 360Beach | 360Beach | 56 | ultrawide | 42 | 1/3175 | 1.95 | 2.23 | 3000 | 4000 | 120 | | 360Garden | 360Garden | 77 | ultrawide | 41 | 1/1104 | 1.95 | 2.23 | 3000 | 4000 | 121 | | 360Siegen | 360Siegen | 67 | ultrawide | 41 | 1/1029 | 1.95 | 2.23 | 3000 | 4000 | 122 | | CatBar | CatBar | 37 | ultrawide | 88 | 1/110 | 1.95 | 2.23 | 3000 | 4000 | 123 | | Convocation | Convocation | 43 | ultrawide | 41 | 1/2309 | 1.95 | 2.23 | 3000 | 4000 | 124 | | MellonDoor | MellonDoor | 40 | ultrawide | 41 | 1/564 | 1.95 | 2.23 | 3000 | 4000 | 125 | | RedPit | RedPit | 53 | ultrawide | 41 | 1/418 | 1.95 | 2.23 | 3000 | 4000 | 126 | | River | River | 65 | ultrawide | 48 | 1/78 | 1.95 | 2.23 | 3000 | 4000 | 127 | | UniversityCollege | UniversityCollege | 55 | ultrawide | 42 | 1/2177 | 1.95 | 2.23 | 3000 | 4000 | 128 | | Waterfall | Waterfall | 74 | ultrawide | 41 | 1/1621 | 1.95 | 2.23 | 3000 | 4000 | 129 | 130 | Glhf, 131 | Ilya 132 | -------------------------------------------------------------------------------- /checkpoints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/checkpoints/__init__.py -------------------------------------------------------------------------------- /config/config_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "encoding": { 3 | "otype": "Grid", 4 | "type": "Hash", 5 | "n_levels": 15, 6 | "n_features_per_level": 4, 7 | "log2_hashmap_size": 19, 8 | "base_resolution": 4, 9 | "per_level_scale": 1.61, 10 | "interpolation": "Linear" 11 | }, 12 | "network": { 13 | "otype": "FullyFusedMLP", 14 | "activation": "ReLU", 15 | "output_activation": "None", 16 | "n_neurons": 128, 17 | "n_hidden_layers": 5 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /config/config_medium.json: -------------------------------------------------------------------------------- 1 | { 2 | "encoding": { 3 | "otype": "Grid", 4 | "type": "Hash", 5 | "n_levels": 12, 6 | "n_features_per_level": 4, 7 | "log2_hashmap_size": 16, 8 | "base_resolution": 4, 9 | "per_level_scale": 1.61, 10 | "interpolation": "Linear" 11 | }, 12 | "network": { 13 | "otype": "FullyFusedMLP", 14 | "activation": "ReLU", 15 | "output_activation": "None", 16 | "n_neurons": 64, 17 | "n_hidden_layers": 5 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /config/config_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "encoding": { 3 | "otype": "Grid", 4 | "type": "Hash", 5 | "n_levels": 8, 6 | "n_features_per_level": 4, 7 | "log2_hashmap_size": 16, 8 | "base_resolution": 4, 9 | "per_level_scale": 1.61, 10 | "interpolation": "Linear" 11 | }, 12 | "network": { 13 | "otype": "FullyFusedMLP", 14 | "activation": "ReLU", 15 | "output_activation": "None", 16 | "n_neurons": 64, 17 | "n_hidden_layers": 3 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /config/config_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "encoding": { 3 | "otype": "Grid", 4 | "type": "Hash", 5 | "n_levels": 6, 6 | "n_features_per_level": 4, 7 | "log2_hashmap_size": 12, 8 | "base_resolution": 4, 9 | "per_level_scale": 1.61, 10 | "interpolation": "Linear" 11 | }, 12 | "network": { 13 | "otype": "FullyFusedMLP", 14 | "activation": "ReLU", 15 | "output_activation": "None", 16 | "n_neurons": 32, 17 | "n_hidden_layers": 3 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /config/config_ultrakill.json: -------------------------------------------------------------------------------- 1 | { 2 | "encoding": { 3 | "otype": "Grid", 4 | "type": "Hash", 5 | "n_levels": 17, 6 | "n_features_per_level": 4, 7 | "log2_hashmap_size": 20, 8 | "base_resolution": 4, 9 | "per_level_scale": 1.61, 10 | "interpolation": "Linear" 11 | }, 12 | "network": { 13 | "otype": "CutlassMLP", 14 | "activation": "ReLU", 15 | "output_activation": "None", 16 | "n_neurons": 256, 17 | "n_hidden_layers": 5 18 | } 19 | } 20 | 21 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/data/__init__.py -------------------------------------------------------------------------------- /lightning_logs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/lightning_logs/__init__.py -------------------------------------------------------------------------------- /outputs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NeuLS/56f6e51993147d3bc1e29c06300cfbb29d3b28d6/outputs/__init__.py -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | import pygame 5 | import argparse 6 | 7 | from utils import utils 8 | from train import PanoModel, BundleDataset 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description="Interactive Rendering") 12 | parser.add_argument("-d", "--checkpoint", required=True, help="Path to the checkpoint file") 13 | return parser.parse_args() 14 | 15 | args = parse_args() 16 | chkpt_path = args.checkpoint 17 | chkpt_path = os.path.join(os.path.dirname(chkpt_path), "last.ckpt") 18 | data_path = os.path.join(os.path.dirname(chkpt_path), "data.pkl") 19 | model = PanoModel.load_from_checkpoint(chkpt_path, device="cuda", cached_data=data_path) 20 | model = model.to("cuda") 21 | model = model.eval() 22 | model.args.no_offset = False 23 | model.args.no_view_color = False 24 | 25 | class InteractiveWindow: 26 | def __init__(self): 27 | self.offset = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float32).to("cuda") 28 | self.fov_factor = 1.0 29 | self.width = 1440 30 | self.height = 1080 31 | self.t = 0.5 32 | self.brightness = 1 33 | self.t_tensor = torch.full((self.width * self.height,), self.t, device="cuda", dtype=torch.float32) 34 | 35 | self.generate_camera() 36 | 37 | 38 | pygame.init() 39 | pygame.font.init() # Initialize font module 40 | self.font = pygame.font.SysFont(None, 36) # Set up the font 41 | 42 | self.screen = pygame.display.set_mode((self.width, self.height), pygame.DOUBLEBUF | pygame.HWSURFACE | pygame.SRCALPHA) 43 | pygame.display.set_caption("NeuLS Interactive Rendering") 44 | 45 | self.last_mouse_position = None 46 | self.update_image() 47 | 48 | def generate_camera(self): 49 | self.intrinsics_inv = model.data.intrinsics_inv[int(self.t * model.args.num_frames - 1)].clone() 50 | self.intrinsics_inv[0] *= 1.77 # Widen FOV 51 | self.intrinsics_inv = self.intrinsics_inv[None, :, :].repeat(self.width * self.height, 1, 1).to("cuda") 52 | 53 | self.quaternion_camera_to_world = model.data.quaternion_camera_to_world[int(self.t * model.args.num_frames - 1)].to("cuda") 54 | self.quaternion_camera_to_world_offset = self.quaternion_camera_to_world.clone() 55 | self.camera_to_world = model.model_rotation(self.quaternion_camera_to_world, self.t_tensor).to("cuda") 56 | 57 | self.ray_origins = model.model_translation(self.t_tensor, 1.0) 58 | 59 | self.uv = utils.make_grid(self.height, self.width, [0, 1], [0, 1]).to("cuda") 60 | self.ray_directions = model.generate_ray_directions(self.uv, self.camera_to_world, self.intrinsics_inv) 61 | 62 | 63 | 64 | def generate_image(self): 65 | with torch.no_grad(): 66 | self.camera_to_world = utils.convert_quaternions_to_rot(self.quaternion_camera_to_world_offset).repeat(self.width * self.height, 1, 1).to("cuda") 67 | self.ray_directions = model.generate_ray_directions(self.uv * self.fov_factor + 0.5 * (1 - self.fov_factor), self.camera_to_world, self.intrinsics_inv) 68 | 69 | rgb_transmission = model.inference(self.t_tensor, self.uv * self.fov_factor + 0.5 * (1 - self.fov_factor), self.ray_origins + self.offset, self.ray_directions, 1.0) 70 | 71 | rgb_transmission = model.color(rgb_transmission, self.height, self.width).permute(2, 1, 0) 72 | return (rgb_transmission ** 0.7 * 255 * self.brightness).clamp(0, 255).byte().cpu().numpy() 73 | 74 | def update_image(self): 75 | image = self.generate_image() 76 | surf = pygame.surfarray.make_surface(image) 77 | self.screen.blit(surf, (0, 0)) 78 | 79 | # Render and display status text 80 | self.render_status_text() 81 | 82 | pygame.display.flip() 83 | 84 | def render_status_text(self): 85 | # Double the font size 86 | large_font = pygame.font.SysFont(None, int(self.font.get_height() * 1.4)) 87 | 88 | view_color_status = "View-Dependent Color: On" if not model.args.no_view_color else "View-Dependent Color: Off" 89 | ray_offset_status = "Ray Offset: On" if not model.args.no_offset else "Ray Offset: Off" 90 | 91 | # Render the text in white 92 | view_color_text = large_font.render(view_color_status, True, (255, 255, 255)) 93 | ray_offset_text = large_font.render(ray_offset_status, True, (255, 255, 255)) 94 | 95 | # Calculate the size of the black box 96 | box_width = max(view_color_text.get_width(), ray_offset_text.get_width()) + 20 97 | box_height = view_color_text.get_height() + ray_offset_text.get_height() + 20 98 | 99 | # Position for the box 100 | box_x = 10 101 | box_y = self.height - box_height - 10 102 | 103 | # Draw the black box 104 | pygame.draw.rect(self.screen, (0, 0, 0), (box_x, box_y, box_width, box_height)) 105 | 106 | # Draw the white text on top of the black box 107 | self.screen.blit(view_color_text, (box_x + 10, box_y + 10)) 108 | self.screen.blit(ray_offset_text, (box_x + 10, box_y + view_color_text.get_height() + 10)) 109 | 110 | 111 | def handle_keys(self, event): 112 | fov_step = 1.05 113 | 114 | if event.type == pygame.KEYDOWN: 115 | if event.key == pygame.K_o: 116 | model.args.no_offset = not model.args.no_offset 117 | elif event.key == pygame.K_p: 118 | model.args.no_view_color = not model.args.no_view_color 119 | elif event.key == pygame.K_EQUALS: 120 | self.brightness *= 1.05 121 | elif event.key == pygame.K_MINUS: 122 | self.brightness *= 0.9523 123 | elif event.key == pygame.K_LEFT: 124 | self.t = max(0, self.t - 0.01) 125 | self.t_tensor.fill_(self.t) 126 | self.generate_camera() 127 | elif event.key == pygame.K_RIGHT: 128 | self.t = min(1, self.t + 0.01) 129 | self.t_tensor.fill_(self.t) 130 | self.generate_camera() 131 | elif event.key == pygame.K_ESCAPE: 132 | pygame.quit() 133 | sys.exit() 134 | elif event.key == pygame.K_r: 135 | self.offset.zero_() 136 | self.quaternion_camera_to_world_offset = self.quaternion_camera_to_world.clone() 137 | self.fov_factor = 1.0 138 | 139 | def handle_mouse(self, event): 140 | shift_pressed = pygame.key.get_mods() & pygame.KMOD_SHIFT 141 | 142 | if event.type == pygame.MOUSEBUTTONDOWN and event.button == 1: # Left button 143 | self.last_mouse_position = pygame.mouse.get_pos() 144 | elif event.type == pygame.MOUSEMOTION and self.last_mouse_position is not None: 145 | x, y = pygame.mouse.get_pos() 146 | dx, dy = x - self.last_mouse_position[0], y - self.last_mouse_position[1] 147 | sensitivity = 0.001 * self.fov_factor * model.args.focal_compensation 148 | 149 | if event.buttons[0]: # Left button held down 150 | if shift_pressed: 151 | step = 0.5 152 | movement_camera_space = torch.tensor([-dx * step * sensitivity, dy * step * sensitivity, 0.0]).to("cuda") 153 | movement_world_space = self.camera_to_world[0] @ movement_camera_space 154 | self.offset[0] += movement_world_space 155 | else: 156 | pitch = torch.tensor(-dy * sensitivity, dtype=torch.float32).to("cuda") 157 | yaw = torch.tensor(-dx * sensitivity, dtype=torch.float32).to("cuda") 158 | 159 | # Compute the pitch and yaw quaternions (small rotations) 160 | pitch_quat = torch.tensor([torch.cos(pitch / 2), torch.sin(pitch / 2), 0, 0]).to("cuda") 161 | yaw_quat = torch.tensor([torch.cos(yaw / 2), 0, torch.sin(yaw / 2), 0]).to("cuda") 162 | 163 | # Combine the pitch and yaw quaternions by multiplying them 164 | rotation_quat = utils.quaternion_multiply(yaw_quat, pitch_quat) 165 | 166 | # Apply the rotation incrementally 167 | self.quaternion_camera_to_world_offset = utils.quaternion_multiply( 168 | self.quaternion_camera_to_world_offset, rotation_quat 169 | ) 170 | 171 | # Normalize the quaternion to avoid drift 172 | self.quaternion_camera_to_world_offset = self.quaternion_camera_to_world_offset / torch.norm(self.quaternion_camera_to_world_offset) 173 | 174 | 175 | self.last_mouse_position = (x, y) 176 | elif event.type == pygame.MOUSEBUTTONUP and event.button == 1: # Left button 177 | self.last_mouse_position = None 178 | elif event.type == pygame.MOUSEWHEEL: 179 | fov_step = 1.02 180 | if event.y > 0: # Scroll up 181 | self.fov_factor /= fov_step 182 | elif event.y < 0: # Scroll down 183 | self.fov_factor *= fov_step 184 | 185 | def run(self): 186 | clock = pygame.time.Clock() 187 | while True: 188 | for event in pygame.event.get(): 189 | if event.type == pygame.QUIT: 190 | pygame.quit() 191 | sys.exit() 192 | elif event.type in [pygame.MOUSEBUTTONDOWN, pygame.MOUSEBUTTONUP, pygame.MOUSEMOTION, pygame.MOUSEWHEEL]: 193 | self.handle_mouse(event) 194 | else: 195 | self.handle_keys(event) 196 | 197 | self.update_image() 198 | clock.tick(30) # Limit the frame rate to 30 FPS 199 | 200 | if __name__ == "__main__": 201 | window = InteractiveWindow() 202 | window.run() 203 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | commentjson==0.9.0 2 | ipython==8.21.0 3 | ipywidgets==8.0.6 4 | matplotlib==3.7.0 5 | numpy==1.24.3 6 | Pillow==11.0.0 7 | pygame==2.6.0 8 | pytorch_lightning==2.4.0 9 | torch==2.7.0 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import commentjson as json 3 | import numpy as np 4 | import os 5 | import re 6 | import pickle 7 | 8 | import tinycudann as tcnn 9 | 10 | from utils import utils 11 | from utils.utils import debatch 12 | import matplotlib.pyplot as plt 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | from torch.utils.data import Dataset 18 | from torch.utils.data import DataLoader 19 | import pytorch_lightning as pl 20 | 21 | ######################################################################################################### 22 | ################################################ DATASET ################################################ 23 | ######################################################################################################### 24 | 25 | class BundleDataset(Dataset): 26 | def __init__(self, args, load_volume=False): 27 | self.args = args 28 | print("Loading from:", self.args.data_path) 29 | 30 | data = np.load(args.data_path, allow_pickle=True) 31 | 32 | self.characteristics = data['characteristics'].item() # camera characteristics 33 | self.motion = data['motion'].item() 34 | self.frame_timestamps = torch.tensor([data[f'raw_{i}'].item()['timestamp'] for i in range(data['num_raw_frames'])], dtype=torch.float64) 35 | self.motion_timestamps = torch.tensor(self.motion['timestamp'], dtype=torch.float64) 36 | self.num_frames = data['num_raw_frames'].item() 37 | 38 | # WXYZ quaternions, remove phase wraps (2pi jumps) 39 | self.quaternions = utils.unwrap_quaternions(torch.tensor(self.motion['quaternion']).float()) # T',4, has different timestamps from frames 40 | # flip x/y to match out convention: 41 | # +y towards selfie camera, +x towards buttons, +z towards scene 42 | self.quaternions[:,1] *= -1 43 | self.quaternions[:,2] *= -1 44 | 45 | self.quaternions = utils.multi_interp(self.frame_timestamps, self.motion_timestamps, self.quaternions) # interpolate to frame times 46 | 47 | self.reference_quaternion = self.quaternions[0:1] 48 | self.quaternion_camera_to_world = utils.quaternion_multiply(utils.quaternion_conjugate(self.reference_quaternion), self.quaternions) 49 | 50 | self.intrinsics = torch.tensor(np.array([data[f'raw_{i}'].item()['intrinsics'] for i in range(data['num_raw_frames'])])).float() # T,3,3 51 | # swap cx,cy -> landscape to portrait 52 | cx, cy = self.intrinsics[:, 2, 1].clone(), self.intrinsics[:, 2, 0].clone() 53 | self.intrinsics[:, 2, 0], self.intrinsics[:, 2, 1] = cx, cy 54 | # transpose to put cx,cy in right column 55 | self.intrinsics = self.intrinsics.transpose(1, 2) 56 | self.intrinsics_inv = torch.inverse(self.intrinsics) 57 | 58 | self.lens_distortion = torch.tensor(data['raw_0'].item()['lens_distortion']) 59 | self.tonemap_curve = torch.tensor(data['raw_0'].item()['tonemap_curve'], dtype=torch.float32) 60 | self.ccm = torch.tensor(data['raw_0'].item()['ccm'], dtype=torch.float32) 61 | 62 | self.img_channels = 3 63 | self.img_height = data['raw_0'].item()['width'] # rotated 90 64 | self.img_width = data['raw_0'].item()['height'] 65 | self.rgb_volume = None # placeholder volume for fast loading 66 | 67 | # rolling shutter timing compensation, off by default (can bug for data not from a Pixel 8 Pro) 68 | self.rolling_shutter_skew = data['raw_0'].item()['android']['sensor.rollingShutterSkew'] / 1e9 # delay between top and bottom row readout, seconds 69 | self.rolling_shutter_skew_row = self.rolling_shutter_skew / (self.img_height - 1) 70 | self.row_timestamps = torch.zeros(len(self.frame_timestamps), self.img_height, dtype=torch.float64) # NxH 71 | for i, frame_timestamp in enumerate(self.frame_timestamps): 72 | for j in range(self.img_height): 73 | if args.rolling_shutter: 74 | self.row_timestamps[i,j] = frame_timestamp + j * self.rolling_shutter_skew_row 75 | else: 76 | self.row_timestamps[i,j] = frame_timestamp 77 | 78 | 79 | self.row_timestamps = self.row_timestamps - self.row_timestamps[0,0] # zero at start 80 | self.row_timestamps = self.row_timestamps/self.row_timestamps[-1,-1] # normalize to 0-1 81 | 82 | if args.frames is not None: 83 | # subsample frames 84 | self.num_frames = len(args.frames) 85 | self.frame_timestamps = self.frame_timestamps[args.frames] 86 | self.intrinsics = self.intrinsics[args.frames] 87 | self.intrinsics_inv = self.intrinsics_inv[args.frames] 88 | self.quaternions = self.quaternions[args.frames] 89 | self.reference_quaternion = self.quaternions[0:1] 90 | self.quaternion_camera_to_world = self.quaternion_camera_to_world[args.frames] 91 | 92 | if load_volume: 93 | self.load_volume() 94 | 95 | self.frame_batch_size = 2 * (self.args.point_batch_size // self.num_frames // 2) # nearest even cut 96 | self.point_batch_size = self.frame_batch_size * self.num_frames # nearest multiple of num_frames 97 | self.num_batches = self.args.num_batches 98 | 99 | self.training_phase = 0.0 # fraction of training complete 100 | print("Frame Count: ", self.num_frames) 101 | 102 | def load_volume(self): 103 | volume_path = self.args.data_path.replace("frame_bundle.npz", "rgb_volume.npy") 104 | if os.path.exists(volume_path): 105 | print("Loading cached volume from:", volume_path) 106 | self.rgb_volume = torch.tensor(np.load(volume_path)).float() 107 | else: 108 | data = dict(np.load(self.args.data_path, allow_pickle=True)) 109 | utils.de_item(data) 110 | self.rgb_volume = (utils.raw_to_rgb(data)) # T,C,H,W 111 | if self.args.cache: 112 | print("Saving cached volume to:", volume_path) 113 | np.save(volume_path, self.rgb_volume.numpy()) 114 | 115 | if self.args.max_percentile < 100: # cut off highlights (long-tail-distribution) 116 | self.clip = np.percentile(self.rgb_volume[0], self.args.max_percentile) 117 | self.rgb_volume = torch.clamp(self.rgb_volume, 0, self.clip) 118 | self.rgb_volume = self.rgb_volume/self.clip 119 | else: 120 | self.clip = 1.0 121 | 122 | self.mean = self.rgb_volume[0].mean() 123 | self.rgb_volume = (16 * (self.rgb_volume - self.mean)).to(torch.float16) 124 | 125 | if self.args.frames is not None: 126 | self.rgb_volume = self.rgb_volume[self.args.frames] # subsample frames 127 | 128 | self.img_height, self.img_width = self.rgb_volume.shape[2], self.rgb_volume.shape[3] 129 | 130 | 131 | def __len__(self): 132 | return self.num_batches # arbitrary as we continuously generate random samples 133 | 134 | def __getitem__(self, idx): 135 | uv = torch.rand((self.frame_batch_size * self.num_frames), 2) # uniform random in [0,1] 136 | uv = uv * torch.tensor([[self.img_width-1, self.img_height-1]]) # scale to image dimensions 137 | uv = uv.round() # quantize to pixels 138 | u,v = uv.unbind(-1) 139 | 140 | t = torch.zeros_like(uv[:,0:1]) 141 | for frame in range(self.num_frames): 142 | # use row to index into row_timestamps 143 | t[frame * self.frame_batch_size:(frame + 1) * self.frame_batch_size,0] = self.row_timestamps[frame, v[frame * self.frame_batch_size:(frame + 1) * self.frame_batch_size].long()] 144 | 145 | uv = uv / torch.tensor([[self.img_width-1, self.img_height-1]]) # scale back to 0-1 146 | 147 | return self.generate_samples(t, uv) 148 | 149 | def generate_samples(self, t, uv): 150 | """ generate samples from dataset and camera parameters for training 151 | """ 152 | 153 | # create frame_batch_size of quaterions for each frame 154 | quaternion_camera_to_world = (self.quaternion_camera_to_world[:self.num_frames]).repeat_interleave(self.frame_batch_size, dim=0) 155 | # create frame_batch_size of intrinsics for each frame 156 | intrinsics = (self.intrinsics[:self.num_frames]).repeat_interleave(self.frame_batch_size, dim=0) 157 | intrinsics_inv = (self.intrinsics_inv[:self.num_frames]).repeat_interleave(self.frame_batch_size, dim=0) 158 | 159 | # sample grid 160 | u,v = uv.unbind(-1) 161 | u, v = (u * (self.img_width - 1)).round().long(), (v * (self.img_height - 1)).round().long() # pixel coordinates 162 | u, v = torch.clamp(u, 0, self.img_width-1), torch.clamp(v, 0, self.img_height-1) # clamp to image bounds 163 | x, y = u, (self.img_height - 1) - v # convert to array coordinates 164 | 165 | if self.rgb_volume is not None: 166 | rgb_samples = [] 167 | for frame in range(self.num_frames): 168 | frame_min, frame_max = frame * self.frame_batch_size, (frame + 1) * self.frame_batch_size 169 | rgb_samples.append(self.rgb_volume[frame,:,y[frame_min:frame_max],x[frame_min:frame_max]].permute(1,0)) 170 | rgb_samples = torch.cat(rgb_samples, dim=0) 171 | else: 172 | rgb_samples = torch.zeros(self.frame_batch_size * self.num_frames, 3) 173 | 174 | return t, uv, quaternion_camera_to_world, intrinsics, intrinsics_inv, rgb_samples 175 | 176 | def sample_frame(self, frame, uv): 177 | """ sample frame [frame] at coordinates u,v 178 | """ 179 | 180 | u,v = uv.unbind(-1) 181 | u, v = (u * self.img_width).round().long(), (v * self.img_height).round().long() # pixel coordinates 182 | u, v = torch.clamp(u, 0, self.img_width-1), torch.clamp(v, 0, self.img_height-1) # clamp to image bounds 183 | x, y = u, (self.img_height - 1) - v # convert to array coordinates 184 | 185 | if self.rgb_volume is not None: 186 | rgb_samples = self.rgb_volume[frame:frame+1,:,y,x] # frames x 3 x H x W volume 187 | rgb_samples = rgb_samples.permute(0,2,1).flatten(0,1) # point_batch_size x channels 188 | else: 189 | rgb_samples = torch.zeros(u.shape[0], 3) 190 | 191 | return rgb_samples 192 | 193 | ######################################################################################################### 194 | ################################################ MODELS #################$############################### 195 | ######################################################################################################### 196 | 197 | class RotationModel(nn.Module): 198 | def __init__(self, args): 199 | super().__init__() 200 | self.args = args 201 | self.rotations = nn.Parameter(torch.zeros(1, 3, self.args.num_frames, dtype=torch.float32), requires_grad=True) 202 | 203 | def forward(self, quaternion_camera_to_world, t): 204 | self.rotations.data[:, :, 0] = 0.0 # zero out first frame's rotation 205 | 206 | rotations = utils.interpolate_params(self.rotations, t) 207 | rotations = self.args.rotation_weight * rotations 208 | rx, ry, rz = rotations[:, 0], rotations[:, 1], rotations[:, 2] 209 | r1 = torch.ones_like(rx) 210 | 211 | rotation_offsets = torch.stack([torch.stack([r1, -rz, ry], dim=-1), 212 | torch.stack([rz, r1, -rx], dim=-1), 213 | torch.stack([-ry, rx, r1], dim=-1)], dim=-1) 214 | 215 | return rotation_offsets @ utils.convert_quaternions_to_rot(quaternion_camera_to_world) 216 | 217 | class TranslationModel(nn.Module): 218 | def __init__(self, args): 219 | super().__init__() 220 | self.args = args 221 | self.translations_coarse = nn.Parameter(torch.rand(1, 3, 7, dtype=torch.float32) * 1e-5, requires_grad=True) 222 | self.translations_fine = nn.Parameter(torch.rand(1, 3, args.num_frames, dtype=torch.float32) * 1e-5, requires_grad=True) 223 | self.center = nn.Parameter(torch.zeros(1, 3, dtype=torch.float32), requires_grad=True) 224 | 225 | def forward(self, t, training_phase=1.0): 226 | self.translations_coarse.data[:, :, 0] = 0.0 # zero out first frame's translation 227 | self.translations_fine.data[:, :, 0] = 0.0 # zero out first frame's translation 228 | 229 | if training_phase < 0.25: 230 | translation = utils.interpolate_params(self.translations_coarse, t) 231 | else: 232 | translation = utils.interpolate_params(self.translations_coarse, t) + utils.interpolate_params(self.translations_fine, t) 233 | 234 | return self.args.focal_compensation * self.args.translation_weight * (translation + 5 * self.center) 235 | 236 | class LightSphereModel(pl.LightningModule): 237 | 238 | def __init__(self, args): 239 | super().__init__() 240 | self.args = args 241 | 242 | encoding_offset_position_config = json.load(open(f"config/config_{args.encoding_offset_position_config}.json"))["encoding"] 243 | encoding_offset_angle_config = json.load(open(f"config/config_{args.encoding_offset_angle_config}.json"))["encoding"] 244 | network_offset_config = json.load(open(f"config/config_{args.network_offset_config}.json"))["network"] 245 | 246 | encoding_color_position_config = json.load(open(f"config/config_{args.encoding_color_position_config}.json"))["encoding"] 247 | encoding_color_angle_config = json.load(open(f"config/config_{args.encoding_color_angle_config}.json"))["encoding"] 248 | network_color_position_config = json.load(open(f"config/config_{args.network_color_position_config}.json"))["network"] 249 | network_color_angle_config = json.load(open(f"config/config_{args.network_color_angle_config}.json"))["network"] 250 | 251 | self.encoding_offset_position = tcnn.Encoding(n_input_dims=3, encoding_config=encoding_offset_position_config) 252 | self.encoding_offset_angle = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_offset_angle_config) 253 | 254 | self.network_offset = tcnn.Network(n_input_dims=self.encoding_offset_position.n_output_dims + self.encoding_offset_angle.n_output_dims, 255 | n_output_dims=3, network_config=network_offset_config) 256 | 257 | self.encoding_color_position = tcnn.Encoding(n_input_dims=3, encoding_config=encoding_color_position_config) 258 | self.encoding_color_angle = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_color_angle_config) 259 | 260 | self.network_color_position = tcnn.Network(n_input_dims=self.encoding_color_position.n_output_dims, n_output_dims=64, network_config=network_color_position_config) 261 | self.network_color_angle = tcnn.Network(n_input_dims=self.encoding_color_position.n_output_dims + self.encoding_color_angle.n_output_dims, n_output_dims=64, network_config=network_color_angle_config) 262 | self.network_color = nn.Linear(64, 3, dtype=torch.float32, bias=False) # faster than tcnn.Network 263 | self.network_color.weight.data = torch.rand_like(self.network_color.weight.data) 264 | 265 | self.initial_rgb = torch.nn.Parameter(data=torch.zeros([1,3], dtype=torch.float16), requires_grad=False) 266 | 267 | self.enc_feat_direction = None 268 | self.enc_offset_intersection = None 269 | self.enc_offset_direction = None 270 | self.enc_image_outer = None 271 | 272 | def mask(self, encoding, training_phase, initial): 273 | if self.args.no_mask: 274 | return encoding 275 | else: 276 | return utils.mask(encoding, training_phase, initial) 277 | 278 | @torch.jit.script 279 | def solve_sphere_crossings(ray_origins, ray_directions): 280 | # Coefficients for the quadratic equation 281 | b = 2 * torch.sum(ray_origins * ray_directions, dim=1) 282 | c = torch.sum(ray_origins**2, dim=1) - 1.0 283 | 284 | discriminant = b**2 - 4 * c 285 | sqrt_discriminant = torch.sqrt(discriminant) 286 | t = (-b + sqrt_discriminant) / 2 287 | intersections = ray_origins + t.unsqueeze(-1) * ray_directions 288 | return intersections 289 | 290 | def inference(self, t, uv, ray_origins, ray_directions, training_phase=1.0): 291 | # Slightly slimmed down version of forward() for inference 292 | uv = uv.clamp(0, 1) 293 | 294 | intersections_sphere = self.solve_sphere_crossings(ray_origins, ray_directions) 295 | 296 | if not self.args.no_offset: 297 | encoded_offset_position = self.encoding_offset_position((intersections_sphere + 1) / 2) 298 | encoded_offset_angle = self.encoding_offset_angle(uv) 299 | encoded_offset = torch.cat((encoded_offset_position, encoded_offset_angle), dim=1) 300 | 301 | offset = self.network_offset(encoded_offset).float() 302 | ray_directions_offset = ray_directions + torch.cross(offset, ray_directions, dim=1) 303 | intersections_sphere_offset = self.solve_sphere_crossings(ray_origins, ray_directions_offset) 304 | else: 305 | intersections_sphere_offset = intersections_sphere 306 | 307 | encoded_color_position = self.encoding_color_position((intersections_sphere_offset + 1) / 2) 308 | feat_color = self.network_color_position(encoded_color_position).float() 309 | 310 | if not self.args.no_view_color: 311 | encoded_color_angle = self.encoding_color_angle(uv) 312 | encoded_color = torch.cat((encoded_color_position, encoded_color_angle), dim=1) 313 | feat_color_angle = self.network_color_angle(encoded_color) 314 | feat_color = feat_color + feat_color_angle 315 | 316 | rgb = self.initial_rgb + self.network_color(feat_color) 317 | return rgb 318 | 319 | def forward(self, t, uv, ray_origins, ray_directions, training_phase): 320 | uv = uv.clamp(0,1) 321 | 322 | if training_phase < 0.2: # Apply random perturbation for training during early epochs 323 | factor = 0.015 * self.args.focal_compensation 324 | perturbation = torch.randn_like(ray_origins) * factor / 0.2 * max(0.0, 0.25 - training_phase) 325 | ray_origins = ray_origins + perturbation 326 | 327 | intersections_sphere = self.solve_sphere_crossings(ray_origins, ray_directions) 328 | 329 | if training_phase > 0.2 and not self.args.no_offset: 330 | encoded_offset_position = self.mask(self.encoding_offset_position((intersections_sphere + 1) / 2), training_phase, initial=0.2) 331 | encoded_offset_angle = self.mask(self.encoding_offset_angle(uv), training_phase, initial=0.5) 332 | encoded_offset = torch.cat((encoded_offset_position, encoded_offset_angle), dim=1) 333 | 334 | offset = self.network_offset(encoded_offset).float() 335 | ray_directions_offset = ray_directions + torch.cross(offset, ray_directions, dim=1) # linearized rotation 336 | intersections_sphere_offset = self.solve_sphere_crossings(ray_origins, ray_directions_offset) 337 | else: 338 | offset = torch.ones_like(ray_directions) 339 | intersections_sphere_offset = intersections_sphere 340 | 341 | encoded_color_position = self.mask(self.encoding_color_position((intersections_sphere_offset + 1) / 2), training_phase, initial=0.8) 342 | feat_color = self.network_color_position(encoded_color_position).float() 343 | 344 | if training_phase > 0.25 and not self.args.no_view_color: 345 | encoded_color_angle = self.mask(self.encoding_color_angle(uv), training_phase, initial=0.2) 346 | encoded_color = torch.cat((encoded_color_position, encoded_color_angle), dim=1) 347 | feat_color_angle = self.network_color_angle(encoded_color) 348 | feat_color = feat_color + feat_color_angle 349 | else: 350 | feat_color_angle = torch.zeros_like(feat_color) 351 | 352 | rgb = self.initial_rgb + self.network_color(feat_color) 353 | return rgb, offset, feat_color_angle 354 | 355 | 356 | class DistortionModel(pl.LightningModule): 357 | def __init__(self, args): 358 | super().__init__() 359 | self.args = args 360 | 361 | def forward(self, kappa): 362 | if self.args.no_lens_distortion: # no distortion 363 | return (kappa * 0.0).to(self.device) 364 | else: 365 | return kappa.to(self.device) 366 | 367 | 368 | ######################################################################################################### 369 | ################################################ NETWORK ################################################ 370 | ######################################################################################################### 371 | 372 | class PanoModel(pl.LightningModule): 373 | def __init__(self, args, cached_data=None): 374 | super().__init__() 375 | # load network configs 376 | 377 | self.args = args 378 | if cached_data is None: 379 | self.data = BundleDataset(self.args) 380 | else: 381 | with open(cached_data, 'rb') as file: 382 | self.data = pickle.load(file) 383 | 384 | if args.frames is None: 385 | self.args.frames = list(range(self.data.num_frames)) 386 | self.args.num_frames = self.data.num_frames 387 | 388 | 389 | # to account for varying focal lengths, scale camera/ray motion to match 82deg Pixel 8 Pro main lens 390 | self.args.focal_compensation = 1.9236 / (self.data.intrinsics[0,0,0].item()/self.data.intrinsics[0,0,2].item()) 391 | 392 | self.model_translation = TranslationModel(self.args) 393 | self.model_rotation = RotationModel(self.args) 394 | self.model_distortion = DistortionModel(self.args) 395 | self.model_light_sphere = LightSphereModel(self.args) 396 | 397 | self.training_phase = 1.0 398 | self.save_hyperparameters() 399 | 400 | def load_volume(self): 401 | self.data.load_volume() 402 | 403 | def configure_optimizers(self): 404 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr, betas=[0.9,0.99], eps=1e-9, weight_decay=self.args.weight_decay) 405 | gamma = 1.0 406 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) 407 | 408 | return [optimizer], [scheduler] 409 | 410 | def inference(self, *args, **kwargs): 411 | with torch.no_grad(): 412 | return self.model_light_sphere.inference(*args, **kwargs) 413 | 414 | def forward(self, *args, **kwargs): 415 | return self.model_light_sphere(*args, **kwargs) 416 | 417 | def generate_ray_directions(self, uv, camera_to_world, intrinsics_inv): 418 | u, v = uv[:, 0:1] * self.data.img_width, uv[:, 1:2] * self.data.img_height 419 | uv1 = torch.cat([u, v, torch.ones_like(u)], dim=1) # N x 3 420 | # Transform pixel coordinates to camera coordinates 421 | xy1 = torch.bmm(intrinsics_inv, uv1.unsqueeze(2)).squeeze(2) # N x 3 422 | xy = xy1[:, 0:2] 423 | 424 | 425 | x, y = xy[:, 0:1], xy[:, 1:2] 426 | r2 = x**2 + y**2 427 | r4 = r2**2 428 | r6 = r4 * r2 429 | 430 | kappa1, kappa2, kappa3, kappa4, kappa5 = self.model_distortion(self.data.lens_distortion) 431 | 432 | # Apply radial distortion 433 | x_distorted = x * (1 + kappa1 * r2 + kappa2 * r4 + kappa3 * r6) + \ 434 | 2 * kappa4 * x * y + kappa5 * (r2 + 2 * x**2) 435 | y_distorted = y * (1 + kappa1 * r2 + kappa2 * r4 + kappa3 * r6) + \ 436 | 2 * kappa5 * x * y + kappa4 * (r2 + 2 * y**2) 437 | 438 | xy_distorted = torch.cat([x_distorted, y_distorted], dim=1) 439 | 440 | # Combine with z = 1 for direction calculation 441 | ray_directions_unrotated = torch.cat([xy_distorted, torch.ones_like(x)], dim=1) # N x 3 442 | 443 | ray_directions = torch.bmm(camera_to_world, ray_directions_unrotated.unsqueeze(2)).squeeze(2) # Apply camera rotation 444 | 445 | # Normalize ray directions 446 | ray_directions = ray_directions / ray_directions.norm(dim=1, keepdim=True) 447 | 448 | return ray_directions 449 | 450 | def training_step(self, train_batch, batch_idx): 451 | t, uv, quaternion_camera_to_world, intrinsics, intrinsics_inv, rgb_reference = debatch(train_batch) # collapse batch + point dimensions 452 | 453 | camera_to_world = self.model_rotation(quaternion_camera_to_world, t) # apply rotation offset 454 | ray_origins = self.model_translation(t, self.training_phase) # camera center in world coordinates 455 | ray_origins.clamp(-0.99,0.99) # bound within the sphere 456 | ray_directions = self.generate_ray_directions(uv, camera_to_world, intrinsics_inv) 457 | 458 | rgb, offset, feat_color_angle = self.forward(t, uv, ray_origins, ray_directions, self.training_phase) 459 | 460 | loss = 0.0 461 | 462 | rgb_loss = F.l1_loss(rgb, rgb_reference) 463 | self.log('loss/rgb', rgb_loss.mean()) 464 | loss += rgb_loss.mean() 465 | 466 | 467 | factor = 1.2 # Adjust this parameter to control the bend of the curve (1 for linear, 2 for quadratic, etc.) 468 | normalized_epoch = (self.current_epoch + (batch_idx/self.args.num_batches)) / (self.args.max_epochs - 1) 469 | 470 | # Update the training_phase calculation with the factor 471 | self.training_phase = min(1.0, 0.05 + (normalized_epoch ** factor)) 472 | self.data.training_phase = self.training_phase 473 | 474 | return loss 475 | 476 | def color_and_tone(self, rgb_samples, height, width): 477 | """ Apply CCM and tone mapping to raw samples 478 | """ 479 | 480 | img = self.color(rgb_samples, height, width) 481 | img = utils.apply_tonemap(img, self.data.tonemap_curve.to(rgb_samples.device)) 482 | 483 | return img.clamp(0,1) 484 | 485 | def color(self, rgb_samples, height, width): 486 | """ Apply CCM to raw samples 487 | """ 488 | 489 | img = self.data.ccm.to(rgb_samples.device) @ (self.data.mean + rgb_samples.float()/16.0).T 490 | img = img.reshape(3, height, width) 491 | 492 | return img.clamp(0,1) 493 | 494 | @torch.no_grad() 495 | def chunk_forward(self, quaternion_camera_to_world, intrinsics_inv, t, uv, translation, chunk_size=1000000): 496 | """ Forward model with chunking to avoid OOM 497 | """ 498 | total_elements = t.shape[0] 499 | 500 | for start_idx in range(0, total_elements, chunk_size): 501 | end_idx = min(start_idx + chunk_size, total_elements) 502 | 503 | camera_to_world_chunk = self.model_rotation(quaternion_camera_to_world[start_idx:end_idx], t[start_idx:end_idx]) 504 | ray_directions_chunk = self.generate_ray_directions(uv[start_idx:end_idx], camera_to_world_chunk, intrinsics_inv[start_idx:end_idx]) 505 | 506 | if translation is None: 507 | ray_origins_chunk = self.model_translation(t[start_idx:end_idx]) 508 | else: 509 | ray_origins_chunk = torch.zeros_like(ray_directions_chunk) + translation 510 | 511 | chunk_outputs = self.forward(t[start_idx:end_idx], uv[start_idx:end_idx], ray_origins_chunk, ray_directions_chunk, self.training_phase) 512 | 513 | if start_idx == 0: 514 | num_outputs = len(chunk_outputs) 515 | final_outputs = [[] for _ in range(num_outputs)] 516 | 517 | for i, output in enumerate(chunk_outputs): 518 | final_outputs[i].append(output.cpu()) 519 | 520 | final_outputs = tuple(torch.cat(outputs, dim=0) for outputs in final_outputs) 521 | 522 | return final_outputs 523 | 524 | @torch.no_grad() 525 | def generate_outputs(self, time, height=720, width=720, u_lims=[0,1], v_lims=[0,1], fov_scale=1.0, quaternion_camera_to_world=None, intrinsics_inv=None, translation=None, sensor_size=None): 526 | 527 | device = self.device 528 | 529 | uv = utils.make_grid(height, width, u_lims, v_lims) 530 | frame = int(time * (self.data.num_frames - 1)) 531 | t = torch.tensor(time, dtype=torch.float32).repeat(uv.shape[0])[:,None] # num_points x 1 532 | 533 | rgb_reference = self.data.sample_frame(frame, uv) # reference rgb samples 534 | 535 | if intrinsics_inv is None : 536 | intrinsics_inv = self.data.intrinsics_inv[frame:frame+2] # 2 x 3 x 3 537 | if time <= 0 or time >= 1.0: # select exact frame timestamp 538 | intrinsics_inv = intrinsics_inv[0:1] 539 | else: # interpolate between frames 540 | fraction = time * (self.data.num_frames - 1) - frame 541 | intrinsics_inv = intrinsics_inv[0:1] * (1 - fraction) + intrinsics_inv[1:2] * fraction 542 | 543 | if quaternion_camera_to_world is None: 544 | quaternion_camera_to_world = self.data.quaternion_camera_to_world[frame:frame+2] # 2 x 3 x 3 545 | if time <= 0 or time >= 1.0: 546 | quaternion_camera_to_world = quaternion_camera_to_world[0:1] 547 | else: # interpolate between frames 548 | fraction = time * (self.data.num_frames - 1) - frame 549 | quaternion_camera_to_world = quaternion_camera_to_world[0:1] * (1 - fraction) + quaternion_camera_to_world[1:2] * fraction 550 | 551 | intrinsics_inv = intrinsics_inv.clone() 552 | intrinsics_inv[:,0] = intrinsics_inv[:,0] * fov_scale 553 | intrinsics_inv[:,1] = intrinsics_inv[:,1] 554 | 555 | intrinsics_inv = intrinsics_inv.repeat(uv.shape[0],1,1) # num_points x 3 x 3 556 | quaternion_camera_to_world = quaternion_camera_to_world.repeat(uv.shape[0],1) # num_points x 4 557 | 558 | quaternion_camera_to_world, intrinsics_inv, t, uv = quaternion_camera_to_world.to(device), intrinsics_inv.to(device), t.to(device), uv.to(device) 559 | 560 | rgb, offset, feat_color_angle = self.chunk_forward(quaternion_camera_to_world, intrinsics_inv, t, uv, translation, chunk_size=3000**2) # break into chunks to avoid OOM 561 | 562 | rgb_reference = self.color_and_tone(rgb_reference, height, width) 563 | rgb = self.color_and_tone(rgb, height, width) 564 | 565 | offset = offset.reshape(height, width, 3).float().permute(2,0,1) 566 | 567 | # Normalize the offset tensor along the axis 568 | offset = offset 569 | offset_img = utils.colorize_tensor(offset.mean(dim=0), vmin=offset.min(), vmax=offset.max(), cmap="RdYlBu") 570 | 571 | return rgb_reference, rgb, offset, offset_img 572 | 573 | 574 | def save_outputs(self, path, high_res=False): 575 | os.makedirs(f"outputs/{self.args.name + path}", exist_ok=True) 576 | if high_res: 577 | rgb_reference, rgb, offset, offset_img = model.generate_outputs(time=0, height=2560, width=1920) 578 | else: 579 | rgb_reference, rgb, offset, offset_img = model.generate_outputs(time=0, height=1080, width=810) 580 | 581 | 582 | ######################################################################################################### 583 | ############################################### VALIDATION ############################################## 584 | ######################################################################################################### 585 | 586 | class ValidationCallback(pl.Callback): 587 | def __init__(self): 588 | super().__init__() 589 | self.unlock = False 590 | 591 | def bright(self, rgb): 592 | return ((rgb / np.percentile(rgb, 95)) ** 0.7).clamp(0,1) 593 | 594 | def on_train_epoch_start(self, trainer, model): 595 | print(f"Training phase (0-1): {model.training_phase}") 596 | 597 | if model.current_epoch == 1: 598 | model.model_translation.translations_coarse.requires_grad_(True) 599 | model.model_translation.translations_fine.requires_grad_(True) 600 | model.model_rotation.rotations.requires_grad_(True) 601 | 602 | if model.args.fast: # skip tensorboarding except for beginning and end 603 | if model.current_epoch == model.args.max_epochs - 1 or model.current_epoch == 0: 604 | pass 605 | else: 606 | return 607 | 608 | for i, time in enumerate([0.2, 0.5, 0.8]): 609 | if model.args.hd: 610 | rgb_reference, rgb, offset, offset_img = model.generate_outputs(time=time, height=1080, width=1080, fov_scale=1.4) 611 | else: 612 | rgb_reference, rgb, offset, offset_img = model.generate_outputs(time=time, fov_scale=2.5) 613 | 614 | model.logger.experiment.add_image(f'pred/{i}_rgb_combined', rgb, global_step=trainer.global_step) 615 | model.logger.experiment.add_image(f'pred/{i}_rgb_combined_bright', self.bright(rgb), global_step=trainer.global_step) 616 | model.logger.experiment.add_image(f'pred/{i}_offset', offset_img, global_step=trainer.global_step) 617 | 618 | 619 | def on_train_start(self, trainer, model): 620 | pl.seed_everything(42) # the answer to life, the universe, and everything 621 | 622 | # initialize rgb as average color of first frame of data (minimize the amount the rgb models have to learn) 623 | model.model_light_sphere.initial_rgb.data = torch.mean(model.data.rgb_volume[0], dim=(1,2))[None,:].to(model.device).to(torch.float16) 624 | 625 | model.logger.experiment.add_text("args", str(model.args)) 626 | 627 | for i, time in enumerate([0, 0.5, 1.0]): 628 | rgb_reference, rgb, offset, offset_img = model.generate_outputs(time=time, height=1080, width=810) 629 | model.logger.experiment.add_image(f'gt/{i}_rgb_reference', rgb_reference, global_step=trainer.global_step) 630 | model.logger.experiment.add_image(f'gt/{i}_rgb_reference_bright', self.bright(rgb_reference), global_step=trainer.global_step) 631 | 632 | model.training_phase = 0.05 633 | model.data.training_phase = model.training_phase 634 | 635 | 636 | def on_train_end(self, trainer, model): 637 | checkpoint_dir = os.path.join("checkpoints", model.args.name, "last.ckpt") 638 | data_dir = os.path.join("checkpoints", model.args.name, "data.pkl") 639 | 640 | os.makedirs(os.path.dirname(checkpoint_dir), exist_ok=True) 641 | checkpoint = trainer._checkpoint_connector.dump_checkpoint() 642 | 643 | # Forcibly remove optimizer states from the checkpoint 644 | if 'optimizer_states' in checkpoint: 645 | del checkpoint['optimizer_states'] 646 | 647 | torch.save(checkpoint, checkpoint_dir) 648 | 649 | with open(data_dir, 'wb') as file: 650 | model.data.rgb_volume = None 651 | pickle.dump(model.data, file) 652 | 653 | if __name__ == "__main__": 654 | 655 | # argparse 656 | parser = argparse.ArgumentParser() 657 | 658 | # data 659 | parser.add_argument('--point_batch_size', type=int, default=2**18, help="Number of points to sample per dataloader index.") 660 | parser.add_argument('--num_batches', type=int, default=200, help="Number of training batches.") 661 | parser.add_argument('--max_percentile', type=float, default=99.5, help="Percentile of lightest pixels to cut.") 662 | parser.add_argument('--frames', type=str, help="Which subset of frames to use for training, e.g. 0,10,20,30,40") 663 | 664 | # model 665 | parser.add_argument('--rotation_weight', type=float, default=1e-2, help="Scale learned rotation.") 666 | parser.add_argument('--translation_weight', type=float, default=1e0, help="Scale learned translation.") 667 | parser.add_argument('--rolling_shutter', action='store_true', help="Use rolling shutter compensation.") 668 | parser.add_argument('--no_mask', action='store_true', help="Do not use mask.") 669 | parser.add_argument('--no_offset', action='store_true', help="Do not use ray offset model.") 670 | parser.add_argument('--no_view_color', action='store_true', help="Do not use view dependent color model.") 671 | parser.add_argument('--no_lens_distortion', action='store_true', help="Do not use lens distortion model.") 672 | 673 | 674 | 675 | # light sphere 676 | parser.add_argument('--encoding_offset_position_config', type=str, default="small", help="Encoding offset position configuration (tiny, small, medium, large, ultrakill).") 677 | parser.add_argument('--encoding_offset_angle_config', type=str, default="small", help="Encoding offset angle configuration (tiny, small, medium, large, ultrakill).") 678 | parser.add_argument('--network_offset_config', type=str, default="large", help="Network offset configuration (tiny, small, medium, large, ultrakill).") 679 | 680 | parser.add_argument('--encoding_color_position_config', type=str, default="large", help="Encoding color position configuration (tiny, small, medium, large, ultrakill).") 681 | parser.add_argument('--encoding_color_angle_config', type=str, default="small", help="Encoding color angle configuration (tiny, small, medium, large, ultrakill).") 682 | parser.add_argument('--network_color_position_config', type=str, default="large", help="Network color position configuration (tiny, small, medium, large, ultrakill).") 683 | parser.add_argument('--network_color_angle_config', type=str, default="large", help="Network color angle configuration (tiny, small, medium, large, ultrakill).") 684 | 685 | # training 686 | parser.add_argument('--data_path', '--d', type=str, required=True, help="Path to frame_bundle.npz") 687 | parser.add_argument('--name', type=str, required=True, help="Experiment name for logs and checkpoints.") 688 | parser.add_argument('--max_epochs', type=int, default=100, help="Number of training epochs.") 689 | parser.add_argument('--lr', type=float, default=5e-4, help="Learning rate.") 690 | parser.add_argument('--weight_decay', type=float, default=1e-5, help="Weight decay.") 691 | parser.add_argument('--save_video', action='store_true', help="Store training outputs at each epoch for visualization.") 692 | parser.add_argument('--num_workers', type=int, default=4, help="Number of dataloader workers.") 693 | parser.add_argument('--debug', action='store_true', help="Debug mode, only use 1 batch.") 694 | parser.add_argument('--fast', action='store_true', help="Fast mode.") 695 | parser.add_argument('--cache', action='store_true', help="Cache data.") 696 | parser.add_argument('--hd', action='store_true', help="Make tensorboard HD.") 697 | 698 | 699 | args = parser.parse_args() 700 | # parse plane args 701 | print(args) 702 | if args.frames is not None: 703 | args.frames = [int(x) for x in re.findall(r"[-+]?\d*\.\d+|\d+", args.frames)] 704 | 705 | # model 706 | model = PanoModel(args) 707 | model.load_volume() 708 | 709 | # dataset 710 | data = model.data 711 | train_loader = DataLoader(data, batch_size=1, num_workers=args.num_workers, shuffle=False, pin_memory=True, prefetch_factor=1) 712 | 713 | model.model_translation.translations_coarse.requires_grad_(False) 714 | model.model_translation.translations_fine.requires_grad_(False) 715 | model.model_rotation.rotations.requires_grad_(False) 716 | 717 | torch.set_float32_matmul_precision('medium') 718 | 719 | # training 720 | lr_callback = pl.callbacks.LearningRateMonitor() 721 | logger = pl.loggers.TensorBoardLogger(save_dir=os.getcwd(), version=args.name, name="lightning_logs") 722 | validation_callback = ValidationCallback() 723 | trainer = pl.Trainer(accelerator="gpu", devices=torch.cuda.device_count(), num_nodes=1, strategy="auto", max_epochs=args.max_epochs, 724 | logger=logger, callbacks=[validation_callback, lr_callback], enable_checkpointing=False, fast_dev_run=args.debug) 725 | trainer.fit(model, train_loader) 726 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | plt.switch_backend('Agg') # non-visual 7 | 8 | def interp(x, xp, fp): 9 | """ 10 | Linear interpolation of values fp from known points xp to new points x. 11 | """ 12 | indices = torch.searchsorted(xp, x).clamp(1, len(xp) - 1) 13 | x0, x1 = xp[indices - 1], xp[indices] 14 | y0, y1 = fp[indices - 1], fp[indices] 15 | 16 | return (y0 + (y1 - y0) * (x - x0) / (x1 - x0)).to(fp.dtype) 17 | 18 | def multi_interp(x, xp, fp): 19 | """ 20 | Multi-dimensional linear interpolation of fp from xp to x along all axes. 21 | """ 22 | if torch.is_tensor(fp): 23 | out = [interp(x, xp, fp[:, i]) for i in range(fp.shape[-1])] 24 | return torch.stack(out, dim=-1).to(fp.dtype) 25 | else: 26 | out = [np.interp(x, xp, fp[:, i]) for i in range(fp.shape[-1])] 27 | return np.stack(out, axis=-1).astype(fp.dtype) 28 | 29 | def interpolate_params(params, t): 30 | """ 31 | Interpolate parameters over time t, linearly between frames. 32 | """ 33 | num_frames = params.shape[-1] 34 | frame_number = t * (num_frames - 1) 35 | integer_frame = torch.floor(frame_number).long() 36 | fractional_frame = frame_number.to(params.dtype) - integer_frame.to(params.dtype) 37 | 38 | # Ensure indices are within valid range 39 | next_frame = torch.clamp(integer_frame + 1, 0, num_frames - 1) 40 | integer_frame = torch.clamp(integer_frame, 0, num_frames - 1) 41 | 42 | param_now = params[:, :, integer_frame] 43 | param_next = params[:, :, next_frame] 44 | 45 | # Linear interpolation between current and next frame parameters 46 | interpolated_params = param_now + fractional_frame * (param_next - param_now) 47 | 48 | return interpolated_params.squeeze(0).squeeze(-1).permute(1, 0) 49 | 50 | class MaskFunction(torch.autograd.Function): 51 | @staticmethod 52 | def forward(ctx, encoding, mask_coef, initial=0.4): 53 | """ 54 | Forward pass for masking, scales mask_coef to blend encoding. 55 | """ 56 | mask_coef = initial + (1 - initial) * mask_coef 57 | mask = torch.zeros_like(encoding[0:1]) 58 | mask_ceil = int(np.ceil(mask_coef * encoding.shape[1])) 59 | mask[..., :mask_ceil] = 1.0 60 | ctx.save_for_backward(mask) 61 | return encoding * mask 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | """ 66 | Backward pass to retain masked gradients. 67 | """ 68 | mask, = ctx.saved_tensors 69 | return grad_output * mask, None, None 70 | 71 | def mask(encoding, mask_coef, initial=0.4): 72 | """ 73 | Apply mask to encoding with scaling factor mask_coef. 74 | """ 75 | return MaskFunction.apply(encoding, mask_coef, initial) 76 | 77 | def make_grid(height, width, u_lims, v_lims): 78 | """ 79 | Create (u,v) meshgrid of size (height, width) with given limits. 80 | """ 81 | u = torch.linspace(u_lims[0], u_lims[1], width) 82 | v = torch.linspace(v_lims[1], v_lims[0], height) # Flip for array convention 83 | u_grid, v_grid = torch.meshgrid([u, v], indexing="xy") 84 | return torch.stack((u_grid.flatten(), v_grid.flatten())).permute(1, 0) 85 | 86 | def unwrap_quaternions(q): 87 | """ 88 | Remove 2pi wraps from quaternion rotations. 89 | """ 90 | n = q.shape[0] 91 | unwrapped_q = q.clone() 92 | for i in range(1, n): 93 | cos_theta = torch.dot(unwrapped_q[i-1], unwrapped_q[i]) 94 | if cos_theta < 0: 95 | unwrapped_q[i] = -unwrapped_q[i] 96 | return unwrapped_q 97 | 98 | @torch.jit.script 99 | def quaternion_conjugate(q): 100 | """ 101 | Return the conjugate of a quaternion. 102 | """ 103 | q_conj = q.clone() 104 | q_conj[:, 1:] = -q_conj[:, 1:] # Invert vector part 105 | return q_conj 106 | 107 | def quaternion_multiply(q1, q2): 108 | """ 109 | Multiply two quaternions. 110 | """ 111 | w1, v1 = q1[..., 0], q1[..., 1:] 112 | w2, v2 = q2[..., 0], q2[..., 1:] 113 | w = w1 * w2 - torch.sum(v1 * v2, dim=-1) 114 | v = w1.unsqueeze(-1) * v2 + w2.unsqueeze(-1) * v1 + torch.cross(v1, v2, dim=-1) 115 | return torch.cat((w.unsqueeze(-1), v), dim=-1) 116 | 117 | @torch.jit.script 118 | def convert_quaternions_to_rot(quaternions): 119 | """ 120 | Convert quaternions (WXYZ) to 3x3 rotation matrices. 121 | """ 122 | w, x, y, z = quaternions.unbind(-1) 123 | r00 = 1 - 2 * (y**2 + z**2) 124 | r01 = 2 * (x * y - z * w) 125 | r02 = 2 * (x * z + y * w) 126 | r10 = 2 * (x * y + z * w) 127 | r11 = 1 - 2 * (x**2 + z**2) 128 | r12 = 2 * (y * z - x * w) 129 | r20 = 2 * (x * z - y * w) 130 | r21 = 2 * (y * z + x * w) 131 | r22 = 1 - 2 * (x**2 + y**2) 132 | R = torch.stack([r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=-1) 133 | return R.reshape(-1, 3, 3) 134 | 135 | @torch.no_grad() 136 | def raw_to_rgb(bundle): 137 | """ 138 | Convert RAW mosaic to three-channel RGB volume by filling empty pixels. 139 | """ 140 | raw_frames = torch.tensor(np.array([bundle[f'raw_{i}']['raw'] for i in range(bundle['num_raw_frames'])]).astype(np.int32), dtype=torch.float32)[None] 141 | raw_frames = raw_frames.permute(1, 0, 2, 3) 142 | color_correction_gains = bundle['raw_0']['color_correction_gains'] 143 | color_filter_arrangement = bundle['characteristics']['color_filter_arrangement'] 144 | blacklevel = torch.tensor(np.array([bundle[f'raw_{i}']['blacklevel'] for i in range(bundle['num_raw_frames'])]))[:, :, None, None] 145 | whitelevel = torch.tensor(np.array([bundle[f'raw_{i}']['whitelevel'] for i in range(bundle['num_raw_frames'])]))[:, None, None, None] 146 | shade_maps = torch.tensor(np.array([bundle[f'raw_{i}']['shade_map'] for i in range(bundle['num_raw_frames'])])).permute(0, 3, 1, 2) 147 | shade_maps = F.interpolate(shade_maps, size=(raw_frames.shape[-2]//2, raw_frames.shape[-1]//2), mode='bilinear', align_corners=False) 148 | 149 | top_left = raw_frames[:, :, 0::2, 0::2] 150 | top_right = raw_frames[:, :, 0::2, 1::2] 151 | bottom_left = raw_frames[:, :, 1::2, 0::2] 152 | bottom_right = raw_frames[:, :, 1::2, 1::2] 153 | 154 | if color_filter_arrangement == 0: # RGGB 155 | R, G1, G2, B = top_left, top_right, bottom_left, bottom_right 156 | elif color_filter_arrangement == 1: # GRBG 157 | G1, R, B, G2 = top_left, top_right, bottom_left, bottom_right 158 | elif color_filter_arrangement == 2: # GBRG 159 | G1, B, R, G2 = top_left, top_right, bottom_left, bottom_right 160 | elif color_filter_arrangement == 3: # BGGR 161 | B, G1, G2, R = top_left, top_right, bottom_left, bottom_right 162 | 163 | # Apply color correction gains, flip to portrait 164 | R = ((R - blacklevel[:, 0:1]) / (whitelevel - blacklevel[:, 0:1]) * color_correction_gains[0]) 165 | R *= shade_maps[:, 0:1] 166 | G1 = ((G1 - blacklevel[:, 1:2]) / (whitelevel - blacklevel[:, 1:2]) * color_correction_gains[1]) 167 | G1 *= shade_maps[:, 1:2] 168 | G2 = ((G2 - blacklevel[:, 2:3]) / (whitelevel - blacklevel[:, 2:3]) * color_correction_gains[2]) 169 | G2 *= shade_maps[:, 2:3] 170 | B = ((B - blacklevel[:, 3:4]) / (whitelevel - blacklevel[:, 3:4]) * color_correction_gains[3]) 171 | B *= shade_maps[:, 3:4] 172 | 173 | rgb_volume = torch.zeros(raw_frames.shape[0], 3, raw_frames.shape[-2], raw_frames.shape[-1], dtype=torch.float32) 174 | 175 | # Fill gaps in blue channel 176 | rgb_volume[:, 2, 0::2, 0::2] = B.squeeze(1) 177 | rgb_volume[:, 2, 0::2, 1::2] = (B + torch.roll(B, -1, dims=3)).squeeze(1) / 2 178 | rgb_volume[:, 2, 1::2, 0::2] = (B + torch.roll(B, -1, dims=2)).squeeze(1) / 2 179 | rgb_volume[:, 2, 1::2, 1::2] = (B + torch.roll(B, -1, dims=2) + torch.roll(B, -1, dims=3) + torch.roll(B, [-1, -1], dims=[2, 3])).squeeze(1) / 4 180 | 181 | # Fill gaps in green channel 182 | rgb_volume[:, 1, 0::2, 0::2] = G1.squeeze(1) 183 | rgb_volume[:, 1, 0::2, 1::2] = (G1 + torch.roll(G1, -1, dims=3) + G2 + torch.roll(G2, 1, dims=2)).squeeze(1) / 4 184 | rgb_volume[:, 1, 1::2, 0::2] = (G1 + torch.roll(G1, -1, dims=2) + G2 + torch.roll(G2, 1, dims=3)).squeeze(1) / 4 185 | rgb_volume[:, 1, 1::2, 1::2] = G2.squeeze(1) 186 | 187 | # Fill gaps in red channel 188 | rgb_volume[:, 0, 0::2, 0::2] = R.squeeze(1) 189 | rgb_volume[:, 0, 0::2, 1::2] = (R + torch.roll(R, -1, dims=3)).squeeze(1) / 2 190 | rgb_volume[:, 0, 1::2, 0::2] = (R + torch.roll(R, -1, dims=2)).squeeze(1) / 2 191 | rgb_volume[:, 0, 1::2, 1::2] = (R + torch.roll(R, -1, dims=2) + torch.roll(R, -1, dims=3) + torch.roll(R, [-1, -1], dims=[2, 3])).squeeze(1) / 4 192 | 193 | rgb_volume = torch.flip(rgb_volume.transpose(-1, -2), [-1]) # Rotate 90 degrees to portrait 194 | return rgb_volume 195 | 196 | def de_item(bundle): 197 | """ 198 | Call .item() on all dictionary items, removing extra dimensions. 199 | """ 200 | bundle['motion'] = bundle['motion'].item() 201 | bundle['characteristics'] = bundle['characteristics'].item() 202 | 203 | for i in range(bundle['num_raw_frames']): 204 | bundle[f'raw_{i}'] = bundle[f'raw_{i}'].item() 205 | 206 | def debatch(batch): 207 | """ 208 | Collapse batch and channel dimensions together. 209 | """ 210 | debatched = [] 211 | for x in batch: 212 | if len(x.shape) <= 1: 213 | raise Exception("This tensor is too small to debatch.") 214 | elif len(x.shape) == 2: 215 | debatched.append(x.reshape(x.shape[0] * x.shape[1])) 216 | else: 217 | debatched.append(x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])) 218 | return debatched 219 | 220 | def apply_ccm(image, ccm): 221 | """ 222 | Apply Color Correction Matrix (CCM) to the image. 223 | """ 224 | if image.dim() == 3: 225 | corrected_image = torch.einsum('ij,jkl->ikl', ccm, image) 226 | else: 227 | corrected_image = ccm @ image 228 | return corrected_image.clamp(0, 1) 229 | 230 | def apply_tonemap(image, tonemap): 231 | """ 232 | Apply tonemapping curve to the image using custom linear interpolation. 233 | """ 234 | toned_image = torch.empty_like(image) 235 | for i in range(3): 236 | x_vals = tonemap[i][:, 0].contiguous() 237 | y_vals = tonemap[i][:, 1].contiguous() 238 | toned_image[i] = interp(image[i], x_vals, y_vals) 239 | return toned_image 240 | 241 | def colorize_tensor(value, vmin=None, vmax=None, cmap=None, colorbar=False, height=9.6, width=7.2): 242 | """ 243 | Convert tensor to 3-channel RGB array using colormap (similar to plt.imshow). 244 | """ 245 | assert len(value.shape) == 2 246 | fig, ax = plt.subplots(1, 1) 247 | fig.set_size_inches(width, height) 248 | a = ax.imshow(value.detach().cpu(), vmin=vmin, vmax=vmax, cmap=cmap) 249 | ax.set_axis_off() 250 | if colorbar: 251 | cbar = plt.colorbar(a, fraction=0.05) 252 | cbar.ax.tick_params(labelsize=30) 253 | plt.tight_layout() 254 | 255 | # Convert figure to numpy array 256 | fig.canvas.draw() 257 | img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 258 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 259 | img = img / 255.0 260 | 261 | plt.close(fig) 262 | return torch.tensor(img).permute(2, 0, 1).float() 263 | --------------------------------------------------------------------------------