├── .figs
├── extras.png
├── occ-main.png
├── occ-supp.png
├── occ-wild.png
├── pani-badge.svg
├── ref-main.png
├── ref-supp.png
└── ref-wild.png
├── LICENSE
├── README.md
├── checkpoints
└── __init__.py
├── config
├── config_large.json
├── config_medium.json
├── config_small.json
└── config_tiny.json
├── data
└── __init__.py
├── lightning_logs
└── __init__.py
├── outputs
└── __init__.py
├── requirements.txt
├── scripts
├── dehaze.sh
├── fusion.sh
├── occlusion-wild.sh
├── occlusion.sh
├── reflection-wild.sh
├── reflection.sh
├── segmentation.sh
└── shadow.sh
├── train.py
├── tutorial.ipynb
└── utils
└── utils.py
/.figs/extras.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/.figs/extras.png
--------------------------------------------------------------------------------
/.figs/occ-main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/.figs/occ-main.png
--------------------------------------------------------------------------------
/.figs/occ-supp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/.figs/occ-supp.png
--------------------------------------------------------------------------------
/.figs/occ-wild.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/.figs/occ-wild.png
--------------------------------------------------------------------------------
/.figs/pani-badge.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
69 |
--------------------------------------------------------------------------------
/.figs/ref-main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/.figs/ref-main.png
--------------------------------------------------------------------------------
/.figs/ref-supp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/.figs/ref-supp.png
--------------------------------------------------------------------------------
/.figs/ref-wild.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/.figs/ref-wild.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Ilya Chugunov
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Neural Spline Fields for Burst Image Fusion and Layer Separation
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | This is the official code repository for the CVPR 2024 work: [Neural Spline Fields for Burst Image Fusion and Layer Separation](https://light.princeton.edu/publication/nsf/). If you use parts of this work, or otherwise take inspiration from it, please considering citing our paper:
11 | ```
12 | @inproceedings{chugunov2024neural,
13 | title={Neural spline fields for burst image fusion and layer separation},
14 | author={Chugunov, Ilya and Shustin, David and Yan, Ruyu and Lei, Chenyang and Heide, Felix},
15 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
16 | pages={25763--25773},
17 | year={2024}
18 | }
19 | ```
20 |
21 | ## Requirements:
22 | - Code was written in PyTorch 2.0 on an Ubuntu 22.04 machine.
23 | - Condensed package requirements are in `\requirements.txt`. Note that this contains the exact package versions at the time of publishing. Code will most likely work with newer versions of the libraries, but you will need to watch out for changes in class/function calls.
24 | - The non-standard packages you may need are `pytorch_lightning`, `commentjson`, `rawpy`, and `tinycudann`. See [NVlabs/tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn) for installation instructions. Depending on your system you might just be able to do `pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch`, or might have to cmake and build it from source.
25 |
26 | ## Project Structure:
27 | ```cpp
28 | NSF
29 | ├── checkpoints
30 | │ └── // folder for network checkpoints
31 | ├── config
32 | │ └── // network and encoding configurations for different sizes of MLPs
33 | ├── data
34 | │ └── // folder for long-burst data
35 | ├── lightning_logs
36 | │ └── // folder for tensorboard logs
37 | ├── outputs
38 | │ └── // folder for model outputs (e.g., final reconstructions)
39 | ├── scripts
40 | │ └── // training scripts for different tasks (e.g., occlusion/reflection/shadow separation)
41 | ├── utils
42 | │ └── utils.py // network helper functions (e.g., RAW demosaicing, spline interpolation)
43 | ├── LICENSE // legal stuff
44 | ├── README.md // <- you are here
45 | ├── requirements.txt // frozen package requirements
46 | ├── train.py // dataloader, network, visualization, and trainer code
47 | └── tutorial.ipynb // interactive tutorial for training the model
48 | ```
49 | ## Getting Started:
50 | We highly recommend you start by going through `tutorial.ipynb`, either on your own machine or [with this Google Colab link](https://colab.research.google.com/github/princeton-computational-imaging/NSF/blob/main/tutorial.ipynb).
51 |
52 | TLDR: models can be trained with:
53 |
54 | `bash scripts/{application}.sh --bundle_path {path_to_data} --name {checkpoint_name}`
55 |
56 | And reconstruction outputs will get saved to `outputs/{checkpoint_name}-final`
57 |
58 | For a full list of training arguments, we recommend looking through the argument parser section at the bottom of `\train.py`.
59 |
60 | ## Data:
61 | You can download the long-burst data used in the paper (and extra bonus scenes) via the following links:
62 |
63 | 1. Main occlusion scenes: [occlusion-main.zip](https://soap.cs.princeton.edu/nsf/data/occlusion-main.zip) (use `scripts/occlusion.sh` to train)
64 | 
65 |
66 | 2. Supplementary occlusion scenes: [occlusion-supp.zip](https://soap.cs.princeton.edu/nsf/data/occlusion-supp.zip) (use `scripts/occlusion.sh` to train)
67 | 
68 |
69 | 3. In-the-wild occlusion scenes: [occlusion-wild.zip](https://soap.cs.princeton.edu/nsf/data/occlusion-wild.zip) (use `scripts/occlusion-wild.sh` to train)
70 | 
71 |
72 | 4. Main reflection scenes: [reflection-main.zip](https://soap.cs.princeton.edu/nsf/data/reflection-main.zip) (use `scripts/reflection.sh` to train)
73 | 
74 |
75 | 5. Supplementary reflection scenes: [reflection-supp.zip](https://soap.cs.princeton.edu/nsf/data/reflection-supp.zip) (use `scripts/reflection.sh` to train)
76 | 
77 |
78 | 6. In-the-wild reflection scenes: [reflection-wild.zip](https://soap.cs.princeton.edu/nsf/data/reflection-wild.zip) (use `scripts/reflection-wild.sh` to train)
79 | 
80 |
81 | 7. Extra scenes: [extras.zip](https://soap.cs.princeton.edu/nsf/data/extras.zip) (use `scripts/dehaze.sh`, `segmentation.sh`, or `shadow.sh`)
82 | 
83 |
84 | 7. Synthetic validation: [synthetic-validation.zip](https://soap.cs.princeton.edu/nsf/data/synthetic-validation.zip) (use `scripts/reflection.sh` or `occlusion.sh` with flag `--rgb`)
85 |
86 | We recommend you download and extract these into the `data/` folder.
87 |
88 | ## App:
89 | Want to record your own long-burst data? Check out our Android RAW capture app [Pani!](https://github.com/Ilya-Muromets/Pani)
90 |
91 | Good luck have fun,
92 | Ilya
93 |
--------------------------------------------------------------------------------
/checkpoints/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/checkpoints/__init__.py
--------------------------------------------------------------------------------
/config/config_large.json:
--------------------------------------------------------------------------------
1 | {
2 | "encoding": {
3 | "otype": "HashGrid",
4 | "n_levels": 16,
5 | "n_features_per_level": 4,
6 | "log2_hashmap_size": 17,
7 | "base_resolution": 4,
8 | "per_level_scale": 1.61,
9 | "interpolation": "Linear"
10 | },
11 | "network": {
12 | "otype": "FullyFusedMLP",
13 | "activation": "ReLU",
14 | "output_activation": "None",
15 | "n_neurons": 64,
16 | "n_hidden_layers": 5
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/config/config_medium.json:
--------------------------------------------------------------------------------
1 | {
2 | "encoding": {
3 | "otype": "HashGrid",
4 | "n_levels": 12,
5 | "n_features_per_level": 4,
6 | "log2_hashmap_size": 15,
7 | "base_resolution": 4,
8 | "per_level_scale": 1.61,
9 | "interpolation": "Linear"
10 | },
11 | "network": {
12 | "otype": "FullyFusedMLP",
13 | "activation": "ReLU",
14 | "output_activation": "None",
15 | "n_neurons": 64,
16 | "n_hidden_layers": 4
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/config/config_small.json:
--------------------------------------------------------------------------------
1 | {
2 | "encoding": {
3 | "otype": "HashGrid",
4 | "n_levels": 8,
5 | "n_features_per_level": 4,
6 | "log2_hashmap_size": 13,
7 | "base_resolution": 4,
8 | "per_level_scale": 1.61,
9 | "interpolation": "Linear"
10 | },
11 | "network": {
12 | "otype": "FullyFusedMLP",
13 | "activation": "ReLU",
14 | "output_activation": "None",
15 | "n_neurons": 64,
16 | "n_hidden_layers": 3
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/config/config_tiny.json:
--------------------------------------------------------------------------------
1 | {
2 | "encoding": {
3 | "otype": "HashGrid",
4 | "n_levels": 6,
5 | "n_features_per_level": 4,
6 | "log2_hashmap_size": 12,
7 | "base_resolution": 4,
8 | "per_level_scale": 1.61,
9 | "interpolation": "Linear"
10 | },
11 | "network": {
12 | "otype": "FullyFusedMLP",
13 | "activation": "ReLU",
14 | "output_activation": "None",
15 | "n_neurons": 64,
16 | "n_hidden_layers": 3
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/data/__init__.py
--------------------------------------------------------------------------------
/lightning_logs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/lightning_logs/__init__.py
--------------------------------------------------------------------------------
/outputs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/NSF/168c4eac651f66fe4db105aa8eb1cb6dac5cf4e8/outputs/__init__.py
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | commentjson>=0.9.0
2 | matplotlib>=3.7.0
3 | natsort>=8.3.1
4 | numpy>=1.24.3
5 | opencv_python>=4.5.4.58
6 | pytorch_lightning>=2.0.1
7 | rawpy>=0.18.1
8 | torch>=2.0.0+cu118
9 | tqdm>=4.64.1
10 |
--------------------------------------------------------------------------------
/scripts/dehaze.sh:
--------------------------------------------------------------------------------
1 | python3 train.py \
2 | --obstruction_flow_grid_size=tiny \
3 | --obstruction_image_grid_size=tiny \
4 | --obstruction_alpha_grid_size=small \
5 | --obstruction_initial_alpha=0.5 \
6 | --obstruction_initial_depth=0.5 \
7 | --transmission_flow_grid_size=tiny \
8 | --transmission_image_grid_size=large \
9 | --transmission_initial_depth=1.0 \
10 | --alpha_weight=1e-2 \
11 | --alpha_temperature=4.0 \
12 | --translation_weight=1e-1 \
13 | "$@"
--------------------------------------------------------------------------------
/scripts/fusion.sh:
--------------------------------------------------------------------------------
1 | python3 train.py \
2 | --transmission_flow_grid_size=tiny \
3 | --transmission_image_grid_size=large \
4 | --transmission_initial_depth=0.5 \
5 | --transmission_control_points_flow=31 \
6 | --single_plane \
7 | --camera_control_points=31 \
8 | --alpha_weight=0.0 \
9 | "$@"
--------------------------------------------------------------------------------
/scripts/occlusion-wild.sh:
--------------------------------------------------------------------------------
1 | python3 train.py \
2 | --obstruction_flow_grid_size=tiny \
3 | --obstruction_image_grid_size=medium \
4 | --obstruction_alpha_grid_size=medium \
5 | --obstruction_initial_depth=0.5 \
6 | --transmission_flow_grid_size=tiny \
7 | --transmission_image_grid_size=large \
8 | --transmission_initial_depth=1.0 \
9 | --alpha_weight=1e-2 \
10 | --alpha_temperature=0.1 \
11 | --lr=3e-5 \
12 | "$@"
13 |
--------------------------------------------------------------------------------
/scripts/occlusion.sh:
--------------------------------------------------------------------------------
1 | python3 train.py \
2 | --obstruction_flow_grid_size=tiny \
3 | --obstruction_image_grid_size=medium \
4 | --obstruction_alpha_grid_size=medium \
5 | --obstruction_initial_depth=0.5 \
6 | --transmission_flow_grid_size=tiny \
7 | --transmission_image_grid_size=large \
8 | --transmission_initial_depth=1.0 \
9 | --alpha_weight=2e-2 \
10 | --alpha_temperature=3.0 \
11 | --lr=3e-5 \
12 | "$@"
13 |
--------------------------------------------------------------------------------
/scripts/reflection-wild.sh:
--------------------------------------------------------------------------------
1 | python3 train.py \
2 | --obstruction_flow_grid_size=tiny \
3 | --obstruction_image_grid_size=large \
4 | --obstruction_alpha_grid_size=tiny \
5 | --obstruction_initial_depth=2.0 \
6 | --transmission_flow_grid_size=tiny \
7 | --transmission_image_grid_size=large \
8 | --transmission_initial_depth=1.0 \
9 | --alpha_weight=5e-3 \
10 | --alpha_temperature=0.1 \
11 | --lr=2e-4 \
12 | --translation_weight=1e-1 \
13 | "$@"
14 |
--------------------------------------------------------------------------------
/scripts/reflection.sh:
--------------------------------------------------------------------------------
1 | python3 train.py \
2 | --obstruction_flow_grid_size=tiny \
3 | --obstruction_image_grid_size=large \
4 | --obstruction_alpha_grid_size=tiny \
5 | --obstruction_initial_depth=2.0 \
6 | --transmission_flow_grid_size=tiny \
7 | --transmission_image_grid_size=large \
8 | --transmission_initial_depth=1.0 \
9 | --alpha_weight=0.0 \
10 | --alpha_temperature=0.2 \
11 | --lr=2e-4 \
12 | --translation_weight=1e-1 \
13 | "$@"
14 |
--------------------------------------------------------------------------------
/scripts/segmentation.sh:
--------------------------------------------------------------------------------
1 | python3 train.py \
2 | --obstruction_flow_grid_size=small \
3 | --obstruction_image_grid_size=large \
4 | --obstruction_alpha_grid_size=medium \
5 | --obstruction_control_points_flow=15 \
6 | --obstruction_initial_depth=0.5 \
7 | --obstruction_initial_alpha=0.5 \
8 | --transmission_flow_grid_size=small \
9 | --transmission_image_grid_size=large \
10 | --transmission_initial_depth=1.0 \
11 | --transmission_control_points_flow=15 \
12 | --camera_control_points=15 \
13 | --alpha_weight=2e-3 \
14 | --alpha_temperature=0.15 \
15 | --translation_weight=1e0 \
16 | "$@"
--------------------------------------------------------------------------------
/scripts/shadow.sh:
--------------------------------------------------------------------------------
1 | python3 train.py \
2 | --obstruction_flow_grid_size=tiny \
3 | --obstruction_image_grid_size=tiny \
4 | --obstruction_alpha_grid_size=medium \
5 | --obstruction_initial_alpha=0.5 \
6 | --obstruction_initial_depth=10.0 \
7 | --transmission_flow_grid_size=tiny \
8 | --transmission_image_grid_size=large \
9 | --transmission_initial_depth=0.5 \
10 | --alpha_weight=0.0 \
11 | --alpha_temperature=4.0 \
12 | --translation_weight=1e-1 \
13 | "$@"
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import commentjson as json
3 | import numpy as np
4 | import os
5 | import re
6 | import pickle
7 |
8 | import tinycudann as tcnn
9 |
10 | from utils import utils
11 | from utils.utils import debatch
12 | import matplotlib.pyplot as plt
13 |
14 | import torch
15 | from torch.nn import functional as F
16 | from torch.utils.data import Dataset
17 | from torch.utils.data import DataLoader
18 | import pytorch_lightning as pl
19 |
20 | #########################################################################################################
21 | ################################################ DATASET ################################################
22 | #########################################################################################################
23 |
24 | class BundleDataset(Dataset):
25 | def __init__(self, args, load_volume=False):
26 | self.args = args
27 | print("Loading from:", self.args.bundle_path)
28 |
29 | if self.args.rgb_data:
30 | self.init_RGB(self.args, load_volume)
31 | else:
32 | self.init_RAW(self.args, load_volume)
33 |
34 | def init_RGB(self, args, load_volume=False):
35 |
36 | self.lens_distortion = torch.tensor([0.0,0.0,0.0,0.0,0.0]).float()
37 | self.ccm = torch.tensor(np.eye(3)).float()
38 | self.tonemap_curve = torch.tensor(np.linspace(0,1,65)[None,:,None]).float().repeat(3,1,2) # identity tonemap curve
39 |
40 | bundle = dict(np.load(args.bundle_path, allow_pickle=True))
41 |
42 | if "reference_img" in bundle.keys():
43 | self.reference_img = torch.tensor(bundle["reference_img"]).float()
44 | if "translations" in bundle.keys():
45 | self.translations = torch.tensor(bundle["translations"]).float() # T,3
46 |
47 | # not efficient, load twice, but we're missing metadata for img sizes
48 | self.rgb_volume = torch.tensor(bundle["rgb_volume"]).float()[:,:3] # T,C,H,W
49 |
50 | self.intrinsics = torch.tensor(bundle["intrinsics"]).float() # T,3,3
51 | self.intrinsics = self.intrinsics.transpose(1, 2)
52 | self.intrinsics_inv = torch.inverse(self.intrinsics)
53 | self.rotations = torch.tensor(bundle["rotations"]).float() # T,3,3
54 | self.reference_rotation = self.rotations[0]
55 | self.camera_to_world = self.reference_rotation.T @ self.rotations
56 |
57 | self.num_frames = self.rgb_volume.shape[0]
58 | self.img_channels = self.rgb_volume.shape[1]
59 | self.img_height = self.rgb_volume.shape[2]
60 | self.img_width = self.rgb_volume.shape[3]
61 |
62 | if args.frames is not None:
63 | # subsample frames
64 | self.num_frames = len(args.frames)
65 | self.rotations = self.rotations[args.frames]
66 | self.camera_to_world = self.camera_to_world[args.frames]
67 | self.intrinsics = self.intrinsics[args.frames]
68 | self.intrinsics_inv = self.intrinsics_inv[args.frames]
69 |
70 | self.load_volume()
71 |
72 | self.frame_batch_size = 2 * (self.args.point_batch_size // self.num_frames // 2) # nearest even cut
73 | self.point_batch_size = self.frame_batch_size * self.num_frames # nearest multiple of num_frames
74 | self.num_batches = self.args.num_batches
75 |
76 | self.sin_epoch = 0.0 # fraction of training complete
77 | self.frame_cutoff = self.num_frames
78 | print("Frame Count: ", self.num_frames)
79 |
80 |
81 | def init_RAW(self, args, load_volume=False):
82 | bundle = np.load(args.bundle_path, allow_pickle=True)
83 |
84 | self.characteristics = bundle['characteristics'].item() # camera characteristics
85 | self.motion = bundle['motion'].item()
86 | self.frame_timestamps = torch.tensor([bundle[f'raw_{i}'].item()['timestamp'] for i in range(bundle['num_raw_frames'])])
87 | self.motion_timestamps = torch.tensor(self.motion['timestamp'])
88 |
89 | self.quaternions = torch.tensor(self.motion['quaternion']).float() # T',4, has different timestamps from frames
90 | # our scene is +z towards scene convention, but phone is +z towards face convention
91 | # so we need to rotate 180 degrees around y axis, or equivalently flip over z,y
92 | self.quaternions[:,2] = -self.quaternions[:,2] # invert y
93 | self.quaternions[:,3] = -self.quaternions[:,3] # invert z
94 |
95 | self.quaternions = utils.multi_interp(self.frame_timestamps, self.motion_timestamps, self.quaternions)
96 | self.rotations = utils.convert_quaternions_to_rot(self.quaternions)
97 |
98 | self.reference_quaternion = self.quaternions[0]
99 | self.reference_rotation = self.rotations[0]
100 |
101 | self.camera_to_world = self.reference_rotation.T @ self.rotations
102 |
103 | self.intrinsics = torch.tensor(np.array([bundle[f'raw_{i}'].item()['intrinsics'] for i in range(bundle['num_raw_frames'])])).float() # T,3,3
104 | # swap cx,cy -> landscape to portrait
105 | cx, cy = self.intrinsics[:, 2, 1].clone(), self.intrinsics[:, 2, 0].clone()
106 | self.intrinsics[:, 2, 0], self.intrinsics[:, 2, 1] = cx, cy
107 | # transpose to put cx,cy in right column
108 | self.intrinsics = self.intrinsics.transpose(1, 2)
109 | self.intrinsics_inv = torch.inverse(self.intrinsics)
110 |
111 | self.lens_distortion = bundle['raw_0'].item()['lens_distortion']
112 | self.tonemap_curve = torch.tensor(bundle['raw_0'].item()['tonemap_curve'])
113 | self.ccm = utils.parse_ccm(bundle['raw_0'].item()['android']['colorCorrection.transform'])
114 |
115 | self.num_frames = bundle['num_raw_frames'].item()
116 | self.img_channels = 3
117 | self.img_height = bundle['raw_0'].item()['width'] # rotated 90
118 | self.img_width = bundle['raw_0'].item()['height']
119 | self.rgb_volume = torch.ones([self.num_frames, self.img_channels, 3,3]).float() # T,C,3,3, tiny fake volume for lazy loading
120 |
121 | if args.frames is not None:
122 | # subsample frames
123 | self.num_frames = len(args.frames)
124 | self.frame_timestamps = self.frame_timestamps[args.frames]
125 | self.quaternions = self.quaternions[args.frames]
126 | self.rotations = self.rotations[args.frames]
127 | self.camera_to_world = self.camera_to_world[args.frames]
128 | self.intrinsics = self.intrinsics[args.frames]
129 | self.intrinsics_inv = self.intrinsics_inv[args.frames]
130 |
131 | if load_volume:
132 | self.load_volume()
133 |
134 | self.frame_batch_size = 2 * (self.args.point_batch_size // self.num_frames // 2) # nearest even cut
135 | self.point_batch_size = self.frame_batch_size * self.num_frames # nearest multiple of num_frames
136 | self.num_batches = self.args.num_batches
137 |
138 | self.sin_epoch = 0.0 # fraction of training complete
139 | self.frame_cutoff = self.num_frames
140 | print("Frame Count: ", self.num_frames)
141 |
142 | def load_volume(self):
143 | if self.args.rgb_data:
144 | bundle = dict(np.load(self.args.bundle_path, allow_pickle=True))
145 | self.rgb_volume = torch.tensor(bundle["rgb_volume"]).float()[:,:3]
146 | else: # need to unpack RAW data
147 | bundle = dict(np.load(self.args.bundle_path, allow_pickle=True))
148 | utils.de_item(bundle)
149 |
150 | self.rgb_volume = (utils.raw_to_rgb(bundle)) # T,C,H,W
151 |
152 | if self.args.max_percentile < 100: # cut off highlights for scaling (long-tail-distribution)
153 | self.rgb_volume = self.rgb_volume/np.percentile(self.rgb_volume, self.args.max_percentile)
154 |
155 | self.rgb_volume = self.rgb_volume.clamp(0,1)
156 |
157 | if self.args.frames is not None:
158 | self.rgb_volume = self.rgb_volume[self.args.frames] # subsample frames
159 |
160 |
161 | def __len__(self):
162 | return self.num_batches # arbitrary as we continuously generate random samples
163 |
164 | def __getitem__(self, idx):
165 | if self.args.frame_cutoff:
166 | self.frame_cutoff = min(int((0.1 + 2*self.sin_epoch) * self.num_frames), self.num_frames) # gradually increase frame cutoff
167 | else:
168 | self.frame_cutoff = self.num_frames
169 |
170 | uv = torch.rand((self.frame_batch_size * self.frame_cutoff), 2)*0.98 + 0.01 # uniform random in [0.01,0.99]
171 |
172 | # t is time for all frames, looks like [0, 0,... 0, 1/N, 1/N, ..., 1/N, 2/N, 2/N, ..., 2/N, etc.]
173 | t = (torch.linspace(0,1,self.num_frames)[:self.frame_cutoff]).repeat_interleave(self.frame_batch_size)[:,None] # point_batch_size, 1
174 |
175 | return self.generate_samples(t, uv)
176 |
177 | def generate_samples(self, t, uv):
178 | """ generate samples from dataset and camera parameters for training
179 | """
180 |
181 | # create frame_batch_size of quaterions for each frame
182 | camera_to_world = (self.camera_to_world[:self.frame_cutoff]).repeat_interleave(self.frame_batch_size, dim=0)
183 | # create frame_batch_size of intrinsics for each frame
184 | intrinsics = (self.intrinsics[:self.frame_cutoff]).repeat_interleave(self.frame_batch_size, dim=0)
185 | intrinsics_inv = (self.intrinsics_inv[:self.frame_cutoff]).repeat_interleave(self.frame_batch_size, dim=0)
186 |
187 | # sample grid
188 | grid_uv = ((uv - 0.5) * 2).reshape(self.frame_cutoff,self.frame_batch_size,1,2)
189 | rgb_samples = F.grid_sample(self.rgb_volume[:self.frame_cutoff], grid_uv, mode="bilinear", padding_mode="border", align_corners=True)
190 | # samples get returned in shape: num_frames x channels x frame_batch_size x 1 for some reason
191 | rgb_samples = rgb_samples.permute(0,2,1,3).squeeze().flatten(0,1) # point_batch_size x channels
192 |
193 | return t, uv, camera_to_world, intrinsics, intrinsics_inv, rgb_samples
194 |
195 | def sample_frame(self, uv, frame):
196 | """ sample frame [frame] at coordinates u,v
197 | """
198 |
199 | grid_uv = ((uv - 0.5) * 2)[None,:,None,:] # 1,point_batch_size,1,2
200 | rgb_samples = F.grid_sample(self.rgb_volume[frame:frame+1], grid_uv, mode="bilinear", padding_mode="border", align_corners=True)
201 | rgb_samples = rgb_samples.squeeze().permute(1,0) # point_batch_size, C
202 |
203 | return rgb_samples
204 |
205 | #########################################################################################################
206 | ################################################ MODELS #################$###############################
207 | #########################################################################################################
208 |
209 | class RotationModel(pl.LightningModule):
210 | def __init__(self, args):
211 | super().__init__()
212 | self.args = args
213 | self.stabilize = False
214 | self.delta_control_points = torch.nn.Parameter(data=torch.zeros(1, 3, self.args.camera_control_points, dtype=torch.float32), requires_grad=True)
215 |
216 | def forward(self, camera_to_world, t):
217 | delta_control_points = self.delta_control_points.repeat(t.shape[0],1,1)
218 | rotation_deltas = utils.interpolate(delta_control_points, t)
219 | rx, ry, rz = rotation_deltas[:,0], rotation_deltas[:,1], rotation_deltas[:,2]
220 | r0 = torch.zeros_like(rx)
221 |
222 | rotation_offsets = torch.stack([torch.stack([ r0, -rz, ry], dim=-1),
223 | torch.stack([ rz, r0, -rx], dim=-1),
224 | torch.stack([-ry, rx, r0], dim=-1)], dim=-1)
225 |
226 | return camera_to_world + self.args.rotation_weight * rotation_offsets
227 |
228 | class TranslationModel(pl.LightningModule):
229 | def __init__(self, args):
230 | super().__init__()
231 | self.args = args
232 | self.stabilize = False
233 | self.delta_control_points = torch.nn.Parameter(data=torch.zeros(1, 3, self.args.camera_control_points, dtype=torch.float32), requires_grad=True)
234 |
235 | def forward(self, t):
236 | control_points = self.args.translation_weight * self.delta_control_points.repeat(t.shape[0],1,1)
237 | translation = utils.interpolate(control_points, t)
238 |
239 | return translation
240 |
241 | class PlaneModel(pl.LightningModule):
242 | """ Plane reprojection model with learnable z-depth
243 | """
244 | def __init__(self, args, depth):
245 | super().__init__()
246 |
247 | self.args = args
248 | self.depth = torch.nn.Parameter(data=torch.tensor([depth/1.0], dtype=torch.float32), requires_grad=True)
249 |
250 | self.u0 = torch.nn.Parameter(data=torch.tensor([1.0, 0.0], dtype=torch.float32), requires_grad=False)
251 | self.v0 = torch.nn.Parameter(data=torch.tensor([0.0, 1.0], dtype=torch.float32), requires_grad=False)
252 |
253 | def forward(self, ray_origins, ray_directions):
254 | # termination is just plane depth - ray origin z
255 | termination = ((1.0 * self.depth) - ray_origins[:, 2]).unsqueeze(1)
256 |
257 | # compute intersection points (N x 3)
258 | intersection_points = ray_origins + (termination * ray_directions)
259 |
260 | # project to (u, v) coordinates (N x 1 for each), avoid zero div
261 | u = 0.5 + 0.4 * torch.sum(intersection_points[:, :2] * (self.u0 / (torch.abs(termination) + 1e-6)), dim=1)
262 | v = 0.5 + 0.4 * torch.sum(intersection_points[:, :2] * (self.v0 / (torch.abs(termination) + 1e-6)), dim=1)
263 | uv = torch.stack((u, v), dim=1)
264 |
265 | return uv.clamp(0, 1) # ensure UV coordinates stay within neural field bounds
266 |
267 | class PlaneTransmissionModel(pl.LightningModule):
268 |
269 | def __init__(self, args):
270 | super().__init__()
271 | with open(f"config/config_{args.transmission_image_grid_size}.json") as config_image:
272 | config_image = json.load(config_image)
273 | with open(f"config/config_{args.transmission_flow_grid_size}.json" ) as config_flow:
274 | config_flow = json.load(config_flow)
275 |
276 | self.args = args
277 |
278 | self.encoding_image = tcnn.Encoding(n_input_dims=2, encoding_config=config_image["encoding"])
279 | self.encoding_flow = tcnn.Encoding(n_input_dims=2, encoding_config=config_flow["encoding"])
280 |
281 | self.network_image = tcnn.Network(n_input_dims=self.encoding_image.n_output_dims, n_output_dims=3, network_config=config_image["network"])
282 | self.network_flow = tcnn.Network(n_input_dims=self.encoding_flow.n_output_dims,
283 | n_output_dims=2*(self.args.transmission_control_points_flow), network_config=config_flow["network"])
284 |
285 | self.model_plane = PlaneModel(args, args.transmission_initial_depth)
286 | self.initial_rgb = torch.nn.Parameter(data=torch.zeros([1,3], dtype=torch.float32), requires_grad=True)
287 |
288 | def forward(self, t, ray_origins, ray_directions, sin_epoch):
289 | uv_plane = self.model_plane(ray_origins, ray_directions)
290 |
291 |
292 | flow = self.network_flow(utils.mask(self.encoding_flow(uv_plane), sin_epoch)) # B x 2
293 |
294 | flow = flow.reshape(-1,2,self.args.transmission_control_points_flow)
295 | flow = 0.01 * utils.interpolate(flow, t)
296 |
297 | rgb = self.network_image(utils.mask(self.encoding_image(uv_plane + flow), sin_epoch)).float()
298 | rgb = (self.initial_rgb + rgb).clamp(0,1)
299 |
300 | return rgb, flow
301 |
302 | class PlaneObstructionModel(pl.LightningModule):
303 |
304 | def __init__(self, args):
305 | super().__init__()
306 | with open(f"config/config_{args.obstruction_image_grid_size}.json") as config_image:
307 | config_image = json.load(config_image)
308 | with open(f"config/config_{args.obstruction_alpha_grid_size}.json" ) as config_alpha:
309 | config_alpha = json.load(config_alpha)
310 | with open(f"config/config_{args.obstruction_flow_grid_size}.json" ) as config_flow:
311 | config_flow = json.load(config_flow)
312 |
313 | self.args = args
314 |
315 | self.encoding_image = tcnn.Encoding(n_input_dims=2, encoding_config=config_image["encoding"])
316 | self.encoding_alpha = tcnn.Encoding(n_input_dims=2, encoding_config=config_alpha["encoding"])
317 | self.encoding_flow = tcnn.Encoding(n_input_dims=2, encoding_config=config_flow["encoding"])
318 |
319 | self.network_image = tcnn.Network(n_input_dims=self.encoding_image.n_output_dims, n_output_dims=3, network_config=config_image["network"])
320 | self.network_alpha = tcnn.Network(n_input_dims=self.encoding_alpha.n_output_dims, n_output_dims=1, network_config=config_alpha["network"])
321 | self.network_flow = tcnn.Network(n_input_dims=self.encoding_flow.n_output_dims, n_output_dims=2*(self.args.obstruction_control_points_flow), network_config=config_flow["network"])
322 |
323 | self.model_plane = PlaneModel(args, args.obstruction_initial_depth)
324 | self.initial_alpha = torch.nn.Parameter(data=torch.tensor(args.obstruction_initial_alpha, dtype=torch.float32), requires_grad=True)
325 | self.initial_rgb = torch.nn.Parameter(data=torch.zeros([1,3], dtype=torch.float32), requires_grad=True)
326 |
327 | def forward(self, t, ray_origins, ray_directions, sin_epoch):
328 | uv_plane = self.model_plane(ray_origins, ray_directions)
329 |
330 | flow = self.network_flow(utils.mask(self.encoding_flow(uv_plane), sin_epoch)) # B x 2
331 | flow = flow.reshape(-1,2,self.args.obstruction_control_points_flow)
332 | flow = 0.01 * utils.interpolate(flow, t)
333 |
334 | rgb = self.network_image(utils.mask(self.encoding_image(uv_plane + flow), sin_epoch)).float()
335 | rgb = (self.initial_rgb + rgb).clamp(0,1)
336 |
337 | alpha = self.network_alpha(utils.mask(self.encoding_alpha(uv_plane + flow), sin_epoch)).float()
338 | alpha = torch.sigmoid((-torch.log(1/self.initial_alpha - 1) + self.args.alpha_temperature * alpha))
339 |
340 | return rgb, flow, alpha
341 |
342 | #########################################################################################################
343 | ################################################ NETWORK ################################################
344 | #########################################################################################################
345 |
346 | class BundleMLP(pl.LightningModule):
347 | def __init__(self, args, cached_bundle=None):
348 | super().__init__()
349 | # load network configs
350 |
351 | self.args = args
352 | if cached_bundle is None:
353 | self.bundle = BundleDataset(self.args)
354 | else:
355 | with open(cached_bundle, 'rb') as file:
356 | self.bundle = pickle.load(file)
357 |
358 | self.img_width = self.bundle.img_width
359 | self.img_height = self.bundle.img_height
360 | self.lens_distortion = self.bundle.lens_distortion
361 | self.num_frames = args.num_frames = self.bundle.num_frames
362 | if args.frames is None:
363 | self.args.frames = list(range(self.num_frames))
364 |
365 | self.model_transmission = PlaneTransmissionModel(args)
366 | self.model_obstruction = PlaneObstructionModel(args)
367 | self.model_translation = TranslationModel(args)
368 | self.model_rotation = RotationModel(args)
369 |
370 | self.sin_epoch = 1.0
371 | self.save_hyperparameters()
372 |
373 | def load_volume(self):
374 | self.bundle.load_volume()
375 |
376 | def configure_optimizers(self):
377 |
378 | optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr)
379 | #constant lr
380 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1.0)
381 |
382 | return [optimizer], [scheduler]
383 |
384 |
385 | def forward(self, t, ray_origins, ray_directions):
386 | """ Forward model pass, estimate motion, implicit depth + image.
387 | """
388 |
389 | rgb_transmission, flow_transmission = self.model_transmission(t, ray_origins, ray_directions, self.sin_epoch)
390 | rgb_obstruction, flow_obstruction, alpha_obstruction = self.model_obstruction(t, ray_origins, ray_directions, self.sin_epoch)
391 |
392 | if self.args.single_plane:
393 | rgb_combined = rgb_transmission
394 | else:
395 | rgb_combined = rgb_transmission * (1 - alpha_obstruction) + rgb_obstruction * alpha_obstruction
396 |
397 | return rgb_combined, rgb_transmission, rgb_obstruction, flow_transmission, flow_obstruction, alpha_obstruction
398 |
399 | def generate_ray_directions(self, uv, camera_to_world, intrinsics_inv):
400 | u, v = uv[:,0:1] * self.img_width, uv[:,1:2] * self.img_height
401 | uv1 = torch.cat([u, v, torch.ones_like(uv[:,0:1])], dim=1) # N x 3
402 | # scale by image width/height
403 | xy1 = torch.bmm(intrinsics_inv, uv1.unsqueeze(2)).squeeze(2) # N x 3
404 | xy = xy1[:,0:2]
405 |
406 | f_div_cx = -1 / intrinsics_inv[:,0,2]
407 | f_div_cy = -1 / intrinsics_inv[:,1,2]
408 |
409 | r2 = torch.sum(xy**2, dim=1, keepdim=True) # N x 1
410 | r4 = r2**2
411 | r6 = r2**3
412 | kappa1, kappa2, kappa3 = self.lens_distortion[0:3]
413 |
414 | # apply lens distortion correction
415 | xy = xy * (1 + kappa1*r2 + kappa2*r4 + kappa3*r6)
416 |
417 | xy = xy * torch.min(f_div_cx[:, None], f_div_cy[:, None]) # scale long dimension to -1, 1
418 | ray_directions = torch.cat([xy, torch.ones_like(xy[:,0:1])], dim=1) # N x 3
419 | ray_directions = torch.bmm(camera_to_world, ray_directions.unsqueeze(2)).squeeze(2) # apply camera rotation
420 | ray_directions = ray_directions / ray_directions[:,2:3] # normalize by z
421 | return ray_directions
422 |
423 | def training_step(self, train_batch, batch_idx):
424 | t, uv, camera_to_world, intrinsics, intrinsics_inv, rgb_reference = debatch(train_batch) # collapse batch + point dimensions
425 |
426 | camera_to_world = self.model_rotation(camera_to_world, t) # apply rotation offset
427 | ray_origins = self.model_translation(t) # camera center in world coordinates
428 | ray_directions = self.generate_ray_directions(uv, camera_to_world, intrinsics_inv)
429 |
430 | rgb_combined, rgb_transmission, rgb_obstruction, flow_transmission, flow_obstruction, alpha_obstruction = self.forward(t, ray_origins, ray_directions)
431 |
432 | loss = 0.0
433 |
434 | rgb_loss = torch.abs((rgb_combined - rgb_reference)/(rgb_combined.detach() + 0.001))
435 | self.log('loss/rgb', rgb_loss.mean())
436 | loss += rgb_loss.mean()
437 |
438 | self.log(f'plane_depth/image', self.model_transmission.model_plane.depth)
439 | self.log(f'plane_depth/obstruction', self.model_obstruction.model_plane.depth)
440 |
441 | if (np.abs(self.args.alpha_weight) > 0 and self.sin_epoch) > 0.6:
442 | alpha_loss = self.args.alpha_weight * self.sin_epoch * alpha_obstruction
443 | self.log('loss/alpha', alpha_loss.mean())
444 | loss += alpha_loss.mean()
445 |
446 | return loss
447 |
448 | def color_and_tone(self, rgb_samples, height, width):
449 | """ Apply CCM and tone curve to raw samples
450 | """
451 |
452 | img = self.bundle.ccm.to(rgb_samples.device) @ rgb_samples.T
453 | img = img.reshape(3, height, width)
454 | img = utils.apply_tonemap_curve(img, self.bundle.tonemap_curve)
455 |
456 | return img
457 |
458 | def make_grid(self, height, width, u_lims, v_lims):
459 | """ Create (u,v) meshgrid with size (height,width) extent (u_lims, v_lims)
460 | """
461 | u = torch.linspace(u_lims[0], u_lims[1], width)
462 | v = torch.linspace(v_lims[0], v_lims[1], height)
463 | u_grid, v_grid = torch.meshgrid([u, v], indexing="xy") # u/v grid
464 | return torch.stack((u_grid.flatten(), v_grid.flatten())).t()
465 |
466 | def generate_img(self, frame, height=960, width=720, u_lims=[0,1], v_lims=[0,1]):
467 | """ Produce reference image for tensorboard/visualization
468 | """
469 | device = self.device
470 | uv = self.make_grid(height, width, u_lims, v_lims)
471 |
472 | rgb_samples = self.bundle.sample_frame(uv, frame).to(device)
473 | img = self.color_and_tone(rgb_samples, height, width)
474 |
475 | return img
476 |
477 | def generate_outputs(self, frame=0, height=720, width=540, u_lims=[0,1], v_lims=[0,1], time=None):
478 | """ Use forward model to sample implicit image I(u,v), depth D(u,v) and raw images
479 | at reprojected u,v, coordinates. Results should be aligned (sampled at (u',v'))
480 | """
481 | device = self.device
482 | uv = self.make_grid(height, width, u_lims, v_lims)
483 | if time is None:
484 | t = torch.tensor(frame/(self.bundle.num_frames - 1), dtype=torch.float32).repeat(uv.shape[0])[:,None] # num_points x 1
485 | else:
486 | t = torch.tensor(time, dtype=torch.float32).repeat(uv.shape[0])[:,None] # num_points x 1
487 | frame = int(np.floor(time * (self.bundle.num_frames - 1)))
488 |
489 | rgb_reference = self.bundle.sample_frame(uv, frame).to(device)
490 | intrinsics_inv = self.bundle.intrinsics_inv[frame:frame+2] # 2 x 3 x 3
491 | camera_to_world = self.bundle.camera_to_world[frame:frame+2] # 2 x 3 x 3
492 |
493 | if time is None or time >= 1.0: # select exact frame timestamp
494 | intrinsics_inv = intrinsics_inv[0:1]
495 | camera_to_world = camera_to_world[0:1]
496 | else: # interpolate between frames
497 | fraction = time * (self.bundle.num_frames - 1) - frame
498 | intrinsics_inv = intrinsics_inv[0:1] * (1 - fraction) + intrinsics_inv[1:2] * fraction
499 | camera_to_world = camera_to_world[0:1] * (1 - fraction) + camera_to_world[1:2] * fraction
500 |
501 | intrinsics_inv = intrinsics_inv.repeat(uv.shape[0],1,1) # num_points x 3 x 3
502 | camera_to_world = camera_to_world.repeat(uv.shape[0],1,1) # num_points x 3 x 3
503 |
504 | with torch.no_grad():
505 | rgb_combined_chunks = []
506 | rgb_transmission_chunks = []
507 | rgb_obstruction_chunks = []
508 | flow_transmission_chunks = []
509 | flow_obstruction_chunks = []
510 | alpha_obstruction_chunks = []
511 |
512 | chunk_size = 42 * self.args.point_batch_size
513 | for i in range((t.shape[0] // chunk_size) + 1):
514 | t_chunk, uv_chunk = t[i*chunk_size:(i+1)*chunk_size].to(device), uv[i*chunk_size:(i+1)*chunk_size].to(device)
515 | intrinsics_inv_chunk = intrinsics_inv[i*chunk_size:(i+1)*chunk_size].to(device)
516 | camera_to_world_chunk = camera_to_world[i*chunk_size:(i+1)*chunk_size].to(device)
517 |
518 | camera_to_world_chunk = self.model_rotation(camera_to_world_chunk, t_chunk) # apply rotation offset
519 | ray_origins = self.model_translation(t_chunk) # camera center in world coordinates
520 | ray_directions = self.generate_ray_directions(uv_chunk, camera_to_world_chunk, intrinsics_inv_chunk)
521 | rgb_combined, rgb_transmission, rgb_obstruction, flow_transmission, flow_obstruction, alpha_obstruction = self.forward(t_chunk, ray_origins, ray_directions)
522 |
523 | rgb_combined_chunks.append(rgb_combined.detach().cpu())
524 | rgb_transmission_chunks.append(rgb_transmission.detach().cpu())
525 | rgb_obstruction_chunks.append(rgb_obstruction.detach().cpu())
526 | flow_transmission_chunks.append(flow_transmission.detach().cpu())
527 | flow_obstruction_chunks.append(flow_obstruction.detach().cpu())
528 | alpha_obstruction_chunks.append(alpha_obstruction.detach().cpu())
529 |
530 | rgb_combined = torch.cat(rgb_combined_chunks, dim=0)
531 |
532 | rgb_reference = self.color_and_tone(rgb_reference, height, width)
533 | rgb_combined = self.color_and_tone(rgb_combined, height, width)
534 | rgb_transmission = self.color_and_tone(torch.cat(rgb_transmission_chunks, dim=0), height, width)
535 | rgb_obstruction = self.color_and_tone(torch.cat(rgb_obstruction_chunks, dim=0), height, width)
536 | flow_transmission = torch.cat(flow_transmission_chunks, dim=0).reshape(height, width, 2)
537 | flow_obstruction = torch.cat(flow_obstruction_chunks, dim=0).reshape(height, width, 2)
538 | alpha_obstruction = torch.cat(alpha_obstruction_chunks, dim=0).reshape(height, width)
539 |
540 | return rgb_reference, rgb_combined, rgb_transmission, rgb_obstruction, flow_transmission, flow_obstruction, alpha_obstruction
541 |
542 | def save_outputs(self, path, high_res=False):
543 | os.makedirs(f"outputs/{self.args.name + path}", exist_ok=True)
544 | if high_res:
545 | rgb_reference, rgb_combined, rgb_transmission, rgb_obstruction, flow_transmission, flow_obstruction, alpha_occlusion = model.generate_outputs(frame=0, height=2560, width=1920, u_lims=[0,1], v_lims=[0,1], time=0.0)
546 | np.save(f"outputs/{self.args.name + path}/flow_transmission.npy", flow_transmission.detach().cpu().numpy())
547 | np.save(f"outputs/{self.args.name + path}/flow_obstruction.npy", flow_obstruction.detach().cpu().numpy())
548 | else:
549 | rgb_reference, rgb_combined, rgb_transmission, rgb_obstruction, flow_transmission, flow_obstruction, alpha_occlusion = model.generate_outputs(frame=0, height=1080, width=810, u_lims=[0,1], v_lims=[0,1], time=0.0)
550 |
551 | plt.imsave(f"outputs/{self.args.name + path}/reference.png", rgb_reference.permute(1,2,0).detach().cpu().numpy())
552 | plt.imsave(f"outputs/{self.args.name + path}/alpha.png", alpha_occlusion.detach().cpu().numpy(), cmap="gray")
553 | plt.imsave(f"outputs/{self.args.name + path}/transmission.png", rgb_transmission.permute(1,2,0).detach().cpu().numpy())
554 | plt.imsave(f"outputs/{self.args.name + path}/obstruction.png", rgb_obstruction.permute(1,2,0).detach().cpu().numpy())
555 | plt.imsave(f"outputs/{self.args.name + path}/combined.png", rgb_combined.permute(1,2,0).detach().cpu().numpy())
556 |
557 |
558 | #########################################################################################################
559 | ############################################### VALIDATION ##############################################
560 | #########################################################################################################
561 |
562 | class ValidationCallback(pl.Callback):
563 | def __init__(self):
564 | super().__init__()
565 |
566 | def on_train_epoch_start(self, trainer, model):
567 | model.sin_epoch = min(1.0, 0.05 + np.sin(model.current_epoch/(model.args.max_epochs - 1) * np.pi/2)) # progression of training
568 | trainer.train_dataloader.dataset.sin_epoch = model.sin_epoch
569 | print(f" Sin of Current Epoch: {model.sin_epoch:.3f}")
570 |
571 | if model.sin_epoch > 0.4:
572 | # unlock flow model
573 | model.model_transmission.encoding_flow.requires_grad_(True)
574 | model.model_transmission.network_flow.requires_grad_(True)
575 | model.model_obstruction.encoding_flow.requires_grad_(True)
576 | model.model_obstruction.network_flow.requires_grad_(True)
577 |
578 | model.model_transmission.network_flow.train(True)
579 | model.model_transmission.encoding_flow.train(True)
580 | model.model_obstruction.encoding_flow.train(True)
581 | model.model_obstruction.network_flow.train(True)
582 |
583 | if model.args.fast: # skip tensorboarding except for beginning and end
584 | if model.current_epoch == model.args.max_epochs - 1 or model.current_epoch == 0:
585 | pass
586 | else:
587 | return
588 |
589 | # for i, frame in enumerate([0, model.bundle.num_frames//2, model.bundle.num_frames-1]): # can sample more frames
590 | for i, frame in enumerate([0]):
591 | rgb_reference, rgb_combined, rgb_transmission, rgb_obstruction, flow_transmission, flow_obstruction, alpha_obstruction = model.generate_outputs(frame)
592 | model.logger.experiment.add_image(f'pred/{i}_rgb_combined', rgb_combined, global_step=trainer.global_step)
593 | model.logger.experiment.add_image(f'pred/{i}_rgb_transmission', rgb_transmission, global_step=trainer.global_step)
594 | model.logger.experiment.add_image(f'pred/{i}_rgb_obstruction', rgb_obstruction, global_step=trainer.global_step)
595 | model.logger.experiment.add_image(f'pred/{i}_rgb_obstruction_alpha', rgb_obstruction * alpha_obstruction, global_step=trainer.global_step)
596 | model.logger.experiment.add_image(f'pred/{i}_alpha_obstruction', utils.colorize_tensor(alpha_obstruction, vmin=0, vmax=1, cmap="gray"), global_step=trainer.global_step)
597 |
598 |
599 | if model.args.save_video: # save the evolution of the model
600 | model.save_outputs(path=f"/{model.current_epoch}")
601 |
602 | def on_train_start(self, trainer, model):
603 | pl.seed_everything(42) # the answer to life, the universe, and everything
604 |
605 | # initialize rgb as average color of first frame of data (minimize the amount the rgb models have to learn)
606 | model.model_transmission.initial_rgb.data = torch.mean(model.bundle.rgb_volume[0], dim=(1,2))[None,:].to(model.device)
607 | model.model_obstruction.initial_rgb.data = torch.mean(model.bundle.rgb_volume[0], dim=(1,2))[None,:].to(model.device)
608 |
609 | model.logger.experiment.add_text("args", str(model.args))
610 |
611 | for i, frame in enumerate([0, model.bundle.num_frames//2, model.bundle.num_frames-1]):
612 | rgb_raw = model.generate_img(frame)
613 | model.logger.experiment.add_image(f'gt/{i}_rgb_raw', rgb_raw, global_step=trainer.global_step)
614 |
615 |
616 | def on_train_end(self, trainer, model):
617 | checkpoint_dir = os.path.join("checkpoints", model.args.name, "last.ckpt")
618 | bundle_dir = os.path.join("checkpoints", model.args.name, "bundle.pkl")
619 | trainer.save_checkpoint(checkpoint_dir)
620 |
621 | model.save_outputs(path=f"-final", high_res=True)
622 |
623 | with open(bundle_dir, 'wb') as file:
624 | model.bundle.rgb_volume = torch.ones([model.bundle.num_frames, model.bundle.img_channels, 3,3]).float()
625 | pickle.dump(model.bundle, file)
626 |
627 | if __name__ == "__main__":
628 |
629 | # argparse
630 | parser = argparse.ArgumentParser()
631 |
632 | # data
633 | parser.add_argument('--point_batch_size', type=int, default=2**18, help="Number of points to sample per dataloader index.")
634 | parser.add_argument('--num_batches', type=int, default=80, help="Number of training batches.")
635 | parser.add_argument('--max_percentile', type=float, default=100, help="Percentile of brightest pixels to cut.")
636 | parser.add_argument('--frames', type=str, help="Which subset of frames to use for training, e.g. 0,10,20,30,40")
637 | parser.add_argument('--rgb_data', action='store_true', help="Input data is pre-processed RGB.")
638 |
639 | # model
640 | parser.add_argument('--camera_control_points', type=int, default=22, help="Spline control points for translation/rotation model.")
641 | parser.add_argument('--alpha_weight', type=float, default=1e-2, help="Alpha regularization weight.")
642 | parser.add_argument('--rotation_weight', type=float, default=1e-3, help="Scale learned rotation.")
643 | parser.add_argument('--translation_weight', type=float, default=1e-2, help="Scale learned translation.")
644 | parser.add_argument('--alpha_temperature', type=float, default=1.0, help="Temperature for sigmoid in alpha matte calculation.")
645 |
646 | # planes
647 | parser.add_argument('--obstruction_control_points_flow', type=int, default=11, help="Spline control points for flow models.")
648 | parser.add_argument('--obstruction_flow_grid_size', type=str, default="tiny", help="Obstruction flow grid size (small, medium, large).")
649 | parser.add_argument('--obstruction_image_grid_size', type=str, default="large", help="Obstruction image grid size (small, medium, large).")
650 | parser.add_argument('--obstruction_alpha_grid_size', type=str, default="large", help="Obstruction alpha grid size (small, medium, large).")
651 | parser.add_argument('--obstruction_initial_depth', type=float, default=1.0, help="Obstruction initial plane depth.")
652 | parser.add_argument('--obstruction_initial_alpha', type=float, default=0.5, help="Obstruction initial alpha.")
653 | parser.add_argument('--transmission_control_points_flow', type=int, default=11, help="Spline control points for flow models.")
654 | parser.add_argument('--transmission_flow_grid_size', type=str, default="tiny", help="Transmission flow grid size (small, medium, large).")
655 | parser.add_argument('--transmission_image_grid_size', type=str, default="large", help="Transmission image grid size (small, medium, large).")
656 | parser.add_argument('--transmission_initial_depth', type=float, default=0.4, help="Transmission initial plane depth.")
657 | parser.add_argument('--single_plane', action='store_true', help="Use single plane model.")
658 |
659 | # training
660 | parser.add_argument('--bundle_path', type=str, required=True, help="Path to frame_bundle.npz")
661 | parser.add_argument('--name', type=str, required=True, help="Experiment name for logs and checkpoints.")
662 | parser.add_argument('--max_epochs', type=int, default=75, help="Number of training epochs.")
663 | parser.add_argument('--lr', type=float, default=3e-5, help="Learning rate.")
664 | parser.add_argument('--save_video', action='store_true', help="Store training outputs at each epoch for visualization.")
665 | parser.add_argument('--num_workers', type=int, default=4, help="Number of dataloader workers.")
666 | parser.add_argument('--debug', action='store_true', help="Debug mode, only use 1 batch.")
667 | parser.add_argument('--frame_cutoff', action='store_true', help="Use frame cutoff.")
668 | parser.add_argument('--fast', action='store_true', help="Fast mode.")
669 |
670 |
671 | args = parser.parse_args()
672 | # parse plane args
673 | print(args)
674 | if args.frames is not None:
675 | args.frames = [int(x) for x in re.findall(r"[-+]?\d*\.\d+|\d+", args.frames)]
676 |
677 | # model
678 | model = BundleMLP(args)
679 | model.load_volume()
680 |
681 | # freeze flow at the start of training as it will otherwise fight the camera model during early image fitting
682 | # these can be omitted at the cost of learning really weird camera translations
683 | model.model_transmission.encoding_flow.requires_grad_(False)
684 | model.model_transmission.encoding_flow.train(False)
685 | model.model_transmission.network_flow.requires_grad_(False)
686 | model.model_transmission.network_flow.train(False)
687 | model.model_obstruction.encoding_flow.requires_grad_(False)
688 | model.model_obstruction.encoding_flow.train(False)
689 | model.model_obstruction.network_flow.requires_grad_(False)
690 | model.model_obstruction.network_flow.train(False)
691 |
692 | # dataset
693 | bundle = model.bundle
694 | train_loader = DataLoader(bundle, batch_size=1, num_workers=args.num_workers, shuffle=False, pin_memory=True, prefetch_factor=1)
695 |
696 |
697 | torch.set_float32_matmul_precision('high')
698 |
699 | # training
700 | lr_callback = pl.callbacks.LearningRateMonitor()
701 | logger = pl.loggers.TensorBoardLogger(save_dir=os.getcwd(), version=args.name, name="lightning_logs")
702 | validation_callback = ValidationCallback()
703 | trainer = pl.Trainer(accelerator="gpu", devices=torch.cuda.device_count(), num_nodes=1, strategy="auto", max_epochs=args.max_epochs,
704 | logger=logger, callbacks=[validation_callback, lr_callback], enable_checkpointing=False, fast_dev_run=args.debug)
705 | trainer.fit(model, train_loader)
706 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import re
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | @torch.no_grad()
8 | def raw_to_rgb(bundle):
9 | """ Convert RAW mosaic into three-channel RGB volume
10 | by only in-filling empty pixels.
11 | Returns volume of shape: (T, C, H, W)
12 | """
13 |
14 | raw_frames = torch.tensor(np.array([bundle[f'raw_{i}']['raw'] for i in range(bundle['num_raw_frames'])]).astype(np.int32), dtype=torch.float32)[None] # C,T,H,W
15 | raw_frames = raw_frames.permute(1,0,2,3) # T,C,H,W
16 | color_correction_gains = bundle['raw_0']['android']['colorCorrection.gains']
17 | color_correction_gains = np.array([float(el) for el in re.sub(r'[^0-9.,]', '', color_correction_gains).split(',')]) # RGGB gains
18 | color_filter_arrangement = bundle['characteristics']['color_filter_arrangement']
19 | blacklevel = torch.tensor(np.array([bundle[f'raw_{i}']['blacklevel'] for i in range(bundle['num_raw_frames'])]))[:,:,None,None]
20 | whitelevel = torch.tensor(np.array([bundle[f'raw_{i}']['whitelevel'] for i in range(bundle['num_raw_frames'])]))[:,None,None,None]
21 | shade_maps = torch.tensor(np.array([bundle[f'raw_{i}']['shade_map'] for i in range(bundle['num_raw_frames'])])).permute(0,3,1,2) # T,C,H,W
22 | # interpolate to size of image
23 | shade_maps = F.interpolate(shade_maps, size=(raw_frames.shape[-2]//2, raw_frames.shape[-1]//2), mode='bilinear', align_corners=False)
24 |
25 | top_left = raw_frames[:,:,0::2,0::2]
26 | top_right = raw_frames[:,:,0::2,1::2]
27 | bottom_left = raw_frames[:,:,1::2,0::2]
28 | bottom_right = raw_frames[:,:,1::2,1::2]
29 |
30 | # figure out color channels
31 | if color_filter_arrangement == 0: # RGGB
32 | R, G1, G2, B = top_left, top_right, bottom_left, bottom_right
33 | elif color_filter_arrangement == 1: # GRBG
34 | G1, R, B, G2 = top_left, top_right, bottom_left, bottom_right
35 | elif color_filter_arrangement == 2: # GBRG
36 | G1, B, R, G2 = top_left, top_right, bottom_left, bottom_right
37 | elif color_filter_arrangement == 3: # BGGR
38 | B, G1, G2, R = top_left, top_right, bottom_left, bottom_right
39 |
40 | # apply color correction gains, flip to portrait
41 | R = ((R - blacklevel[:,0:1]) / (whitelevel - blacklevel[:,0:1]) * color_correction_gains[0])
42 | R *= shade_maps[:,0:1]
43 | G1 = ((G1 - blacklevel[:,1:2]) / (whitelevel - blacklevel[:,1:2]) * color_correction_gains[1])
44 | G1 *= shade_maps[:,1:2]
45 | G2 = ((G2 - blacklevel[:,2:3]) / (whitelevel - blacklevel[:,2:3]) * color_correction_gains[2])
46 | G2 *= shade_maps[:,2:3]
47 | B = ((B - blacklevel[:,3:4]) / (whitelevel - blacklevel[:,3:4]) * color_correction_gains[3])
48 | B *= shade_maps[:,3:4]
49 |
50 | rgb_volume = torch.zeros(raw_frames.shape[0], 3, raw_frames.shape[-2], raw_frames.shape[-1], dtype=torch.float32)
51 |
52 | # Fill gaps in blue channel
53 | rgb_volume[:, 2, 0::2, 0::2] = B.squeeze(1)
54 | rgb_volume[:, 2, 0::2, 1::2] = (B + torch.roll(B, -1, dims=3)).squeeze(1) / 2
55 | rgb_volume[:, 2, 1::2, 0::2] = (B + torch.roll(B, -1, dims=2)).squeeze(1) / 2
56 | rgb_volume[:, 2, 1::2, 1::2] = (B + torch.roll(B, -1, dims=2) + torch.roll(B, -1, dims=3) + torch.roll(B, [-1, -1], dims=[2, 3])).squeeze(1) / 4
57 |
58 | # Fill gaps in green channel
59 | rgb_volume[:, 1, 0::2, 0::2] = G1.squeeze(1)
60 | rgb_volume[:, 1, 0::2, 1::2] = (G1 + torch.roll(G1, -1, dims=3) + G2 + torch.roll(G2, 1, dims=2)).squeeze(1) / 4
61 | rgb_volume[:, 1, 1::2, 0::2] = (G1 + torch.roll(G1, -1, dims=2) + G2 + torch.roll(G2, 1, dims=3)).squeeze(1) / 4
62 | rgb_volume[:, 1, 1::2, 1::2] = G2.squeeze(1)
63 |
64 | # Fill gaps in red channel
65 | rgb_volume[:, 0, 0::2, 0::2] = R.squeeze(1)
66 | rgb_volume[:, 0, 0::2, 1::2] = (R + torch.roll(R, -1, dims=3)).squeeze(1) / 2
67 | rgb_volume[:, 0, 1::2, 0::2] = (R + torch.roll(R, -1, dims=2)).squeeze(1) / 2
68 | rgb_volume[:, 0, 1::2, 1::2] = (R + torch.roll(R, -1, dims=2) + torch.roll(R, -1, dims=3) + torch.roll(R, [-1, -1], dims=[2, 3])).squeeze(1) / 4
69 |
70 | rgb_volume = torch.flip(rgb_volume.transpose(-1,-2), [-1]) # rotate 90 degrees clockwise to portrait mode
71 |
72 | return rgb_volume
73 |
74 | def de_item(bundle):
75 | """ Call .item() on all dictionary items
76 | removes unnecessary extra dimension
77 | """
78 |
79 | bundle['motion'] = bundle['motion'].item()
80 | bundle['characteristics'] = bundle['characteristics'].item()
81 |
82 | for i in range(bundle['num_raw_frames']):
83 | bundle[f'raw_{i}'] = bundle[f'raw_{i}'].item()
84 |
85 | def mask(encoding, mask_coef):
86 | mask_coef = 0.4 + 0.6*mask_coef
87 | # interpolate to size of encoding
88 | mask = torch.zeros_like(encoding[0:1])
89 | mask_ceil = int(np.ceil(mask_coef * encoding.shape[1]))
90 | mask[:,:mask_ceil] = 1.0
91 |
92 | return encoding * mask
93 |
94 | def interpolate(signal, times):
95 | if signal.shape[-1] == 1:
96 | return signal.squeeze(-1)
97 | elif signal.shape[-1] == 2:
98 | return interpolate_linear(signal, times)
99 | else:
100 | return interpolate_cubic_hermite(signal, times)
101 |
102 | @torch.jit.script
103 | def interpolate_cubic_hermite(signal, times):
104 | # Interpolate a signal using cubic Hermite splines
105 | # signal: (B, C, T) or (B, T)
106 | # times: (B, T)
107 |
108 | if len(signal.shape) == 3: # B,C,T
109 | times = times.unsqueeze(1)
110 | times = times.repeat(1, signal.shape[1], 1)
111 |
112 | N = signal.shape[-1]
113 |
114 | times_scaled = times * (N - 1)
115 | indices = torch.floor(times_scaled).long()
116 |
117 | # Clamping to avoid out-of-bounds indices
118 | indices = torch.clamp(indices, 0, N - 2)
119 | left_indices = torch.clamp(indices - 1, 0, N - 1)
120 | right_indices = torch.clamp(indices + 1, 0, N - 1)
121 | right_right_indices = torch.clamp(indices + 2, 0, N - 1)
122 |
123 | t = (times_scaled - indices.float())
124 |
125 | p0 = torch.gather(signal, -1, left_indices)
126 | p1 = torch.gather(signal, -1, indices)
127 | p2 = torch.gather(signal, -1, right_indices)
128 | p3 = torch.gather(signal, -1, right_right_indices)
129 |
130 | # One-sided derivatives at the boundaries
131 | m0 = torch.where(left_indices == indices, (p2 - p1), (p2 - p0) / 2)
132 | m1 = torch.where(right_right_indices == right_indices, (p2 - p1), (p3 - p1) / 2)
133 |
134 | # Hermite basis functions
135 | h00 = (1 + 2*t) * (1 - t)**2
136 | h10 = t * (1 - t)**2
137 | h01 = t**2 * (3 - 2*t)
138 | h11 = t**2 * (t - 1)
139 |
140 | interpolation = h00 * p1 + h10 * m0 + h01 * p2 + h11 * m1
141 |
142 | if len(signal.shape) == 3: # remove extra singleton dimension
143 | interpolation = interpolation.squeeze(-1)
144 |
145 | return interpolation
146 |
147 |
148 | @torch.jit.script
149 | def interpolate_linear(signal, times):
150 | # Interpolate a signal using linear interpolation
151 | # signal: (B, C, T) or (B, T)
152 | # times: (B, T)
153 |
154 | if len(signal.shape) == 3: # B,C,T
155 | times = times.unsqueeze(1)
156 | times = times.repeat(1, signal.shape[1], 1)
157 |
158 | # Scale times to be between 0 and N - 1
159 | times_scaled = times * (signal.shape[-1] - 1)
160 |
161 | indices = torch.floor(times_scaled).long()
162 | right_indices = (indices + 1).clamp(max=signal.shape[-1] - 1)
163 |
164 | t = (times_scaled - indices.float())
165 |
166 | p0 = torch.gather(signal, -1, indices)
167 | p1 = torch.gather(signal, -1, right_indices)
168 |
169 | # Linear basis functions
170 | h00 = (1 - t)
171 | h01 = t
172 |
173 | interpolation = h00 * p0 + h01 * p1
174 |
175 | if len(signal.shape) == 3: # remove extra singleton dimension
176 | interpolation = interpolation.squeeze(-1)
177 |
178 | return interpolation
179 |
180 | @torch.jit.script
181 | def convert_quaternions_to_rot(quaternions):
182 | """ Convert quaternions (wxyz) to 3x3 rotation matrices.
183 | Adapted from: https://automaticaddison.com/how-to-convert-a-quaternion-to-a-rotation-matrix
184 | """
185 |
186 | qw, qx, qy, qz = quaternions[:,0], quaternions[:,1], quaternions[:,2], quaternions[:,3]
187 |
188 | R00 = 2 * ((qw * qw) + (qx * qx)) - 1
189 | R01 = 2 * ((qx * qy) - (qw * qz))
190 | R02 = 2 * ((qx * qz) + (qw * qy))
191 |
192 | R10 = 2 * ((qx * qy) + (qw * qz))
193 | R11 = 2 * ((qw * qw) + (qy * qy)) - 1
194 | R12 = 2 * ((qy * qz) - (qw * qx))
195 |
196 | R20 = 2 * ((qx * qz) - (qw * qy))
197 | R21 = 2 * ((qy * qz) + (qw * qx))
198 | R22 = 2 * ((qw * qw) + (qz * qz)) - 1
199 |
200 | R = torch.stack([R00, R01, R02, R10, R11, R12, R20, R21, R22], dim=-1)
201 | R = R.reshape(-1,3,3)
202 |
203 | return R
204 |
205 | def multi_interp(x, xp, fp):
206 | """ Simple extension of np.interp for independent
207 | linear interpolation of all axes of fp
208 | sample signal fp with timestamps xp at new timestamps x
209 | """
210 | if torch.is_tensor(fp):
211 | out = [torch.tensor(np.interp(x, xp, fp[:,i]), dtype=fp.dtype) for i in range(fp.shape[-1])]
212 | return torch.stack(out, dim=-1)
213 | else:
214 | out = [np.interp(x, xp, fp[:,i]) for i in range(fp.shape[-1])]
215 | return np.stack(out, axis=-1)
216 |
217 | def parse_ccm(s):
218 | ccm = torch.tensor([eval(x.group()) for x in re.finditer(r"[-+]?\d+/\d+|[-+]?\d+\.\d+|[-+]?\d+", s)])
219 | ccm = ccm.reshape(3,3)
220 | return ccm
221 |
222 | def parse_tonemap_curve(data_string):
223 | channels = re.findall(r'(R|G|B):\[(.*?)\]', data_string)
224 | result_array = np.zeros((3, len(channels[0][1].split('),')), 2))
225 |
226 | for i, (_, channel_data) in enumerate(channels):
227 | pairs = channel_data.split('),')
228 | for j, pair in enumerate(pairs):
229 | x, y = map(float, re.findall(r'([\d\.]+)', pair))
230 | result_array[i, j] = (x, y)
231 | return result_array
232 |
233 | def apply_tonemap_curve(image, tonemap):
234 | # apply tonemap curve to each color channel
235 | image_toned = image.clone().cpu().numpy()
236 |
237 | for i in range(3):
238 | x_vals, y_vals = tonemap[i][:, 0], tonemap[i][:, 1]
239 | image_toned[i] = np.interp(image_toned[i], x_vals, y_vals)
240 |
241 | # Convert back to PyTorch tensor
242 | image_toned = torch.tensor(image_toned, dtype=torch.float32)
243 |
244 | return image_toned
245 |
246 | def debatch(batch):
247 | """ Collapse batch and channel dimension together
248 | """
249 | debatched = []
250 |
251 | for x in batch:
252 | if len(x.shape) <=1:
253 | raise Exception("This tensor is to small to debatch.")
254 | elif len(x.shape) == 2:
255 | debatched.append(x.reshape(x.shape[0] * x.shape[1]))
256 | else:
257 | debatched.append(x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]))
258 |
259 | return debatched
260 |
261 | def colorize_tensor(value, vmin=None, vmax=None, cmap=None, colorbar=False, height=9.6, width=7.2):
262 | """ Convert tensor to 3 channel RGB array according to colors from cmap
263 | similar usage as plt.imshow
264 | """
265 | assert len(value.shape) == 2 # H x W
266 |
267 | fig, ax = plt.subplots(1,1)
268 | fig.set_size_inches(width,height)
269 | a = ax.imshow(value.detach().cpu(), vmin=vmin, vmax=vmax, cmap=cmap)
270 | ax.set_axis_off()
271 | if colorbar:
272 | cbar = plt.colorbar(a, fraction=0.05)
273 | cbar.ax.tick_params(labelsize=30)
274 | plt.tight_layout()
275 | plt.close()
276 |
277 | # Draw figure on canvas
278 | fig.canvas.draw()
279 |
280 | # Convert the figure to numpy array, read the pixel values and reshape the array
281 | img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
282 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
283 |
284 | # Normalize into 0-1 range for TensorBoard(X). Swap axes for newer versions where API expects colors in first dim
285 | img = img / 255.0
286 |
287 | return torch.tensor(img).permute(2,0,1).float()
288 |
--------------------------------------------------------------------------------