├── .figs ├── extras.png ├── occ-main.png ├── occ-supp.png ├── occ-wild.png ├── pani-badge.svg ├── ref-main.png ├── ref-supp.png └── ref-wild.png ├── LICENSE ├── README.md ├── checkpoints └── __init__.py ├── config ├── config_large.json ├── config_medium.json ├── config_small.json └── config_tiny.json ├── data └── __init__.py ├── lightning_logs └── __init__.py ├── outputs └── __init__.py ├── requirements.txt ├── scripts ├── dehaze.sh ├── fusion.sh ├── occlusion-wild.sh ├── occlusion.sh ├── reflection-wild.sh ├── reflection.sh ├── segmentation.sh └── shadow.sh ├── train.py ├── tutorial.ipynb └── utils └── utils.py /.figs/extras.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/.figs/extras.png -------------------------------------------------------------------------------- /.figs/occ-main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/.figs/occ-main.png -------------------------------------------------------------------------------- /.figs/occ-supp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/.figs/occ-supp.png -------------------------------------------------------------------------------- /.figs/occ-wild.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/.figs/occ-wild.png -------------------------------------------------------------------------------- /.figs/pani-badge.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | Android App 45 | 46 | Android App 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /.figs/ref-main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/.figs/ref-main.png -------------------------------------------------------------------------------- /.figs/ref-supp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/.figs/ref-supp.png -------------------------------------------------------------------------------- /.figs/ref-wild.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/.figs/ref-wild.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ilya Chugunov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Neural Spline Fields for Burst Image Fusion and Layer Separation 3 | 4 | Open In Colab 5 | 6 | 7 | Android Capture App 8 | 9 | 10 | This is the official code repository for the CVPR 2024 work: [Neural Spline Fields for Burst Image Fusion and Layer Separation](https://light.princeton.edu/publication/nsf/). If you use parts of this work, or otherwise take inspiration from it, please considering citing our paper: 11 | ``` 12 | @inproceedings{chugunov2024neural, 13 | title={Neural spline fields for burst image fusion and layer separation}, 14 | author={Chugunov, Ilya and Shustin, David and Yan, Ruyu and Lei, Chenyang and Heide, Felix}, 15 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 16 | pages={25763--25773}, 17 | year={2024} 18 | } 19 | ``` 20 | 21 | ## Requirements: 22 | - Code was written in PyTorch 2.0 on an Ubuntu 22.04 machine. 23 | - Condensed package requirements are in `\requirements.txt`. Note that this contains the exact package versions at the time of publishing. Code will most likely work with newer versions of the libraries, but you will need to watch out for changes in class/function calls. 24 | - The non-standard packages you may need are `pytorch_lightning`, `commentjson`, `rawpy`, and `tinycudann`. See [NVlabs/tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn) for installation instructions. Depending on your system you might just be able to do `pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch`, or might have to cmake and build it from source. 25 | 26 | ## Project Structure: 27 | ```cpp 28 | NSF 29 | ├── checkpoints 30 | │ └── // folder for network checkpoints 31 | ├── config 32 | │ └── // network and encoding configurations for different sizes of MLPs 33 | ├── data 34 | │ └── // folder for long-burst data 35 | ├── lightning_logs 36 | │ └── // folder for tensorboard logs 37 | ├── outputs 38 | │ └── // folder for model outputs (e.g., final reconstructions) 39 | ├── scripts 40 | │ └── // training scripts for different tasks (e.g., occlusion/reflection/shadow separation) 41 | ├── utils 42 | │ └── utils.py // network helper functions (e.g., RAW demosaicing, spline interpolation) 43 | ├── LICENSE // legal stuff 44 | ├── README.md // <- you are here 45 | ├── requirements.txt // frozen package requirements 46 | ├── train.py // dataloader, network, visualization, and trainer code 47 | └── tutorial.ipynb // interactive tutorial for training the model 48 | ``` 49 | ## Getting Started: 50 | We highly recommend you start by going through `tutorial.ipynb`, either on your own machine or [with this Google Colab link](https://colab.research.google.com/github/princeton-computational-imaging/NSF/blob/main/tutorial.ipynb). 51 | 52 | TLDR: models can be trained with: 53 | 54 | `bash scripts/{application}.sh --bundle_path {path_to_data} --name {checkpoint_name}` 55 | 56 | And reconstruction outputs will get saved to `outputs/{checkpoint_name}-final` 57 | 58 | For a full list of training arguments, we recommend looking through the argument parser section at the bottom of `\train.py`. 59 | 60 | ## Data: 61 | You can download the long-burst data used in the paper (and extra bonus scenes) via the following links: 62 | 63 | 1. Main occlusion scenes: [occlusion-main.zip](https://soap.cs.princeton.edu/nsf/data/occlusion-main.zip) (use `scripts/occlusion.sh` to train) 64 | ![Main Occlusion](.figs/occ-main.png) 65 | 66 | 2. Supplementary occlusion scenes: [occlusion-supp.zip](https://soap.cs.princeton.edu/nsf/data/occlusion-supp.zip) (use `scripts/occlusion.sh` to train) 67 | ![Supplementary Occlusion](.figs/occ-supp.png) 68 | 69 | 3. In-the-wild occlusion scenes: [occlusion-wild.zip](https://soap.cs.princeton.edu/nsf/data/occlusion-wild.zip) (use `scripts/occlusion-wild.sh` to train) 70 | ![Wild Occlusion](.figs/occ-wild.png) 71 | 72 | 4. Main reflection scenes: [reflection-main.zip](https://soap.cs.princeton.edu/nsf/data/reflection-main.zip) (use `scripts/reflection.sh` to train) 73 | ![Main Reflection](.figs/ref-main.png) 74 | 75 | 5. Supplementary reflection scenes: [reflection-supp.zip](https://soap.cs.princeton.edu/nsf/data/reflection-supp.zip) (use `scripts/reflection.sh` to train) 76 | ![Supplementary Reflection](.figs/ref-supp.png) 77 | 78 | 6. In-the-wild reflection scenes: [reflection-wild.zip](https://soap.cs.princeton.edu/nsf/data/reflection-wild.zip) (use `scripts/reflection-wild.sh` to train) 79 | ![Wild Reflection](.figs/ref-wild.png) 80 | 81 | 7. Extra scenes: [extras.zip](https://soap.cs.princeton.edu/nsf/data/extras.zip) (use `scripts/dehaze.sh`, `segmentation.sh`, or `shadow.sh`) 82 | ![Extras](.figs/extras.png) 83 | 84 | 7. Synthetic validation: [synthetic-validation.zip](https://soap.cs.princeton.edu/nsf/data/synthetic-validation.zip) (use `scripts/reflection.sh` or `occlusion.sh` with flag `--rgb`) 85 | 86 | We recommend you download and extract these into the `data/` folder. 87 | 88 | ## App: 89 | Want to record your own long-burst data? Check out our Android RAW capture app [Pani!](https://github.com/Ilya-Muromets/Pani) 90 | 91 | Good luck have fun, 92 | Ilya 93 | -------------------------------------------------------------------------------- /checkpoints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/checkpoints/__init__.py -------------------------------------------------------------------------------- /config/config_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "encoding": { 3 | "otype": "HashGrid", 4 | "n_levels": 16, 5 | "n_features_per_level": 4, 6 | "log2_hashmap_size": 17, 7 | "base_resolution": 4, 8 | "per_level_scale": 1.61, 9 | "interpolation": "Linear" 10 | }, 11 | "network": { 12 | "otype": "FullyFusedMLP", 13 | "activation": "ReLU", 14 | "output_activation": "None", 15 | "n_neurons": 64, 16 | "n_hidden_layers": 5 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /config/config_medium.json: -------------------------------------------------------------------------------- 1 | { 2 | "encoding": { 3 | "otype": "HashGrid", 4 | "n_levels": 12, 5 | "n_features_per_level": 4, 6 | "log2_hashmap_size": 15, 7 | "base_resolution": 4, 8 | "per_level_scale": 1.61, 9 | "interpolation": "Linear" 10 | }, 11 | "network": { 12 | "otype": "FullyFusedMLP", 13 | "activation": "ReLU", 14 | "output_activation": "None", 15 | "n_neurons": 64, 16 | "n_hidden_layers": 4 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /config/config_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "encoding": { 3 | "otype": "HashGrid", 4 | "n_levels": 8, 5 | "n_features_per_level": 4, 6 | "log2_hashmap_size": 13, 7 | "base_resolution": 4, 8 | "per_level_scale": 1.61, 9 | "interpolation": "Linear" 10 | }, 11 | "network": { 12 | "otype": "FullyFusedMLP", 13 | "activation": "ReLU", 14 | "output_activation": "None", 15 | "n_neurons": 64, 16 | "n_hidden_layers": 3 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /config/config_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "encoding": { 3 | "otype": "HashGrid", 4 | "n_levels": 6, 5 | "n_features_per_level": 4, 6 | "log2_hashmap_size": 12, 7 | "base_resolution": 4, 8 | "per_level_scale": 1.61, 9 | "interpolation": "Linear" 10 | }, 11 | "network": { 12 | "otype": "FullyFusedMLP", 13 | "activation": "ReLU", 14 | "output_activation": "None", 15 | "n_neurons": 64, 16 | "n_hidden_layers": 3 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/data/__init__.py -------------------------------------------------------------------------------- /lightning_logs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/lightning_logs/__init__.py -------------------------------------------------------------------------------- /outputs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/outputs/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | commentjson>=0.9.0 2 | matplotlib>=3.7.0 3 | natsort>=8.3.1 4 | numpy>=1.24.3 5 | opencv_python>=4.5.4.58 6 | pytorch_lightning>=2.0.1 7 | rawpy>=0.18.1 8 | torch>=2.0.0+cu118 9 | tqdm>=4.64.1 10 | -------------------------------------------------------------------------------- /scripts/dehaze.sh: -------------------------------------------------------------------------------- 1 | python3 train.py \ 2 | --obstruction_flow_grid_size=tiny \ 3 | --obstruction_image_grid_size=tiny \ 4 | --obstruction_alpha_grid_size=small \ 5 | --obstruction_initial_alpha=0.5 \ 6 | --obstruction_initial_depth=0.5 \ 7 | --transmission_flow_grid_size=tiny \ 8 | --transmission_image_grid_size=large \ 9 | --transmission_initial_depth=1.0 \ 10 | --alpha_weight=1e-2 \ 11 | --alpha_temperature=4.0 \ 12 | --translation_weight=1e-1 \ 13 | "$@" -------------------------------------------------------------------------------- /scripts/fusion.sh: -------------------------------------------------------------------------------- 1 | python3 train.py \ 2 | --transmission_flow_grid_size=tiny \ 3 | --transmission_image_grid_size=large \ 4 | --transmission_initial_depth=0.5 \ 5 | --transmission_control_points_flow=31 \ 6 | --single_plane \ 7 | --camera_control_points=31 \ 8 | --alpha_weight=0.0 \ 9 | "$@" -------------------------------------------------------------------------------- /scripts/occlusion-wild.sh: -------------------------------------------------------------------------------- 1 | python3 train.py \ 2 | --obstruction_flow_grid_size=tiny \ 3 | --obstruction_image_grid_size=medium \ 4 | --obstruction_alpha_grid_size=medium \ 5 | --obstruction_initial_depth=0.5 \ 6 | --transmission_flow_grid_size=tiny \ 7 | --transmission_image_grid_size=large \ 8 | --transmission_initial_depth=1.0 \ 9 | --alpha_weight=1e-2 \ 10 | --alpha_temperature=0.1 \ 11 | --lr=3e-5 \ 12 | "$@" 13 | -------------------------------------------------------------------------------- /scripts/occlusion.sh: -------------------------------------------------------------------------------- 1 | python3 train.py \ 2 | --obstruction_flow_grid_size=tiny \ 3 | --obstruction_image_grid_size=medium \ 4 | --obstruction_alpha_grid_size=medium \ 5 | --obstruction_initial_depth=0.5 \ 6 | --transmission_flow_grid_size=tiny \ 7 | --transmission_image_grid_size=large \ 8 | --transmission_initial_depth=1.0 \ 9 | --alpha_weight=2e-2 \ 10 | --alpha_temperature=3.0 \ 11 | --lr=3e-5 \ 12 | "$@" 13 | -------------------------------------------------------------------------------- /scripts/reflection-wild.sh: -------------------------------------------------------------------------------- 1 | python3 train.py \ 2 | --obstruction_flow_grid_size=tiny \ 3 | --obstruction_image_grid_size=large \ 4 | --obstruction_alpha_grid_size=tiny \ 5 | --obstruction_initial_depth=2.0 \ 6 | --transmission_flow_grid_size=tiny \ 7 | --transmission_image_grid_size=large \ 8 | --transmission_initial_depth=1.0 \ 9 | --alpha_weight=5e-3 \ 10 | --alpha_temperature=0.1 \ 11 | --lr=2e-4 \ 12 | --translation_weight=1e-1 \ 13 | "$@" 14 | -------------------------------------------------------------------------------- /scripts/reflection.sh: -------------------------------------------------------------------------------- 1 | python3 train.py \ 2 | --obstruction_flow_grid_size=tiny \ 3 | --obstruction_image_grid_size=large \ 4 | --obstruction_alpha_grid_size=tiny \ 5 | --obstruction_initial_depth=2.0 \ 6 | --transmission_flow_grid_size=tiny \ 7 | --transmission_image_grid_size=large \ 8 | --transmission_initial_depth=1.0 \ 9 | --alpha_weight=0.0 \ 10 | --alpha_temperature=0.2 \ 11 | --lr=2e-4 \ 12 | --translation_weight=1e-1 \ 13 | "$@" 14 | -------------------------------------------------------------------------------- /scripts/segmentation.sh: -------------------------------------------------------------------------------- 1 | python3 train.py \ 2 | --obstruction_flow_grid_size=small \ 3 | --obstruction_image_grid_size=large \ 4 | --obstruction_alpha_grid_size=medium \ 5 | --obstruction_control_points_flow=15 \ 6 | --obstruction_initial_depth=0.5 \ 7 | --obstruction_initial_alpha=0.5 \ 8 | --transmission_flow_grid_size=small \ 9 | --transmission_image_grid_size=large \ 10 | --transmission_initial_depth=1.0 \ 11 | --transmission_control_points_flow=15 \ 12 | --camera_control_points=15 \ 13 | --alpha_weight=2e-3 \ 14 | --alpha_temperature=0.15 \ 15 | --translation_weight=1e0 \ 16 | "$@" -------------------------------------------------------------------------------- /scripts/shadow.sh: -------------------------------------------------------------------------------- 1 | python3 train.py \ 2 | --obstruction_flow_grid_size=tiny \ 3 | --obstruction_image_grid_size=tiny \ 4 | --obstruction_alpha_grid_size=medium \ 5 | --obstruction_initial_alpha=0.5 \ 6 | --obstruction_initial_depth=10.0 \ 7 | --transmission_flow_grid_size=tiny \ 8 | --transmission_image_grid_size=large \ 9 | --transmission_initial_depth=0.5 \ 10 | --alpha_weight=0.0 \ 11 | --alpha_temperature=4.0 \ 12 | --translation_weight=1e-1 \ 13 | "$@" -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import commentjson as json 3 | import numpy as np 4 | import os 5 | import re 6 | import pickle 7 | 8 | import tinycudann as tcnn 9 | 10 | from utils import utils 11 | from utils.utils import debatch 12 | import matplotlib.pyplot as plt 13 | 14 | import torch 15 | from torch.nn import functional as F 16 | from torch.utils.data import Dataset 17 | from torch.utils.data import DataLoader 18 | import pytorch_lightning as pl 19 | 20 | ######################################################################################################### 21 | ################################################ DATASET ################################################ 22 | ######################################################################################################### 23 | 24 | class BundleDataset(Dataset): 25 | def __init__(self, args, load_volume=False): 26 | self.args = args 27 | print("Loading from:", self.args.bundle_path) 28 | 29 | if self.args.rgb_data: 30 | self.init_RGB(self.args, load_volume) 31 | else: 32 | self.init_RAW(self.args, load_volume) 33 | 34 | def init_RGB(self, args, load_volume=False): 35 | 36 | self.lens_distortion = torch.tensor([0.0,0.0,0.0,0.0,0.0]).float() 37 | self.ccm = torch.tensor(np.eye(3)).float() 38 | self.tonemap_curve = torch.tensor(np.linspace(0,1,65)[None,:,None]).float().repeat(3,1,2) # identity tonemap curve 39 | 40 | bundle = dict(np.load(args.bundle_path, allow_pickle=True)) 41 | 42 | if "reference_img" in bundle.keys(): 43 | self.reference_img = torch.tensor(bundle["reference_img"]).float() 44 | if "translations" in bundle.keys(): 45 | self.translations = torch.tensor(bundle["translations"]).float() # T,3 46 | 47 | # not efficient, load twice, but we're missing metadata for img sizes 48 | self.rgb_volume = torch.tensor(bundle["rgb_volume"]).float()[:,:3] # T,C,H,W 49 | 50 | self.intrinsics = torch.tensor(bundle["intrinsics"]).float() # T,3,3 51 | self.intrinsics = self.intrinsics.transpose(1, 2) 52 | self.intrinsics_inv = torch.inverse(self.intrinsics) 53 | self.rotations = torch.tensor(bundle["rotations"]).float() # T,3,3 54 | self.reference_rotation = self.rotations[0] 55 | self.camera_to_world = self.reference_rotation.T @ self.rotations 56 | 57 | self.num_frames = self.rgb_volume.shape[0] 58 | self.img_channels = self.rgb_volume.shape[1] 59 | self.img_height = self.rgb_volume.shape[2] 60 | self.img_width = self.rgb_volume.shape[3] 61 | 62 | if args.frames is not None: 63 | # subsample frames 64 | self.num_frames = len(args.frames) 65 | self.rotations = self.rotations[args.frames] 66 | self.camera_to_world = self.camera_to_world[args.frames] 67 | self.intrinsics = self.intrinsics[args.frames] 68 | self.intrinsics_inv = self.intrinsics_inv[args.frames] 69 | 70 | self.load_volume() 71 | 72 | self.frame_batch_size = 2 * (self.args.point_batch_size // self.num_frames // 2) # nearest even cut 73 | self.point_batch_size = self.frame_batch_size * self.num_frames # nearest multiple of num_frames 74 | self.num_batches = self.args.num_batches 75 | 76 | self.sin_epoch = 0.0 # fraction of training complete 77 | self.frame_cutoff = self.num_frames 78 | print("Frame Count: ", self.num_frames) 79 | 80 | 81 | def init_RAW(self, args, load_volume=False): 82 | bundle = np.load(args.bundle_path, allow_pickle=True) 83 | 84 | self.characteristics = bundle['characteristics'].item() # camera characteristics 85 | self.motion = bundle['motion'].item() 86 | self.frame_timestamps = torch.tensor([bundle[f'raw_{i}'].item()['timestamp'] for i in range(bundle['num_raw_frames'])]) 87 | self.motion_timestamps = torch.tensor(self.motion['timestamp']) 88 | 89 | self.quaternions = torch.tensor(self.motion['quaternion']).float() # T',4, has different timestamps from frames 90 | # our scene is +z towards scene convention, but phone is +z towards face convention 91 | # so we need to rotate 180 degrees around y axis, or equivalently flip over z,y 92 | self.quaternions[:,2] = -self.quaternions[:,2] # invert y 93 | self.quaternions[:,3] = -self.quaternions[:,3] # invert z 94 | 95 | self.quaternions = utils.multi_interp(self.frame_timestamps, self.motion_timestamps, self.quaternions) 96 | self.rotations = utils.convert_quaternions_to_rot(self.quaternions) 97 | 98 | self.reference_quaternion = self.quaternions[0] 99 | self.reference_rotation = self.rotations[0] 100 | 101 | self.camera_to_world = self.reference_rotation.T @ self.rotations 102 | 103 | self.intrinsics = torch.tensor(np.array([bundle[f'raw_{i}'].item()['intrinsics'] for i in range(bundle['num_raw_frames'])])).float() # T,3,3 104 | # swap cx,cy -> landscape to portrait 105 | cx, cy = self.intrinsics[:, 2, 1].clone(), self.intrinsics[:, 2, 0].clone() 106 | self.intrinsics[:, 2, 0], self.intrinsics[:, 2, 1] = cx, cy 107 | # transpose to put cx,cy in right column 108 | self.intrinsics = self.intrinsics.transpose(1, 2) 109 | self.intrinsics_inv = torch.inverse(self.intrinsics) 110 | 111 | self.lens_distortion = bundle['raw_0'].item()['lens_distortion'] 112 | self.tonemap_curve = torch.tensor(bundle['raw_0'].item()['tonemap_curve']) 113 | self.ccm = utils.parse_ccm(bundle['raw_0'].item()['android']['colorCorrection.transform']) 114 | 115 | self.num_frames = bundle['num_raw_frames'].item() 116 | self.img_channels = 3 117 | self.img_height = bundle['raw_0'].item()['width'] # rotated 90 118 | self.img_width = bundle['raw_0'].item()['height'] 119 | self.rgb_volume = torch.ones([self.num_frames, self.img_channels, 3,3]).float() # T,C,3,3, tiny fake volume for lazy loading 120 | 121 | if args.frames is not None: 122 | # subsample frames 123 | self.num_frames = len(args.frames) 124 | self.frame_timestamps = self.frame_timestamps[args.frames] 125 | self.quaternions = self.quaternions[args.frames] 126 | self.rotations = self.rotations[args.frames] 127 | self.camera_to_world = self.camera_to_world[args.frames] 128 | self.intrinsics = self.intrinsics[args.frames] 129 | self.intrinsics_inv = self.intrinsics_inv[args.frames] 130 | 131 | if load_volume: 132 | self.load_volume() 133 | 134 | self.frame_batch_size = 2 * (self.args.point_batch_size // self.num_frames // 2) # nearest even cut 135 | self.point_batch_size = self.frame_batch_size * self.num_frames # nearest multiple of num_frames 136 | self.num_batches = self.args.num_batches 137 | 138 | self.sin_epoch = 0.0 # fraction of training complete 139 | self.frame_cutoff = self.num_frames 140 | print("Frame Count: ", self.num_frames) 141 | 142 | def load_volume(self): 143 | if self.args.rgb_data: 144 | bundle = dict(np.load(self.args.bundle_path, allow_pickle=True)) 145 | self.rgb_volume = torch.tensor(bundle["rgb_volume"]).float()[:,:3] 146 | else: # need to unpack RAW data 147 | bundle = dict(np.load(self.args.bundle_path, allow_pickle=True)) 148 | utils.de_item(bundle) 149 | 150 | self.rgb_volume = (utils.raw_to_rgb(bundle)) # T,C,H,W 151 | 152 | if self.args.max_percentile < 100: # cut off highlights for scaling (long-tail-distribution) 153 | self.rgb_volume = self.rgb_volume/np.percentile(self.rgb_volume, self.args.max_percentile) 154 | 155 | self.rgb_volume = self.rgb_volume.clamp(0,1) 156 | 157 | if self.args.frames is not None: 158 | self.rgb_volume = self.rgb_volume[self.args.frames] # subsample frames 159 | 160 | 161 | def __len__(self): 162 | return self.num_batches # arbitrary as we continuously generate random samples 163 | 164 | def __getitem__(self, idx): 165 | if self.args.frame_cutoff: 166 | self.frame_cutoff = min(int((0.1 + 2*self.sin_epoch) * self.num_frames), self.num_frames) # gradually increase frame cutoff 167 | else: 168 | self.frame_cutoff = self.num_frames 169 | 170 | uv = torch.rand((self.frame_batch_size * self.frame_cutoff), 2)*0.98 + 0.01 # uniform random in [0.01,0.99] 171 | 172 | # t is time for all frames, looks like [0, 0,... 0, 1/N, 1/N, ..., 1/N, 2/N, 2/N, ..., 2/N, etc.] 173 | t = (torch.linspace(0,1,self.num_frames)[:self.frame_cutoff]).repeat_interleave(self.frame_batch_size)[:,None] # point_batch_size, 1 174 | 175 | return self.generate_samples(t, uv) 176 | 177 | def generate_samples(self, t, uv): 178 | """ generate samples from dataset and camera parameters for training 179 | """ 180 | 181 | # create frame_batch_size of quaterions for each frame 182 | camera_to_world = (self.camera_to_world[:self.frame_cutoff]).repeat_interleave(self.frame_batch_size, dim=0) 183 | # create frame_batch_size of intrinsics for each frame 184 | intrinsics = (self.intrinsics[:self.frame_cutoff]).repeat_interleave(self.frame_batch_size, dim=0) 185 | intrinsics_inv = (self.intrinsics_inv[:self.frame_cutoff]).repeat_interleave(self.frame_batch_size, dim=0) 186 | 187 | # sample grid 188 | grid_uv = ((uv - 0.5) * 2).reshape(self.frame_cutoff,self.frame_batch_size,1,2) 189 | rgb_samples = F.grid_sample(self.rgb_volume[:self.frame_cutoff], grid_uv, mode="bilinear", padding_mode="border", align_corners=True) 190 | # samples get returned in shape: num_frames x channels x frame_batch_size x 1 for some reason 191 | rgb_samples = rgb_samples.permute(0,2,1,3).squeeze().flatten(0,1) # point_batch_size x channels 192 | 193 | return t, uv, camera_to_world, intrinsics, intrinsics_inv, rgb_samples 194 | 195 | def sample_frame(self, uv, frame): 196 | """ sample frame [frame] at coordinates u,v 197 | """ 198 | 199 | grid_uv = ((uv - 0.5) * 2)[None,:,None,:] # 1,point_batch_size,1,2 200 | rgb_samples = F.grid_sample(self.rgb_volume[frame:frame+1], grid_uv, mode="bilinear", padding_mode="border", align_corners=True) 201 | rgb_samples = rgb_samples.squeeze().permute(1,0) # point_batch_size, C 202 | 203 | return rgb_samples 204 | 205 | ######################################################################################################### 206 | ################################################ MODELS #################$############################### 207 | ######################################################################################################### 208 | 209 | class RotationModel(pl.LightningModule): 210 | def __init__(self, args): 211 | super().__init__() 212 | self.args = args 213 | self.stabilize = False 214 | self.delta_control_points = torch.nn.Parameter(data=torch.zeros(1, 3, self.args.camera_control_points, dtype=torch.float32), requires_grad=True) 215 | 216 | def forward(self, camera_to_world, t): 217 | delta_control_points = self.delta_control_points.repeat(t.shape[0],1,1) 218 | rotation_deltas = utils.interpolate(delta_control_points, t) 219 | rx, ry, rz = rotation_deltas[:,0], rotation_deltas[:,1], rotation_deltas[:,2] 220 | r0 = torch.zeros_like(rx) 221 | 222 | rotation_offsets = torch.stack([torch.stack([ r0, -rz, ry], dim=-1), 223 | torch.stack([ rz, r0, -rx], dim=-1), 224 | torch.stack([-ry, rx, r0], dim=-1)], dim=-1) 225 | 226 | return camera_to_world + self.args.rotation_weight * rotation_offsets 227 | 228 | class TranslationModel(pl.LightningModule): 229 | def __init__(self, args): 230 | super().__init__() 231 | self.args = args 232 | self.stabilize = False 233 | self.delta_control_points = torch.nn.Parameter(data=torch.zeros(1, 3, self.args.camera_control_points, dtype=torch.float32), requires_grad=True) 234 | 235 | def forward(self, t): 236 | control_points = self.args.translation_weight * self.delta_control_points.repeat(t.shape[0],1,1) 237 | translation = utils.interpolate(control_points, t) 238 | 239 | return translation 240 | 241 | class PlaneModel(pl.LightningModule): 242 | """ Plane reprojection model with learnable z-depth 243 | """ 244 | def __init__(self, args, depth): 245 | super().__init__() 246 | 247 | self.args = args 248 | self.depth = torch.nn.Parameter(data=torch.tensor([depth/1.0], dtype=torch.float32), requires_grad=True) 249 | 250 | self.u0 = torch.nn.Parameter(data=torch.tensor([1.0, 0.0], dtype=torch.float32), requires_grad=False) 251 | self.v0 = torch.nn.Parameter(data=torch.tensor([0.0, 1.0], dtype=torch.float32), requires_grad=False) 252 | 253 | def forward(self, ray_origins, ray_directions): 254 | # termination is just plane depth - ray origin z 255 | termination = ((1.0 * self.depth) - ray_origins[:, 2]).unsqueeze(1) 256 | 257 | # compute intersection points (N x 3) 258 | intersection_points = ray_origins + (termination * ray_directions) 259 | 260 | # project to (u, v) coordinates (N x 1 for each), avoid zero div 261 | u = 0.5 + 0.4 * torch.sum(intersection_points[:, :2] * (self.u0 / (torch.abs(termination) + 1e-6)), dim=1) 262 | v = 0.5 + 0.4 * torch.sum(intersection_points[:, :2] * (self.v0 / (torch.abs(termination) + 1e-6)), dim=1) 263 | uv = torch.stack((u, v), dim=1) 264 | 265 | return uv.clamp(0, 1) # ensure UV coordinates stay within neural field bounds 266 | 267 | class PlaneTransmissionModel(pl.LightningModule): 268 | 269 | def __init__(self, args): 270 | super().__init__() 271 | with open(f"config/config_{args.transmission_image_grid_size}.json") as config_image: 272 | config_image = json.load(config_image) 273 | with open(f"config/config_{args.transmission_flow_grid_size}.json" ) as config_flow: 274 | config_flow = json.load(config_flow) 275 | 276 | self.args = args 277 | 278 | self.encoding_image = tcnn.Encoding(n_input_dims=2, encoding_config=config_image["encoding"]) 279 | self.encoding_flow = tcnn.Encoding(n_input_dims=2, encoding_config=config_flow["encoding"]) 280 | 281 | self.network_image = tcnn.Network(n_input_dims=self.encoding_image.n_output_dims, n_output_dims=3, network_config=config_image["network"]) 282 | self.network_flow = tcnn.Network(n_input_dims=self.encoding_flow.n_output_dims, 283 | n_output_dims=2*(self.args.transmission_control_points_flow), network_config=config_flow["network"]) 284 | 285 | self.model_plane = PlaneModel(args, args.transmission_initial_depth) 286 | self.initial_rgb = torch.nn.Parameter(data=torch.zeros([1,3], dtype=torch.float32), requires_grad=True) 287 | 288 | def forward(self, t, ray_origins, ray_directions, sin_epoch): 289 | uv_plane = self.model_plane(ray_origins, ray_directions) 290 | 291 | 292 | flow = self.network_flow(utils.mask(self.encoding_flow(uv_plane), sin_epoch)) # B x 2 293 | 294 | flow = flow.reshape(-1,2,self.args.transmission_control_points_flow) 295 | flow = 0.01 * utils.interpolate(flow, t) 296 | 297 | rgb = self.network_image(utils.mask(self.encoding_image(uv_plane + flow), sin_epoch)).float() 298 | rgb = (self.initial_rgb + rgb).clamp(0,1) 299 | 300 | return rgb, flow 301 | 302 | class PlaneObstructionModel(pl.LightningModule): 303 | 304 | def __init__(self, args): 305 | super().__init__() 306 | with open(f"config/config_{args.obstruction_image_grid_size}.json") as config_image: 307 | config_image = json.load(config_image) 308 | with open(f"config/config_{args.obstruction_alpha_grid_size}.json" ) as config_alpha: 309 | config_alpha = json.load(config_alpha) 310 | with open(f"config/config_{args.obstruction_flow_grid_size}.json" ) as config_flow: 311 | config_flow = json.load(config_flow) 312 | 313 | self.args = args 314 | 315 | self.encoding_image = tcnn.Encoding(n_input_dims=2, encoding_config=config_image["encoding"]) 316 | self.encoding_alpha = tcnn.Encoding(n_input_dims=2, encoding_config=config_alpha["encoding"]) 317 | self.encoding_flow = tcnn.Encoding(n_input_dims=2, encoding_config=config_flow["encoding"]) 318 | 319 | self.network_image = tcnn.Network(n_input_dims=self.encoding_image.n_output_dims, n_output_dims=3, network_config=config_image["network"]) 320 | self.network_alpha = tcnn.Network(n_input_dims=self.encoding_alpha.n_output_dims, n_output_dims=1, network_config=config_alpha["network"]) 321 | self.network_flow = tcnn.Network(n_input_dims=self.encoding_flow.n_output_dims, n_output_dims=2*(self.args.obstruction_control_points_flow), network_config=config_flow["network"]) 322 | 323 | self.model_plane = PlaneModel(args, args.obstruction_initial_depth) 324 | self.initial_alpha = torch.nn.Parameter(data=torch.tensor(args.obstruction_initial_alpha, dtype=torch.float32), requires_grad=True) 325 | self.initial_rgb = torch.nn.Parameter(data=torch.zeros([1,3], dtype=torch.float32), requires_grad=True) 326 | 327 | def forward(self, t, ray_origins, ray_directions, sin_epoch): 328 | uv_plane = self.model_plane(ray_origins, ray_directions) 329 | 330 | flow = self.network_flow(utils.mask(self.encoding_flow(uv_plane), sin_epoch)) # B x 2 331 | flow = flow.reshape(-1,2,self.args.obstruction_control_points_flow) 332 | flow = 0.01 * utils.interpolate(flow, t) 333 | 334 | rgb = self.network_image(utils.mask(self.encoding_image(uv_plane + flow), sin_epoch)).float() 335 | rgb = (self.initial_rgb + rgb).clamp(0,1) 336 | 337 | alpha = self.network_alpha(utils.mask(self.encoding_alpha(uv_plane + flow), sin_epoch)).float() 338 | alpha = torch.sigmoid((-torch.log(1/self.initial_alpha - 1) + self.args.alpha_temperature * alpha)) 339 | 340 | return rgb, flow, alpha 341 | 342 | ######################################################################################################### 343 | ################################################ NETWORK ################################################ 344 | ######################################################################################################### 345 | 346 | class BundleMLP(pl.LightningModule): 347 | def __init__(self, args, cached_bundle=None): 348 | super().__init__() 349 | # load network configs 350 | 351 | self.args = args 352 | if cached_bundle is None: 353 | self.bundle = BundleDataset(self.args) 354 | else: 355 | with open(cached_bundle, 'rb') as file: 356 | self.bundle = pickle.load(file) 357 | 358 | self.img_width = self.bundle.img_width 359 | self.img_height = self.bundle.img_height 360 | self.lens_distortion = self.bundle.lens_distortion 361 | self.num_frames = args.num_frames = self.bundle.num_frames 362 | if args.frames is None: 363 | self.args.frames = list(range(self.num_frames)) 364 | 365 | self.model_transmission = PlaneTransmissionModel(args) 366 | self.model_obstruction = PlaneObstructionModel(args) 367 | self.model_translation = TranslationModel(args) 368 | self.model_rotation = RotationModel(args) 369 | 370 | self.sin_epoch = 1.0 371 | self.save_hyperparameters() 372 | 373 | def load_volume(self): 374 | self.bundle.load_volume() 375 | 376 | def configure_optimizers(self): 377 | 378 | optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr) 379 | #constant lr 380 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1.0) 381 | 382 | return [optimizer], [scheduler] 383 | 384 | 385 | def forward(self, t, ray_origins, ray_directions): 386 | """ Forward model pass, estimate motion, implicit depth + image. 387 | """ 388 | 389 | rgb_transmission, flow_transmission = self.model_transmission(t, ray_origins, ray_directions, self.sin_epoch) 390 | rgb_obstruction, flow_obstruction, alpha_obstruction = self.model_obstruction(t, ray_origins, ray_directions, self.sin_epoch) 391 | 392 | if self.args.single_plane: 393 | rgb_combined = rgb_transmission 394 | else: 395 | rgb_combined = rgb_transmission * (1 - alpha_obstruction) + rgb_obstruction * alpha_obstruction 396 | 397 | return rgb_combined, rgb_transmission, rgb_obstruction, flow_transmission, flow_obstruction, alpha_obstruction 398 | 399 | def generate_ray_directions(self, uv, camera_to_world, intrinsics_inv): 400 | u, v = uv[:,0:1] * self.img_width, uv[:,1:2] * self.img_height 401 | uv1 = torch.cat([u, v, torch.ones_like(uv[:,0:1])], dim=1) # N x 3 402 | # scale by image width/height 403 | xy1 = torch.bmm(intrinsics_inv, uv1.unsqueeze(2)).squeeze(2) # N x 3 404 | xy = xy1[:,0:2] 405 | 406 | f_div_cx = -1 / intrinsics_inv[:,0,2] 407 | f_div_cy = -1 / intrinsics_inv[:,1,2] 408 | 409 | r2 = torch.sum(xy**2, dim=1, keepdim=True) # N x 1 410 | r4 = r2**2 411 | r6 = r2**3 412 | kappa1, kappa2, kappa3 = self.lens_distortion[0:3] 413 | 414 | # apply lens distortion correction 415 | xy = xy * (1 + kappa1*r2 + kappa2*r4 + kappa3*r6) 416 | 417 | xy = xy * torch.min(f_div_cx[:, None], f_div_cy[:, None]) # scale long dimension to -1, 1 418 | ray_directions = torch.cat([xy, torch.ones_like(xy[:,0:1])], dim=1) # N x 3 419 | ray_directions = torch.bmm(camera_to_world, ray_directions.unsqueeze(2)).squeeze(2) # apply camera rotation 420 | ray_directions = ray_directions / ray_directions[:,2:3] # normalize by z 421 | return ray_directions 422 | 423 | def training_step(self, train_batch, batch_idx): 424 | t, uv, camera_to_world, intrinsics, intrinsics_inv, rgb_reference = debatch(train_batch) # collapse batch + point dimensions 425 | 426 | camera_to_world = self.model_rotation(camera_to_world, t) # apply rotation offset 427 | ray_origins = self.model_translation(t) # camera center in world coordinates 428 | ray_directions = self.generate_ray_directions(uv, camera_to_world, intrinsics_inv) 429 | 430 | rgb_combined, rgb_transmission, rgb_obstruction, flow_transmission, flow_obstruction, alpha_obstruction = self.forward(t, ray_origins, ray_directions) 431 | 432 | loss = 0.0 433 | 434 | rgb_loss = torch.abs((rgb_combined - rgb_reference)/(rgb_combined.detach() + 0.001)) 435 | self.log('loss/rgb', rgb_loss.mean()) 436 | loss += rgb_loss.mean() 437 | 438 | self.log(f'plane_depth/image', self.model_transmission.model_plane.depth) 439 | self.log(f'plane_depth/obstruction', self.model_obstruction.model_plane.depth) 440 | 441 | if (np.abs(self.args.alpha_weight) > 0 and self.sin_epoch) > 0.6: 442 | alpha_loss = self.args.alpha_weight * self.sin_epoch * alpha_obstruction 443 | self.log('loss/alpha', alpha_loss.mean()) 444 | loss += alpha_loss.mean() 445 | 446 | return loss 447 | 448 | def color_and_tone(self, rgb_samples, height, width): 449 | """ Apply CCM and tone curve to raw samples 450 | """ 451 | 452 | img = self.bundle.ccm.to(rgb_samples.device) @ rgb_samples.T 453 | img = img.reshape(3, height, width) 454 | img = utils.apply_tonemap_curve(img, self.bundle.tonemap_curve) 455 | 456 | return img 457 | 458 | def make_grid(self, height, width, u_lims, v_lims): 459 | """ Create (u,v) meshgrid with size (height,width) extent (u_lims, v_lims) 460 | """ 461 | u = torch.linspace(u_lims[0], u_lims[1], width) 462 | v = torch.linspace(v_lims[0], v_lims[1], height) 463 | u_grid, v_grid = torch.meshgrid([u, v], indexing="xy") # u/v grid 464 | return torch.stack((u_grid.flatten(), v_grid.flatten())).t() 465 | 466 | def generate_img(self, frame, height=960, width=720, u_lims=[0,1], v_lims=[0,1]): 467 | """ Produce reference image for tensorboard/visualization 468 | """ 469 | device = self.device 470 | uv = self.make_grid(height, width, u_lims, v_lims) 471 | 472 | rgb_samples = self.bundle.sample_frame(uv, frame).to(device) 473 | img = self.color_and_tone(rgb_samples, height, width) 474 | 475 | return img 476 | 477 | def generate_outputs(self, frame=0, height=720, width=540, u_lims=[0,1], v_lims=[0,1], time=None): 478 | """ Use forward model to sample implicit image I(u,v), depth D(u,v) and raw images 479 | at reprojected u,v, coordinates. Results should be aligned (sampled at (u',v')) 480 | """ 481 | device = self.device 482 | uv = self.make_grid(height, width, u_lims, v_lims) 483 | if time is None: 484 | t = torch.tensor(frame/(self.bundle.num_frames - 1), dtype=torch.float32).repeat(uv.shape[0])[:,None] # num_points x 1 485 | else: 486 | t = torch.tensor(time, dtype=torch.float32).repeat(uv.shape[0])[:,None] # num_points x 1 487 | frame = int(np.floor(time * (self.bundle.num_frames - 1))) 488 | 489 | rgb_reference = self.bundle.sample_frame(uv, frame).to(device) 490 | intrinsics_inv = self.bundle.intrinsics_inv[frame:frame+2] # 2 x 3 x 3 491 | camera_to_world = self.bundle.camera_to_world[frame:frame+2] # 2 x 3 x 3 492 | 493 | if time is None or time >= 1.0: # select exact frame timestamp 494 | intrinsics_inv = intrinsics_inv[0:1] 495 | camera_to_world = camera_to_world[0:1] 496 | else: # interpolate between frames 497 | fraction = time * (self.bundle.num_frames - 1) - frame 498 | intrinsics_inv = intrinsics_inv[0:1] * (1 - fraction) + intrinsics_inv[1:2] * fraction 499 | camera_to_world = camera_to_world[0:1] * (1 - fraction) + camera_to_world[1:2] * fraction 500 | 501 | intrinsics_inv = intrinsics_inv.repeat(uv.shape[0],1,1) # num_points x 3 x 3 502 | camera_to_world = camera_to_world.repeat(uv.shape[0],1,1) # num_points x 3 x 3 503 | 504 | with torch.no_grad(): 505 | rgb_combined_chunks = [] 506 | rgb_transmission_chunks = [] 507 | rgb_obstruction_chunks = [] 508 | flow_transmission_chunks = [] 509 | flow_obstruction_chunks = [] 510 | alpha_obstruction_chunks = [] 511 | 512 | chunk_size = 42 * self.args.point_batch_size 513 | for i in range((t.shape[0] // chunk_size) + 1): 514 | t_chunk, uv_chunk = t[i*chunk_size:(i+1)*chunk_size].to(device), uv[i*chunk_size:(i+1)*chunk_size].to(device) 515 | intrinsics_inv_chunk = intrinsics_inv[i*chunk_size:(i+1)*chunk_size].to(device) 516 | camera_to_world_chunk = camera_to_world[i*chunk_size:(i+1)*chunk_size].to(device) 517 | 518 | camera_to_world_chunk = self.model_rotation(camera_to_world_chunk, t_chunk) # apply rotation offset 519 | ray_origins = self.model_translation(t_chunk) # camera center in world coordinates 520 | ray_directions = self.generate_ray_directions(uv_chunk, camera_to_world_chunk, intrinsics_inv_chunk) 521 | rgb_combined, rgb_transmission, rgb_obstruction, flow_transmission, flow_obstruction, alpha_obstruction = self.forward(t_chunk, ray_origins, ray_directions) 522 | 523 | rgb_combined_chunks.append(rgb_combined.detach().cpu()) 524 | rgb_transmission_chunks.append(rgb_transmission.detach().cpu()) 525 | rgb_obstruction_chunks.append(rgb_obstruction.detach().cpu()) 526 | flow_transmission_chunks.append(flow_transmission.detach().cpu()) 527 | flow_obstruction_chunks.append(flow_obstruction.detach().cpu()) 528 | alpha_obstruction_chunks.append(alpha_obstruction.detach().cpu()) 529 | 530 | rgb_combined = torch.cat(rgb_combined_chunks, dim=0) 531 | 532 | rgb_reference = self.color_and_tone(rgb_reference, height, width) 533 | rgb_combined = self.color_and_tone(rgb_combined, height, width) 534 | rgb_transmission = self.color_and_tone(torch.cat(rgb_transmission_chunks, dim=0), height, width) 535 | rgb_obstruction = self.color_and_tone(torch.cat(rgb_obstruction_chunks, dim=0), height, width) 536 | flow_transmission = torch.cat(flow_transmission_chunks, dim=0).reshape(height, width, 2) 537 | flow_obstruction = torch.cat(flow_obstruction_chunks, dim=0).reshape(height, width, 2) 538 | alpha_obstruction = torch.cat(alpha_obstruction_chunks, dim=0).reshape(height, width) 539 | 540 | return rgb_reference, rgb_combined, rgb_transmission, rgb_obstruction, flow_transmission, flow_obstruction, alpha_obstruction 541 | 542 | def save_outputs(self, path, high_res=False): 543 | os.makedirs(f"outputs/{self.args.name + path}", exist_ok=True) 544 | if high_res: 545 | rgb_reference, rgb_combined, rgb_transmission, rgb_obstruction, flow_transmission, flow_obstruction, alpha_occlusion = model.generate_outputs(frame=0, height=2560, width=1920, u_lims=[0,1], v_lims=[0,1], time=0.0) 546 | np.save(f"outputs/{self.args.name + path}/flow_transmission.npy", flow_transmission.detach().cpu().numpy()) 547 | np.save(f"outputs/{self.args.name + path}/flow_obstruction.npy", flow_obstruction.detach().cpu().numpy()) 548 | else: 549 | rgb_reference, rgb_combined, rgb_transmission, rgb_obstruction, flow_transmission, flow_obstruction, alpha_occlusion = model.generate_outputs(frame=0, height=1080, width=810, u_lims=[0,1], v_lims=[0,1], time=0.0) 550 | 551 | plt.imsave(f"outputs/{self.args.name + path}/reference.png", rgb_reference.permute(1,2,0).detach().cpu().numpy()) 552 | plt.imsave(f"outputs/{self.args.name + path}/alpha.png", alpha_occlusion.detach().cpu().numpy(), cmap="gray") 553 | plt.imsave(f"outputs/{self.args.name + path}/transmission.png", rgb_transmission.permute(1,2,0).detach().cpu().numpy()) 554 | plt.imsave(f"outputs/{self.args.name + path}/obstruction.png", rgb_obstruction.permute(1,2,0).detach().cpu().numpy()) 555 | plt.imsave(f"outputs/{self.args.name + path}/combined.png", rgb_combined.permute(1,2,0).detach().cpu().numpy()) 556 | 557 | 558 | ######################################################################################################### 559 | ############################################### VALIDATION ############################################## 560 | ######################################################################################################### 561 | 562 | class ValidationCallback(pl.Callback): 563 | def __init__(self): 564 | super().__init__() 565 | 566 | def on_train_epoch_start(self, trainer, model): 567 | model.sin_epoch = min(1.0, 0.05 + np.sin(model.current_epoch/(model.args.max_epochs - 1) * np.pi/2)) # progression of training 568 | trainer.train_dataloader.dataset.sin_epoch = model.sin_epoch 569 | print(f" Sin of Current Epoch: {model.sin_epoch:.3f}") 570 | 571 | if model.sin_epoch > 0.4: 572 | # unlock flow model 573 | model.model_transmission.encoding_flow.requires_grad_(True) 574 | model.model_transmission.network_flow.requires_grad_(True) 575 | model.model_obstruction.encoding_flow.requires_grad_(True) 576 | model.model_obstruction.network_flow.requires_grad_(True) 577 | 578 | model.model_transmission.network_flow.train(True) 579 | model.model_transmission.encoding_flow.train(True) 580 | model.model_obstruction.encoding_flow.train(True) 581 | model.model_obstruction.network_flow.train(True) 582 | 583 | if model.args.fast: # skip tensorboarding except for beginning and end 584 | if model.current_epoch == model.args.max_epochs - 1 or model.current_epoch == 0: 585 | pass 586 | else: 587 | return 588 | 589 | # for i, frame in enumerate([0, model.bundle.num_frames//2, model.bundle.num_frames-1]): # can sample more frames 590 | for i, frame in enumerate([0]): 591 | rgb_reference, rgb_combined, rgb_transmission, rgb_obstruction, flow_transmission, flow_obstruction, alpha_obstruction = model.generate_outputs(frame) 592 | model.logger.experiment.add_image(f'pred/{i}_rgb_combined', rgb_combined, global_step=trainer.global_step) 593 | model.logger.experiment.add_image(f'pred/{i}_rgb_transmission', rgb_transmission, global_step=trainer.global_step) 594 | model.logger.experiment.add_image(f'pred/{i}_rgb_obstruction', rgb_obstruction, global_step=trainer.global_step) 595 | model.logger.experiment.add_image(f'pred/{i}_rgb_obstruction_alpha', rgb_obstruction * alpha_obstruction, global_step=trainer.global_step) 596 | model.logger.experiment.add_image(f'pred/{i}_alpha_obstruction', utils.colorize_tensor(alpha_obstruction, vmin=0, vmax=1, cmap="gray"), global_step=trainer.global_step) 597 | 598 | 599 | if model.args.save_video: # save the evolution of the model 600 | model.save_outputs(path=f"/{model.current_epoch}") 601 | 602 | def on_train_start(self, trainer, model): 603 | pl.seed_everything(42) # the answer to life, the universe, and everything 604 | 605 | # initialize rgb as average color of first frame of data (minimize the amount the rgb models have to learn) 606 | model.model_transmission.initial_rgb.data = torch.mean(model.bundle.rgb_volume[0], dim=(1,2))[None,:].to(model.device) 607 | model.model_obstruction.initial_rgb.data = torch.mean(model.bundle.rgb_volume[0], dim=(1,2))[None,:].to(model.device) 608 | 609 | model.logger.experiment.add_text("args", str(model.args)) 610 | 611 | for i, frame in enumerate([0, model.bundle.num_frames//2, model.bundle.num_frames-1]): 612 | rgb_raw = model.generate_img(frame) 613 | model.logger.experiment.add_image(f'gt/{i}_rgb_raw', rgb_raw, global_step=trainer.global_step) 614 | 615 | 616 | def on_train_end(self, trainer, model): 617 | checkpoint_dir = os.path.join("checkpoints", model.args.name, "last.ckpt") 618 | bundle_dir = os.path.join("checkpoints", model.args.name, "bundle.pkl") 619 | trainer.save_checkpoint(checkpoint_dir) 620 | 621 | model.save_outputs(path=f"-final", high_res=True) 622 | 623 | with open(bundle_dir, 'wb') as file: 624 | model.bundle.rgb_volume = torch.ones([model.bundle.num_frames, model.bundle.img_channels, 3,3]).float() 625 | pickle.dump(model.bundle, file) 626 | 627 | if __name__ == "__main__": 628 | 629 | # argparse 630 | parser = argparse.ArgumentParser() 631 | 632 | # data 633 | parser.add_argument('--point_batch_size', type=int, default=2**18, help="Number of points to sample per dataloader index.") 634 | parser.add_argument('--num_batches', type=int, default=80, help="Number of training batches.") 635 | parser.add_argument('--max_percentile', type=float, default=100, help="Percentile of brightest pixels to cut.") 636 | parser.add_argument('--frames', type=str, help="Which subset of frames to use for training, e.g. 0,10,20,30,40") 637 | parser.add_argument('--rgb_data', action='store_true', help="Input data is pre-processed RGB.") 638 | 639 | # model 640 | parser.add_argument('--camera_control_points', type=int, default=22, help="Spline control points for translation/rotation model.") 641 | parser.add_argument('--alpha_weight', type=float, default=1e-2, help="Alpha regularization weight.") 642 | parser.add_argument('--rotation_weight', type=float, default=1e-3, help="Scale learned rotation.") 643 | parser.add_argument('--translation_weight', type=float, default=1e-2, help="Scale learned translation.") 644 | parser.add_argument('--alpha_temperature', type=float, default=1.0, help="Temperature for sigmoid in alpha matte calculation.") 645 | 646 | # planes 647 | parser.add_argument('--obstruction_control_points_flow', type=int, default=11, help="Spline control points for flow models.") 648 | parser.add_argument('--obstruction_flow_grid_size', type=str, default="tiny", help="Obstruction flow grid size (small, medium, large).") 649 | parser.add_argument('--obstruction_image_grid_size', type=str, default="large", help="Obstruction image grid size (small, medium, large).") 650 | parser.add_argument('--obstruction_alpha_grid_size', type=str, default="large", help="Obstruction alpha grid size (small, medium, large).") 651 | parser.add_argument('--obstruction_initial_depth', type=float, default=1.0, help="Obstruction initial plane depth.") 652 | parser.add_argument('--obstruction_initial_alpha', type=float, default=0.5, help="Obstruction initial alpha.") 653 | parser.add_argument('--transmission_control_points_flow', type=int, default=11, help="Spline control points for flow models.") 654 | parser.add_argument('--transmission_flow_grid_size', type=str, default="tiny", help="Transmission flow grid size (small, medium, large).") 655 | parser.add_argument('--transmission_image_grid_size', type=str, default="large", help="Transmission image grid size (small, medium, large).") 656 | parser.add_argument('--transmission_initial_depth', type=float, default=0.4, help="Transmission initial plane depth.") 657 | parser.add_argument('--single_plane', action='store_true', help="Use single plane model.") 658 | 659 | # training 660 | parser.add_argument('--bundle_path', type=str, required=True, help="Path to frame_bundle.npz") 661 | parser.add_argument('--name', type=str, required=True, help="Experiment name for logs and checkpoints.") 662 | parser.add_argument('--max_epochs', type=int, default=75, help="Number of training epochs.") 663 | parser.add_argument('--lr', type=float, default=3e-5, help="Learning rate.") 664 | parser.add_argument('--save_video', action='store_true', help="Store training outputs at each epoch for visualization.") 665 | parser.add_argument('--num_workers', type=int, default=4, help="Number of dataloader workers.") 666 | parser.add_argument('--debug', action='store_true', help="Debug mode, only use 1 batch.") 667 | parser.add_argument('--frame_cutoff', action='store_true', help="Use frame cutoff.") 668 | parser.add_argument('--fast', action='store_true', help="Fast mode.") 669 | 670 | 671 | args = parser.parse_args() 672 | # parse plane args 673 | print(args) 674 | if args.frames is not None: 675 | args.frames = [int(x) for x in re.findall(r"[-+]?\d*\.\d+|\d+", args.frames)] 676 | 677 | # model 678 | model = BundleMLP(args) 679 | model.load_volume() 680 | 681 | # freeze flow at the start of training as it will otherwise fight the camera model during early image fitting 682 | # these can be omitted at the cost of learning really weird camera translations 683 | model.model_transmission.encoding_flow.requires_grad_(False) 684 | model.model_transmission.encoding_flow.train(False) 685 | model.model_transmission.network_flow.requires_grad_(False) 686 | model.model_transmission.network_flow.train(False) 687 | model.model_obstruction.encoding_flow.requires_grad_(False) 688 | model.model_obstruction.encoding_flow.train(False) 689 | model.model_obstruction.network_flow.requires_grad_(False) 690 | model.model_obstruction.network_flow.train(False) 691 | 692 | # dataset 693 | bundle = model.bundle 694 | train_loader = DataLoader(bundle, batch_size=1, num_workers=args.num_workers, shuffle=False, pin_memory=True, prefetch_factor=1) 695 | 696 | 697 | torch.set_float32_matmul_precision('high') 698 | 699 | # training 700 | lr_callback = pl.callbacks.LearningRateMonitor() 701 | logger = pl.loggers.TensorBoardLogger(save_dir=os.getcwd(), version=args.name, name="lightning_logs") 702 | validation_callback = ValidationCallback() 703 | trainer = pl.Trainer(accelerator="gpu", devices=torch.cuda.device_count(), num_nodes=1, strategy="auto", max_epochs=args.max_epochs, 704 | logger=logger, callbacks=[validation_callback, lr_callback], enable_checkpointing=False, fast_dev_run=args.debug) 705 | trainer.fit(model, train_loader) 706 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import re 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | @torch.no_grad() 8 | def raw_to_rgb(bundle): 9 | """ Convert RAW mosaic into three-channel RGB volume 10 | by only in-filling empty pixels. 11 | Returns volume of shape: (T, C, H, W) 12 | """ 13 | 14 | raw_frames = torch.tensor(np.array([bundle[f'raw_{i}']['raw'] for i in range(bundle['num_raw_frames'])]).astype(np.int32), dtype=torch.float32)[None] # C,T,H,W 15 | raw_frames = raw_frames.permute(1,0,2,3) # T,C,H,W 16 | color_correction_gains = bundle['raw_0']['android']['colorCorrection.gains'] 17 | color_correction_gains = np.array([float(el) for el in re.sub(r'[^0-9.,]', '', color_correction_gains).split(',')]) # RGGB gains 18 | color_filter_arrangement = bundle['characteristics']['color_filter_arrangement'] 19 | blacklevel = torch.tensor(np.array([bundle[f'raw_{i}']['blacklevel'] for i in range(bundle['num_raw_frames'])]))[:,:,None,None] 20 | whitelevel = torch.tensor(np.array([bundle[f'raw_{i}']['whitelevel'] for i in range(bundle['num_raw_frames'])]))[:,None,None,None] 21 | shade_maps = torch.tensor(np.array([bundle[f'raw_{i}']['shade_map'] for i in range(bundle['num_raw_frames'])])).permute(0,3,1,2) # T,C,H,W 22 | # interpolate to size of image 23 | shade_maps = F.interpolate(shade_maps, size=(raw_frames.shape[-2]//2, raw_frames.shape[-1]//2), mode='bilinear', align_corners=False) 24 | 25 | top_left = raw_frames[:,:,0::2,0::2] 26 | top_right = raw_frames[:,:,0::2,1::2] 27 | bottom_left = raw_frames[:,:,1::2,0::2] 28 | bottom_right = raw_frames[:,:,1::2,1::2] 29 | 30 | # figure out color channels 31 | if color_filter_arrangement == 0: # RGGB 32 | R, G1, G2, B = top_left, top_right, bottom_left, bottom_right 33 | elif color_filter_arrangement == 1: # GRBG 34 | G1, R, B, G2 = top_left, top_right, bottom_left, bottom_right 35 | elif color_filter_arrangement == 2: # GBRG 36 | G1, B, R, G2 = top_left, top_right, bottom_left, bottom_right 37 | elif color_filter_arrangement == 3: # BGGR 38 | B, G1, G2, R = top_left, top_right, bottom_left, bottom_right 39 | 40 | # apply color correction gains, flip to portrait 41 | R = ((R - blacklevel[:,0:1]) / (whitelevel - blacklevel[:,0:1]) * color_correction_gains[0]) 42 | R *= shade_maps[:,0:1] 43 | G1 = ((G1 - blacklevel[:,1:2]) / (whitelevel - blacklevel[:,1:2]) * color_correction_gains[1]) 44 | G1 *= shade_maps[:,1:2] 45 | G2 = ((G2 - blacklevel[:,2:3]) / (whitelevel - blacklevel[:,2:3]) * color_correction_gains[2]) 46 | G2 *= shade_maps[:,2:3] 47 | B = ((B - blacklevel[:,3:4]) / (whitelevel - blacklevel[:,3:4]) * color_correction_gains[3]) 48 | B *= shade_maps[:,3:4] 49 | 50 | rgb_volume = torch.zeros(raw_frames.shape[0], 3, raw_frames.shape[-2], raw_frames.shape[-1], dtype=torch.float32) 51 | 52 | # Fill gaps in blue channel 53 | rgb_volume[:, 2, 0::2, 0::2] = B.squeeze(1) 54 | rgb_volume[:, 2, 0::2, 1::2] = (B + torch.roll(B, -1, dims=3)).squeeze(1) / 2 55 | rgb_volume[:, 2, 1::2, 0::2] = (B + torch.roll(B, -1, dims=2)).squeeze(1) / 2 56 | rgb_volume[:, 2, 1::2, 1::2] = (B + torch.roll(B, -1, dims=2) + torch.roll(B, -1, dims=3) + torch.roll(B, [-1, -1], dims=[2, 3])).squeeze(1) / 4 57 | 58 | # Fill gaps in green channel 59 | rgb_volume[:, 1, 0::2, 0::2] = G1.squeeze(1) 60 | rgb_volume[:, 1, 0::2, 1::2] = (G1 + torch.roll(G1, -1, dims=3) + G2 + torch.roll(G2, 1, dims=2)).squeeze(1) / 4 61 | rgb_volume[:, 1, 1::2, 0::2] = (G1 + torch.roll(G1, -1, dims=2) + G2 + torch.roll(G2, 1, dims=3)).squeeze(1) / 4 62 | rgb_volume[:, 1, 1::2, 1::2] = G2.squeeze(1) 63 | 64 | # Fill gaps in red channel 65 | rgb_volume[:, 0, 0::2, 0::2] = R.squeeze(1) 66 | rgb_volume[:, 0, 0::2, 1::2] = (R + torch.roll(R, -1, dims=3)).squeeze(1) / 2 67 | rgb_volume[:, 0, 1::2, 0::2] = (R + torch.roll(R, -1, dims=2)).squeeze(1) / 2 68 | rgb_volume[:, 0, 1::2, 1::2] = (R + torch.roll(R, -1, dims=2) + torch.roll(R, -1, dims=3) + torch.roll(R, [-1, -1], dims=[2, 3])).squeeze(1) / 4 69 | 70 | rgb_volume = torch.flip(rgb_volume.transpose(-1,-2), [-1]) # rotate 90 degrees clockwise to portrait mode 71 | 72 | return rgb_volume 73 | 74 | def de_item(bundle): 75 | """ Call .item() on all dictionary items 76 | removes unnecessary extra dimension 77 | """ 78 | 79 | bundle['motion'] = bundle['motion'].item() 80 | bundle['characteristics'] = bundle['characteristics'].item() 81 | 82 | for i in range(bundle['num_raw_frames']): 83 | bundle[f'raw_{i}'] = bundle[f'raw_{i}'].item() 84 | 85 | def mask(encoding, mask_coef): 86 | mask_coef = 0.4 + 0.6*mask_coef 87 | # interpolate to size of encoding 88 | mask = torch.zeros_like(encoding[0:1]) 89 | mask_ceil = int(np.ceil(mask_coef * encoding.shape[1])) 90 | mask[:,:mask_ceil] = 1.0 91 | 92 | return encoding * mask 93 | 94 | def interpolate(signal, times): 95 | if signal.shape[-1] == 1: 96 | return signal.squeeze(-1) 97 | elif signal.shape[-1] == 2: 98 | return interpolate_linear(signal, times) 99 | else: 100 | return interpolate_cubic_hermite(signal, times) 101 | 102 | @torch.jit.script 103 | def interpolate_cubic_hermite(signal, times): 104 | # Interpolate a signal using cubic Hermite splines 105 | # signal: (B, C, T) or (B, T) 106 | # times: (B, T) 107 | 108 | if len(signal.shape) == 3: # B,C,T 109 | times = times.unsqueeze(1) 110 | times = times.repeat(1, signal.shape[1], 1) 111 | 112 | N = signal.shape[-1] 113 | 114 | times_scaled = times * (N - 1) 115 | indices = torch.floor(times_scaled).long() 116 | 117 | # Clamping to avoid out-of-bounds indices 118 | indices = torch.clamp(indices, 0, N - 2) 119 | left_indices = torch.clamp(indices - 1, 0, N - 1) 120 | right_indices = torch.clamp(indices + 1, 0, N - 1) 121 | right_right_indices = torch.clamp(indices + 2, 0, N - 1) 122 | 123 | t = (times_scaled - indices.float()) 124 | 125 | p0 = torch.gather(signal, -1, left_indices) 126 | p1 = torch.gather(signal, -1, indices) 127 | p2 = torch.gather(signal, -1, right_indices) 128 | p3 = torch.gather(signal, -1, right_right_indices) 129 | 130 | # One-sided derivatives at the boundaries 131 | m0 = torch.where(left_indices == indices, (p2 - p1), (p2 - p0) / 2) 132 | m1 = torch.where(right_right_indices == right_indices, (p2 - p1), (p3 - p1) / 2) 133 | 134 | # Hermite basis functions 135 | h00 = (1 + 2*t) * (1 - t)**2 136 | h10 = t * (1 - t)**2 137 | h01 = t**2 * (3 - 2*t) 138 | h11 = t**2 * (t - 1) 139 | 140 | interpolation = h00 * p1 + h10 * m0 + h01 * p2 + h11 * m1 141 | 142 | if len(signal.shape) == 3: # remove extra singleton dimension 143 | interpolation = interpolation.squeeze(-1) 144 | 145 | return interpolation 146 | 147 | 148 | @torch.jit.script 149 | def interpolate_linear(signal, times): 150 | # Interpolate a signal using linear interpolation 151 | # signal: (B, C, T) or (B, T) 152 | # times: (B, T) 153 | 154 | if len(signal.shape) == 3: # B,C,T 155 | times = times.unsqueeze(1) 156 | times = times.repeat(1, signal.shape[1], 1) 157 | 158 | # Scale times to be between 0 and N - 1 159 | times_scaled = times * (signal.shape[-1] - 1) 160 | 161 | indices = torch.floor(times_scaled).long() 162 | right_indices = (indices + 1).clamp(max=signal.shape[-1] - 1) 163 | 164 | t = (times_scaled - indices.float()) 165 | 166 | p0 = torch.gather(signal, -1, indices) 167 | p1 = torch.gather(signal, -1, right_indices) 168 | 169 | # Linear basis functions 170 | h00 = (1 - t) 171 | h01 = t 172 | 173 | interpolation = h00 * p0 + h01 * p1 174 | 175 | if len(signal.shape) == 3: # remove extra singleton dimension 176 | interpolation = interpolation.squeeze(-1) 177 | 178 | return interpolation 179 | 180 | @torch.jit.script 181 | def convert_quaternions_to_rot(quaternions): 182 | """ Convert quaternions (wxyz) to 3x3 rotation matrices. 183 | Adapted from: https://automaticaddison.com/how-to-convert-a-quaternion-to-a-rotation-matrix 184 | """ 185 | 186 | qw, qx, qy, qz = quaternions[:,0], quaternions[:,1], quaternions[:,2], quaternions[:,3] 187 | 188 | R00 = 2 * ((qw * qw) + (qx * qx)) - 1 189 | R01 = 2 * ((qx * qy) - (qw * qz)) 190 | R02 = 2 * ((qx * qz) + (qw * qy)) 191 | 192 | R10 = 2 * ((qx * qy) + (qw * qz)) 193 | R11 = 2 * ((qw * qw) + (qy * qy)) - 1 194 | R12 = 2 * ((qy * qz) - (qw * qx)) 195 | 196 | R20 = 2 * ((qx * qz) - (qw * qy)) 197 | R21 = 2 * ((qy * qz) + (qw * qx)) 198 | R22 = 2 * ((qw * qw) + (qz * qz)) - 1 199 | 200 | R = torch.stack([R00, R01, R02, R10, R11, R12, R20, R21, R22], dim=-1) 201 | R = R.reshape(-1,3,3) 202 | 203 | return R 204 | 205 | def multi_interp(x, xp, fp): 206 | """ Simple extension of np.interp for independent 207 | linear interpolation of all axes of fp 208 | sample signal fp with timestamps xp at new timestamps x 209 | """ 210 | if torch.is_tensor(fp): 211 | out = [torch.tensor(np.interp(x, xp, fp[:,i]), dtype=fp.dtype) for i in range(fp.shape[-1])] 212 | return torch.stack(out, dim=-1) 213 | else: 214 | out = [np.interp(x, xp, fp[:,i]) for i in range(fp.shape[-1])] 215 | return np.stack(out, axis=-1) 216 | 217 | def parse_ccm(s): 218 | ccm = torch.tensor([eval(x.group()) for x in re.finditer(r"[-+]?\d+/\d+|[-+]?\d+\.\d+|[-+]?\d+", s)]) 219 | ccm = ccm.reshape(3,3) 220 | return ccm 221 | 222 | def parse_tonemap_curve(data_string): 223 | channels = re.findall(r'(R|G|B):\[(.*?)\]', data_string) 224 | result_array = np.zeros((3, len(channels[0][1].split('),')), 2)) 225 | 226 | for i, (_, channel_data) in enumerate(channels): 227 | pairs = channel_data.split('),') 228 | for j, pair in enumerate(pairs): 229 | x, y = map(float, re.findall(r'([\d\.]+)', pair)) 230 | result_array[i, j] = (x, y) 231 | return result_array 232 | 233 | def apply_tonemap_curve(image, tonemap): 234 | # apply tonemap curve to each color channel 235 | image_toned = image.clone().cpu().numpy() 236 | 237 | for i in range(3): 238 | x_vals, y_vals = tonemap[i][:, 0], tonemap[i][:, 1] 239 | image_toned[i] = np.interp(image_toned[i], x_vals, y_vals) 240 | 241 | # Convert back to PyTorch tensor 242 | image_toned = torch.tensor(image_toned, dtype=torch.float32) 243 | 244 | return image_toned 245 | 246 | def debatch(batch): 247 | """ Collapse batch and channel dimension together 248 | """ 249 | debatched = [] 250 | 251 | for x in batch: 252 | if len(x.shape) <=1: 253 | raise Exception("This tensor is to small to debatch.") 254 | elif len(x.shape) == 2: 255 | debatched.append(x.reshape(x.shape[0] * x.shape[1])) 256 | else: 257 | debatched.append(x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])) 258 | 259 | return debatched 260 | 261 | def colorize_tensor(value, vmin=None, vmax=None, cmap=None, colorbar=False, height=9.6, width=7.2): 262 | """ Convert tensor to 3 channel RGB array according to colors from cmap 263 | similar usage as plt.imshow 264 | """ 265 | assert len(value.shape) == 2 # H x W 266 | 267 | fig, ax = plt.subplots(1,1) 268 | fig.set_size_inches(width,height) 269 | a = ax.imshow(value.detach().cpu(), vmin=vmin, vmax=vmax, cmap=cmap) 270 | ax.set_axis_off() 271 | if colorbar: 272 | cbar = plt.colorbar(a, fraction=0.05) 273 | cbar.ax.tick_params(labelsize=30) 274 | plt.tight_layout() 275 | plt.close() 276 | 277 | # Draw figure on canvas 278 | fig.canvas.draw() 279 | 280 | # Convert the figure to numpy array, read the pixel values and reshape the array 281 | img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 282 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 283 | 284 | # Normalize into 0-1 range for TensorBoard(X). Swap axes for newer versions where API expects colors in first dim 285 | img = img / 255.0 286 | 287 | return torch.tensor(img).permute(2,0,1).float() 288 | --------------------------------------------------------------------------------