├── README.md
├── conv_modules.py
├── data
├── KITTI.py
├── __pycache__
│ ├── Ego4D.cpython-310.pyc
│ ├── KITTI.cpython-310.pyc
│ ├── KITTI.cpython-39.pyc
│ ├── MultiFrameShapenet.cpython-310.pyc
│ ├── co3d.cpython-310.pyc
│ ├── co3d.cpython-39.pyc
│ ├── co3dv1.cpython-310.pyc
│ ├── realestate10k_dataio.cpython-310.pyc
│ └── realestate10k_dataio.cpython-39.pyc
├── co3d.py
└── realestate10k_dataio.py
├── demo.py
├── eval.py
├── geometry.py
├── mlp_modules.py
├── models.py
├── renderer.py
├── run.py
├── train.py
├── vis_scripts.py
└── wandb_logging.py
/README.md:
--------------------------------------------------------------------------------
1 | # FlowCam: Training Generalizable 3D Radiance Fields without Camera Poses via Pixel-Aligned Scene Flow
2 | ### [Project Page](https://cameronosmith.github.io/flowcam) | [Paper](https://arxiv.org/abs/2306.00180) | [Pretrained Models](https://drive.google.com/drive/folders/1t7vmvBg9OAo4S8I2zjwfqhL656H1r2JP?usp=sharing)
3 |
4 | [Cameron Smith](https://cameronosmith.github.io/),
5 | [Yilun Du](https://yilundu.github.io/),
6 | [Ayush Tewari](https://ayushtewari.com),
7 | [Vincent Sitzmann](https://vsitzmann.github.io/)
8 |
9 | MIT
10 |
11 | This is the official implementation of the paper "FlowCam: Training Generalizable 3D Radiance Fields without Camera Poses via Pixel-Aligned Scene Flow".
12 |
13 |
14 |
15 | ## High-Level structure
16 | The code is organized as follows:
17 | * models.py contains the model definition
18 | * run.py contains a generic argument parser which creates the model and dataloaders for both training and evaluation
19 | * train.py and eval.py contains train and evaluation loops
20 | * mlp_modules.py and conv_modules.py contain common MLP and CNN blocks
21 | * vis_scripts.py contains plotting and wandb logging code
22 | * renderer.py implements volume rendering helper functions
23 | * geometry.py implements various geometric operations (projections, 3D lifting, rigid transforms, etc.)
24 | * data contains a list of dataset scripts
25 | * demo.py contains a script to run our model on any image directory for pose estimates. See the file header for an example on running it.
26 |
27 | ## Reproducing experiments
28 |
29 | See `python run.py --help` for a list of command line arguments.
30 | An example training command for CO3D-Hydrants is `python train.py --dataset hydrant --vid_len 8 --batch_size 2 --online --name hydrants_flowcam --n_skip 1 2.`
31 | Similarly, replace `--dataset hydrants` with any of `[realestate,kitti,10cat]` for training on RealEstate10K, KITTI, or CO3D-10Category.
32 |
33 | Example training commands for each dataset are listed below:
34 | `python train.py --dataset hydrant --vid_len 8 --batch_size 2 --online --name hydrant_flowcam --n_skip 1 2`
35 | `python train.py --dataset 10cat --vid_len 8 --batch_size 2 --online --name 10cat_flowcam --n_skip 1`
36 | `python train.py --dataset realestate --vid_len 8 --batch_size 2 --online --name realestate_flowcam --n_skip 9`
37 | `python train.py --dataset kitti --vid_len 8 --batch_size 2 --online --name kitti_flowcam --n_skip 0`
38 |
39 | Use the `--online` flag for summaries to be logged to your wandb account or omit it otherwise.
40 |
41 | ## Environment variables
42 |
43 | We use environment variables to set the dataset and logging paths, though you can easily hardcode the paths in each respective dataset script. Specifically, we use the environment variables `CO3D_ROOT, RE10K_IMG_ROOT, RE10K_POSE_ROOT, KITTI_ROOT, and LOGDIR`. For instance, you can add the line `export CO3D_ROOT="/nobackup/projects/public/facebook-co3dv2"` to your `.bashrc`.
44 |
45 | ## Data
46 |
47 | The KITTI dataset we use can be downloaded here: https://www.cvlibs.net/datasets/kitti/raw_data.php
48 |
49 | Instructions for downloading the RealEstate10K dataset can be found here: https://github.com/yilundu/cross_attention_renderer/blob/master/data_download/README.md
50 |
51 | We use the V2 version of the CO3D dataset, which can be downloaded here: https://github.com/facebookresearch/co3d
52 |
53 | ## Using FlowCam to estimate poses for your own scenes
54 |
55 | You can query FlowCam for any set of images using the script in `demo.py` and specifying the rgb_path, intrinsics (fx,fy,cx,cy), the pretrained checkpoint, whether to render out the reconstructed images or not (slower but illustrates how accurate the geometry is estimated by the model), and the image resolution to resize to in preprocessing (should be around 128 width to avoid memory issues).
56 | For example: `python demo.py --demo_rgb /nobackup/projects/public/facebook-co3dv2/hydrant/615_99120_197713/images --intrinsics 1.7671e+03,3.1427e+03,5.3550e+02,9.5150e+02 -c pretrained_models/co3d_hydrant.pt --render_imgs --low_res 144 128`. The script will write the poses, a rendered pose plot, and re-rendered rgb and depth (if requested) to the folder `demo_output`.
57 | The RealEstate10K pretrained (`pretrained_models/re10k.pt`) model probably has the most general prior to use for your own scenes. We are planning on training and releasing a model on all datasets for a more general prior, so stay tuned for that.
58 |
59 | ### Coordinate and camera parameter conventions
60 | This code uses an "OpenCV" style camera coordinate system, where the Y-axis points downwards (the up-vector points in the negative Y-direction), the X-axis points right, and the Z-axis points into the image plane.
61 |
62 | ### Citation
63 | If you find our work useful in your research, please cite:
64 | ```
65 | @misc{smith2023flowcam,
66 | title={FlowCam: Training Generalizable 3D Radiance Fields without Camera Poses via Pixel-Aligned Scene Flow},
67 | author={Cameron Smith and Yilun Du and Ayush Tewari and Vincent Sitzmann},
68 | year={2023},
69 | eprint={2306.00180},
70 | archivePrefix={arXiv},
71 | primaryClass={cs.CV}
72 | }
73 | ```
74 |
75 | ### Contact
76 | If you have any questions, please email Cameron Smith at omid.smith.cameron@gmail.com or open an issue.
77 |
--------------------------------------------------------------------------------
/conv_modules.py:
--------------------------------------------------------------------------------
1 | import torch, torchvision
2 | from torch import nn
3 | import functools
4 | from einops import rearrange, repeat
5 | from torch.nn import functional as F
6 | import numpy as np
7 |
8 | def get_norm_layer(norm_type="instance", group_norm_groups=32):
9 | """Return a normalization layer
10 | Parameters:
11 | norm_type (str) -- the name of the normalization layer: batch | instance | none
12 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
13 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
14 | """
15 | if norm_type == "batch":
16 | norm_layer = functools.partial(
17 | nn.BatchNorm2d, affine=True, track_running_stats=True
18 | )
19 | elif norm_type == "instance":
20 | norm_layer = functools.partial(
21 | nn.InstanceNorm2d, affine=False, track_running_stats=False
22 | )
23 | elif norm_type == "group":
24 | norm_layer = functools.partial(nn.GroupNorm, group_norm_groups)
25 | elif norm_type == "none":
26 | norm_layer = None
27 | else:
28 | raise NotImplementedError("normalization layer [%s] is not found" % norm_type)
29 | return norm_layer
30 |
31 |
32 | class PixelNeRFEncoder(nn.Module):
33 | def __init__(
34 | self,
35 | backbone="resnet34",
36 | pretrained=True,
37 | num_layers=4,
38 | index_interp="bilinear",
39 | index_padding="border",
40 | upsample_interp="bilinear",
41 | feature_scale=1.0,
42 | use_first_pool=True,
43 | norm_type="batch",
44 | in_ch=3,
45 | ):
46 | super().__init__()
47 |
48 | self.use_custom_resnet = backbone == "custom"
49 | self.feature_scale = feature_scale
50 | self.use_first_pool = use_first_pool
51 | norm_layer = get_norm_layer(norm_type)
52 |
53 | print("Using torchvision", backbone, "encoder")
54 | self.model = getattr(torchvision.models, backbone)(
55 | pretrained=pretrained, norm_layer=norm_layer
56 | )
57 |
58 | if in_ch != 3:
59 | self.model.conv1 = nn.Conv2d(
60 | in_ch,
61 | self.model.conv1.weight.shape[0],
62 | self.model.conv1.kernel_size,
63 | self.model.conv1.stride,
64 | self.model.conv1.padding,
65 | padding_mode=self.model.conv1.padding_mode,
66 | )
67 |
68 | # Following 2 lines need to be uncommented for older configs
69 | self.model.fc = nn.Sequential()
70 | self.model.avgpool = nn.Sequential()
71 | self.latent_size = [0, 64, 128, 256, 512, 1024][num_layers]
72 |
73 | self.num_layers = num_layers
74 | self.index_interp = index_interp
75 | self.index_padding = index_padding
76 | self.upsample_interp = upsample_interp
77 | self.register_buffer("latent", torch.empty(1, 1, 1, 1), persistent=False)
78 | self.register_buffer(
79 | "latent_scaling", torch.empty(2, dtype=torch.float32), persistent=False
80 | )
81 |
82 | self.out = nn.Sequential(
83 | nn.Conv2d(self.latent_size, 512, 1),
84 | )
85 |
86 | def forward(self, x, custom_size=None):
87 |
88 |
89 | if len(x.shape)>4: return self(x.flatten(0,1),custom_size).unflatten(0,x.shape[:2])
90 |
91 | if self.feature_scale != 1.0:
92 | x = F.interpolate(
93 | x,
94 | scale_factor=self.feature_scale,
95 | mode="bilinear" if self.feature_scale > 1.0 else "area",
96 | align_corners=True if self.feature_scale > 1.0 else None,
97 | recompute_scale_factor=True,
98 | )
99 | x = x.to(device=self.latent.device)
100 |
101 | if self.use_custom_resnet:
102 | self.latent = self.model(x)
103 | else:
104 | x = self.model.conv1(x)
105 | x = self.model.bn1(x)
106 | x = self.model.relu(x)
107 |
108 | latents = [x]
109 | if self.num_layers > 1:
110 | if self.use_first_pool:
111 | x = self.model.maxpool(x)
112 | x = self.model.layer1(x)
113 | latents.append(x)
114 | if self.num_layers > 2:
115 | x = self.model.layer2(x)
116 | latents.append(x)
117 | if self.num_layers > 3:
118 | x = self.model.layer3(x)
119 | latents.append(x)
120 | if self.num_layers > 4:
121 | x = self.model.layer4(x)
122 | latents.append(x)
123 |
124 | self.latents = latents
125 | align_corners = None if self.index_interp == "nearest " else True
126 | latent_sz = latents[0].shape[-2:]
127 | for i in range(len(latents)):
128 | latents[i] = F.interpolate(
129 | latents[i],
130 | latent_sz if custom_size is None else custom_size,
131 | mode=self.upsample_interp,
132 | align_corners=align_corners,
133 | )
134 | self.latent = torch.cat(latents, dim=1)
135 | self.latent_scaling[0] = self.latent.shape[-1]
136 | self.latent_scaling[1] = self.latent.shape[-2]
137 | self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0
138 | return self.out(self.latent)
139 |
--------------------------------------------------------------------------------
/data/KITTI.py:
--------------------------------------------------------------------------------
1 | import os
2 | import multiprocessing as mp
3 | import torch.nn.functional as F
4 | import torch
5 | import random
6 | import imageio
7 | import numpy as np
8 | from glob import glob
9 | from collections import defaultdict
10 | from pdb import set_trace as pdb
11 | from itertools import combinations
12 | from random import choice
13 | import matplotlib.pyplot as plt
14 |
15 | from torchvision import transforms
16 | from einops import rearrange, repeat
17 |
18 | import sys
19 |
20 | # Geometry functions below used for calculating depth, ignore
21 | def glob_imgs(path):
22 | imgs = []
23 | for ext in ["*.png", "*.jpg", "*.JPEG", "*.JPG"]:
24 | imgs.extend(glob(os.path.join(path, ext)))
25 | return imgs
26 |
27 |
28 | def pick(list, item_idcs):
29 | if not list:
30 | return list
31 | return [list[i] for i in item_idcs]
32 |
33 |
34 | def parse_intrinsics(intrinsics):
35 | fx = intrinsics[..., 0, :1]
36 | fy = intrinsics[..., 1, 1:2]
37 | cx = intrinsics[..., 0, 2:3]
38 | cy = intrinsics[..., 1, 2:3]
39 | return fx, fy, cx, cy
40 |
41 |
42 | hom = lambda x, i=-1: torch.cat((x, torch.ones_like(x.unbind(i)[0].unsqueeze(i))), i)
43 | ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
44 |
45 | def expand_as(x, y):
46 | if len(x.shape) == len(y.shape):
47 | return x
48 |
49 | for i in range(len(y.shape) - len(x.shape)):
50 | x = x.unsqueeze(-1)
51 |
52 | return x
53 |
54 |
55 | def lift(x, y, z, intrinsics, homogeneous=False):
56 | """
57 |
58 | :param self:
59 | :param x: Shape (batch_size, num_points)
60 | :param y:
61 | :param z:
62 | :param intrinsics:
63 | :return:
64 | """
65 | fx, fy, cx, cy = parse_intrinsics(intrinsics)
66 |
67 | x_lift = (x - expand_as(cx, x)) / expand_as(fx, x) * z
68 | y_lift = (y - expand_as(cy, y)) / expand_as(fy, y) * z
69 |
70 | if homogeneous:
71 | return torch.stack((x_lift, y_lift, z, torch.ones_like(z).to(x.device)), dim=-1)
72 | else:
73 | return torch.stack((x_lift, y_lift, z), dim=-1)
74 |
75 |
76 | def world_from_xy_depth(xy, depth, cam2world, intrinsics):
77 | batch_size, *_ = cam2world.shape
78 |
79 | x_cam = xy[..., 0]
80 | y_cam = xy[..., 1]
81 | z_cam = depth
82 |
83 | pixel_points_cam = lift(
84 | x_cam, y_cam, z_cam, intrinsics=intrinsics, homogeneous=True
85 | )
86 | world_coords = torch.einsum("b...ij,b...kj->b...ki", cam2world, pixel_points_cam)[
87 | ..., :3
88 | ]
89 |
90 | return world_coords
91 |
92 |
93 | def get_ray_directions(xy, cam2world, intrinsics, normalize=True):
94 | z_cam = torch.ones(xy.shape[:-1]).to(xy.device)
95 | pixel_points = world_from_xy_depth(
96 | xy, z_cam, intrinsics=intrinsics, cam2world=cam2world
97 | ) # (batch, num_samples, 3)
98 |
99 | cam_pos = cam2world[..., :3, 3]
100 | ray_dirs = pixel_points - cam_pos[..., None, :] # (batch, num_samples, 3)
101 | if normalize:
102 | ray_dirs = F.normalize(ray_dirs, dim=-1)
103 | return ray_dirs
104 |
105 |
106 | class SceneInstanceDataset(torch.utils.data.Dataset):
107 | """This creates a dataset class for a single object instance (such as a single car)."""
108 |
109 | def __init__(
110 | self,
111 | instance_idx,
112 | instance_dir,
113 | specific_observation_idcs=None,
114 | input_img_sidelength=None,
115 | img_sidelength=None,
116 | num_images=None,
117 | cache=None,
118 | raft=None,
119 | low_res=(64,208),
120 | ):
121 | self.instance_idx = instance_idx
122 | self.img_sidelength = img_sidelength
123 | self.input_img_sidelength = input_img_sidelength
124 | self.instance_dir = instance_dir
125 | self.cache = {}
126 |
127 | self.low_res=low_res
128 |
129 | pose_dir = os.path.join(instance_dir, "pose")
130 | color_dir = os.path.join(instance_dir, "image")
131 |
132 | import pykitti
133 |
134 | drive = self.instance_dir.strip("/").split("/")[-1].split("_")[-2]
135 | date = self.instance_dir.strip("/").split("/")[-2]
136 | self.kitti_raw = pykitti.raw(
137 | "/".join(self.instance_dir.rstrip("/").split("/")[:-2]), date, drive
138 | )
139 | self.num_img = len(
140 | os.listdir(
141 | os.path.join(self.instance_dir, self.instance_dir, "image_02/data")
142 | )
143 | )
144 |
145 | self.color_paths = sorted(glob_imgs(color_dir))
146 | self.pose_paths = sorted(glob(os.path.join(pose_dir, "*.txt")))
147 | self.instance_name = os.path.basename(os.path.dirname(self.instance_dir))
148 |
149 | if specific_observation_idcs is not None:
150 | self.color_paths = pick(self.color_paths, specific_observation_idcs)
151 | self.pose_paths = pick(self.pose_paths, specific_observation_idcs)
152 | elif num_images is not None:
153 | idcs = np.linspace(
154 | 0, stop=len(self.color_paths), num=num_images, endpoint=False, dtype=int
155 | )
156 | self.color_paths = pick(self.color_paths, idcs)
157 | self.pose_paths = pick(self.pose_paths, idcs)
158 |
159 | def set_img_sidelength(self, new_img_sidelength):
160 | """For multi-resolution training: Updates the image sidelength with whichimages are loaded."""
161 | self.img_sidelength = new_img_sidelength
162 |
163 | def __len__(self):
164 | return self.num_img
165 |
166 | def __getitem__(self, idx, context=False, input_context=True):
167 | # print("trgt load")
168 |
169 | rgb = transforms.ToTensor()(self.kitti_raw.get_cam2(idx)) * 2 - 1
170 |
171 | K = torch.from_numpy(self.kitti_raw.calib.K_cam2.copy())
172 | cam2imu = torch.from_numpy(self.kitti_raw.calib.T_cam2_imu).inverse()
173 | imu2world = torch.from_numpy(self.kitti_raw.oxts[idx].T_w_imu)
174 | cam2world = (imu2world @ cam2imu).float()
175 |
176 | uv = np.mgrid[0 : rgb.size(1), 0 : rgb.size(2)].astype(float).transpose(1, 2, 0)
177 | uv = torch.from_numpy(np.flip(uv, axis=-1).copy()).long()
178 |
179 | # Downsample
180 | h, w = rgb.shape[-2:]
181 |
182 | K = torch.stack((K[0] / w, K[1] / h, K[2])) #normalize intrinsics to be resolution independent
183 |
184 | scale = 2;
185 | lowh, loww = int(64 * scale), int(208 * scale)
186 | med_rgb = F.interpolate( rgb[None], (lowh, loww), mode="bilinear", align_corners=True)[0]
187 | scale = 3;
188 | lowh, loww = int(64 * scale), int(208 * scale)
189 | large_rgb = F.interpolate( rgb[None], (lowh, loww), mode="bilinear", align_corners=True)[0]
190 | uv_large = np.mgrid[0:lowh, 0:loww].astype(float).transpose(1, 2, 0)
191 | uv_large = torch.from_numpy(np.flip(uv_large, axis=-1).copy()).long()
192 | uv_large = uv_large / torch.tensor([loww, lowh]) # uv in [0,1]
193 |
194 | #scale = 1;
195 | lowh, loww = self.low_res#int(64 * scale), int(208 * scale)
196 | rgb = F.interpolate(
197 | rgb[None], (lowh, loww), mode="bilinear", align_corners=True
198 | )[0]
199 | uv = np.mgrid[0:lowh, 0:loww].astype(float).transpose(1, 2, 0)
200 | uv = torch.from_numpy(np.flip(uv, axis=-1).copy()).long()
201 | uv = uv / torch.tensor([loww, lowh]) # uv in [0,1]
202 |
203 | tmp = torch.eye(4)
204 | tmp[:3, :3] = K
205 | K = tmp
206 |
207 | sample = {
208 | "instance_name": self.instance_name,
209 | "instance_idx": torch.Tensor([self.instance_idx]).squeeze().long(),
210 | "cam2world": cam2world,
211 | "img_idx": torch.Tensor([idx]).squeeze().long(),
212 | "img_id": "%s_%02d_%02d" % (self.instance_name, self.instance_idx, idx),
213 | "rgb": rgb,
214 | "large_rgb": large_rgb,
215 | "med_rgb": med_rgb,
216 | "intrinsics": K.float(),
217 | "uv": uv,
218 | "uv_large": uv_large,
219 | }
220 |
221 | return sample
222 |
223 |
224 | def get_instance_datasets(
225 | root,
226 | max_num_instances=None,
227 | specific_observation_idcs=None,
228 | cache=None,
229 | sidelen=None,
230 | max_observations_per_instance=None,
231 | ):
232 | instance_dirs = sorted(glob(os.path.join(root, "*/")))
233 | assert len(instance_dirs) != 0, f"No objects in the directory {root}"
234 |
235 | if max_num_instances != None:
236 | instance_dirs = instance_dirs[:max_num_instances]
237 |
238 | all_instances = [
239 | SceneInstanceDataset(
240 | instance_idx=idx,
241 | instance_dir=dir,
242 | specific_observation_idcs=specific_observation_idcs,
243 | img_sidelength=sidelen,
244 | cache=cache,
245 | num_images=max_observations_per_instance,
246 | )
247 | for idx, dir in enumerate(instance_dirs)
248 | ]
249 | return all_instances
250 |
251 |
252 | class KittiDataset(torch.utils.data.Dataset):
253 | """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""
254 |
255 | def __init__(
256 | self,
257 | num_context=2,
258 | num_trgt=1,
259 | vary_context_number=False,
260 | query_sparsity=None,
261 | img_sidelength=None,
262 | input_img_sidelength=None,
263 | max_num_instances=None,
264 | max_observations_per_instance=None,
265 | specific_observation_idcs=None,
266 | val=False,
267 | test_context_idcs=None,
268 | context_is_last=False,
269 | context_is_first=False,
270 | cache=None,
271 | video=True,
272 | low_res=(64,208),
273 | n_skip=0,
274 | ):
275 |
276 | max_num_instances = None
277 |
278 | root_dir = os.environ['KITTI_ROOT']
279 |
280 | basedirs = list(
281 | filter(lambda x: "20" in x and "zip" not in x, os.listdir(root_dir))
282 | )
283 | drive_paths = []
284 | for basedir in basedirs:
285 | dirs = list(
286 | filter(
287 | lambda x: "txt" not in x,
288 | os.listdir(os.path.join(root_dir, basedir)),
289 | )
290 | )
291 | drive_paths += [
292 | os.path.abspath(os.path.join(root_dir, basedir, dir_)) for dir_ in dirs
293 | ]
294 | self.instance_dirs = sorted(drive_paths)
295 |
296 | if type(n_skip)==type([]):n_skip=n_skip[0]
297 | self.n_skip = n_skip+1
298 | self.num_context = num_context
299 | self.num_trgt = num_trgt
300 | self.query_sparsity = query_sparsity
301 | self.img_sidelength = img_sidelength
302 | self.vary_context_number = vary_context_number
303 | self.cache = {}
304 | self.test = val
305 | self.test_context_idcs = test_context_idcs
306 | self.context_is_last = context_is_last
307 | self.context_is_first = context_is_first
308 |
309 | print(f"Root dir {root_dir}, {len(self.instance_dirs)} instances")
310 |
311 | assert len(self.instance_dirs) != 0, "No objects in the data directory"
312 |
313 | self.max_num_instances = max_num_instances
314 | if max_num_instances == 1:
315 | self.instance_dirs = [
316 | x for x in self.instance_dirs if "2011_09_26_drive_0027_sync" in x
317 | ]
318 | print("note testing single dir") # testing dir
319 |
320 | self.all_instances = [
321 | SceneInstanceDataset(
322 | instance_idx=idx,
323 | instance_dir=dir,
324 | specific_observation_idcs=specific_observation_idcs,
325 | img_sidelength=img_sidelength,
326 | input_img_sidelength=input_img_sidelength,
327 | num_images=max_observations_per_instance,
328 | cache=cache,
329 | low_res=low_res,
330 | )
331 | for idx, dir in enumerate(self.instance_dirs)
332 | ]
333 | self.all_instances = [x for x in self.all_instances if len(x) > 40]
334 | if max_num_instances is not None:
335 | self.all_instances = self.all_instances[:max_num_instances]
336 |
337 | test_idcs = list(range(len(self.all_instances)))[::8]
338 | self.all_instances = [x for i,x in enumerate(self.all_instances) if (i in test_idcs and val) or (i not in test_idcs and not val)]
339 | print("validation: ",val,len(self.all_instances))
340 |
341 | self.num_per_instance_observations = [len(obj) for obj in self.all_instances]
342 | self.num_instances = len(self.all_instances)
343 |
344 | self.instance_img_pairs = []
345 | for i,instance_dir in enumerate(self.all_instances):
346 | for j in range(len(instance_dir)-n_skip*(num_trgt+1)):
347 | self.instance_img_pairs.append((i,j))
348 |
349 | def sparsify(self, dict, sparsity):
350 | new_dict = {}
351 | if sparsity is None:
352 | return dict
353 | else:
354 | # Sample upper_limit pixel idcs at random.
355 | rand_idcs = np.random.choice(
356 | self.img_sidelength ** 2, size=sparsity, replace=False
357 | )
358 | for key in ["rgb", "uv"]:
359 | new_dict[key] = dict[key][rand_idcs]
360 |
361 | for key, v in dict.items():
362 | if key not in ["rgb", "uv"]:
363 | new_dict[key] = dict[key]
364 |
365 | return new_dict
366 |
367 | def set_img_sidelength(self, new_img_sidelength):
368 | """For multi-resolution training: Updates the image sidelength with which images are loaded."""
369 | self.img_sidelength = new_img_sidelength
370 | for instance in self.all_instances:
371 | instance.set_img_sidelength(new_img_sidelength)
372 |
373 | def __len__(self):
374 | return len(self.instance_img_pairs)
375 |
376 | def get_instance_idx(self, idx):
377 | if self.test:
378 | obj_idx = 0
379 | while idx >= 0:
380 | idx -= self.num_per_instance_observations[obj_idx]
381 | obj_idx += 1
382 | return (
383 | obj_idx - 1,
384 | int(idx + self.num_per_instance_observations[obj_idx - 1]),
385 | )
386 | else:
387 | return np.random.randint(self.num_instances), 0
388 |
389 | def collate_fn(self, batch_list):
390 | keys = batch_list[0].keys()
391 | result = defaultdict(list)
392 |
393 | for entry in batch_list:
394 | # make them all into a new dict
395 | for key in keys:
396 | result[key].append(entry[key])
397 |
398 | for key in keys:
399 | try:
400 | result[key] = torch.stack(result[key], dim=0)
401 | except:
402 | continue
403 |
404 | return result
405 |
406 | def getframe(self, obj_idx, x):
407 | return (
408 | self.all_instances[obj_idx].__getitem__(
409 | x, context=True, input_context=True
410 | ),
411 | x,
412 | )
413 |
414 | def __getitem__(self, idx, sceneidx=None):
415 |
416 | context = []
417 | trgt = []
418 | post_input = []
419 | #obj_idx,det_idx= np.random.randint(self.num_instances), 0
420 | of=0
421 |
422 | obj_idx, i = self.instance_img_pairs[idx]
423 |
424 | if of: obj_idx = 0
425 |
426 | if sceneidx is not None:
427 | obj_idx, det_idx = sceneidx[0], sceneidx[0]
428 |
429 | if len(self.all_instances[obj_idx])<=i+self.num_trgt*self.n_skip:
430 | i=0
431 | if sceneidx is not None:
432 | i=sceneidx[1]
433 | for _ in range(self.num_trgt):
434 | if sceneidx is not None:
435 | print(i)
436 | i += self.n_skip
437 | sample = self.all_instances[obj_idx].__getitem__(
438 | i, context=True, input_context=True
439 | )
440 | post_input.append(sample)
441 | post_input[-1]["mask"] = torch.Tensor([1.0])
442 | sub_sample = self.sparsify(sample, self.query_sparsity)
443 | trgt.append(sub_sample)
444 |
445 | post_input = self.collate_fn(post_input)
446 | trgt = self.collate_fn(trgt)
447 |
448 | out_dict = {"query": trgt, "post_input": post_input, "context": None}, trgt
449 |
450 | imgs = trgt["rgb"]
451 | imgs_large = (trgt["large_rgb"]*.5+.5)*255
452 | imgs_med = (trgt["large_rgb"]*.5+.5)*255
453 | Ks = trgt["intrinsics"][:,:3,:3]
454 | uv = trgt["uv"].flatten(1,2)
455 |
456 | #imgs large in [0,255],
457 | #imgs in [-1,1],
458 | #gt_rgb in [0,1],
459 | model_input = {
460 | "trgt_rgb": imgs[1:],
461 | "ctxt_rgb": imgs[:-1],
462 | "trgt_rgb_large": imgs_large[1:],
463 | "ctxt_rgb_large": imgs_large[:-1],
464 | "trgt_rgb_med": imgs_med[1:],
465 | "ctxt_rgb_med": imgs_med[:-1],
466 | "intrinsics": Ks[1:],
467 | "x_pix": uv[1:],
468 | "trgt_c2w": trgt["cam2world"][1:],
469 | "ctxt_c2w": trgt["cam2world"][:-1],
470 | }
471 | gt = {
472 | "trgt_rgb": ch_sec(imgs[1:])*.5+.5,
473 | "ctxt_rgb": ch_sec(imgs[:-1])*.5+.5,
474 | "intrinsics": Ks[1:],
475 | "x_pix": uv[1:],
476 | }
477 | return model_input,gt
478 |
--------------------------------------------------------------------------------
/data/__pycache__/Ego4D.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cameronosmith/FlowCam/aff310543a858df2ddbf4c6ac6f01372e2f4562e/data/__pycache__/Ego4D.cpython-310.pyc
--------------------------------------------------------------------------------
/data/__pycache__/KITTI.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cameronosmith/FlowCam/aff310543a858df2ddbf4c6ac6f01372e2f4562e/data/__pycache__/KITTI.cpython-310.pyc
--------------------------------------------------------------------------------
/data/__pycache__/KITTI.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cameronosmith/FlowCam/aff310543a858df2ddbf4c6ac6f01372e2f4562e/data/__pycache__/KITTI.cpython-39.pyc
--------------------------------------------------------------------------------
/data/__pycache__/MultiFrameShapenet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cameronosmith/FlowCam/aff310543a858df2ddbf4c6ac6f01372e2f4562e/data/__pycache__/MultiFrameShapenet.cpython-310.pyc
--------------------------------------------------------------------------------
/data/__pycache__/co3d.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cameronosmith/FlowCam/aff310543a858df2ddbf4c6ac6f01372e2f4562e/data/__pycache__/co3d.cpython-310.pyc
--------------------------------------------------------------------------------
/data/__pycache__/co3d.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cameronosmith/FlowCam/aff310543a858df2ddbf4c6ac6f01372e2f4562e/data/__pycache__/co3d.cpython-39.pyc
--------------------------------------------------------------------------------
/data/__pycache__/co3dv1.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cameronosmith/FlowCam/aff310543a858df2ddbf4c6ac6f01372e2f4562e/data/__pycache__/co3dv1.cpython-310.pyc
--------------------------------------------------------------------------------
/data/__pycache__/realestate10k_dataio.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cameronosmith/FlowCam/aff310543a858df2ddbf4c6ac6f01372e2f4562e/data/__pycache__/realestate10k_dataio.cpython-310.pyc
--------------------------------------------------------------------------------
/data/__pycache__/realestate10k_dataio.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cameronosmith/FlowCam/aff310543a858df2ddbf4c6ac6f01372e2f4562e/data/__pycache__/realestate10k_dataio.cpython-39.pyc
--------------------------------------------------------------------------------
/data/co3d.py:
--------------------------------------------------------------------------------
1 | # note for davis dataloader later: temporally consistent depth estimator: https://github.com/yu-li/TCMonoDepth
2 | # note for cool idea of not even downloading data and just streaming from youtube:https://gist.github.com/Mxhmovd/41e7690114e7ddad8bcd761a76272cc3
3 | import matplotlib.pyplot as plt;
4 | import cv2
5 | import os
6 | import multiprocessing as mp
7 | import torch.nn.functional as F
8 | import torch
9 | import random
10 | import imageio
11 | import numpy as np
12 | from glob import glob
13 | from collections import defaultdict
14 | from pdb import set_trace as pdb
15 | from itertools import combinations
16 | from random import choice
17 | import matplotlib.pyplot as plt
18 | import imageio.v3 as iio
19 |
20 | from torchvision import transforms
21 |
22 | import sys
23 |
24 | from glob import glob
25 | import os
26 | import gzip
27 | import json
28 | import numpy as np
29 |
30 | from PIL import Image
31 | def _load_16big_png_depth(depth_png) -> np.ndarray:
32 | with Image.open(depth_png) as depth_pil:
33 | # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
34 | # we cast it to uint16, then reinterpret as float16, then cast to float32
35 | depth = (
36 | np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
37 | .astype(np.float32)
38 | .reshape((depth_pil.size[1], depth_pil.size[0]))
39 | )
40 | return depth
41 | def _load_depth(path, scale_adjustment) -> np.ndarray:
42 | d = _load_16big_png_depth(path) * scale_adjustment
43 | d[~np.isfinite(d)] = 0.0
44 | return d[None] # fake feature channel
45 |
46 | # Geometry functions below used for calculating depth, ignore
47 | def glob_imgs(path):
48 | imgs = []
49 | for ext in ["*.png", "*.jpg", "*.JPEG", "*.JPG"]:
50 | imgs.extend(glob(os.path.join(path, ext)))
51 | return imgs
52 |
53 |
54 | def pick(list, item_idcs):
55 | if not list:
56 | return list
57 | return [list[i] for i in item_idcs]
58 |
59 |
60 | def parse_intrinsics(intrinsics):
61 | fx = intrinsics[..., 0, :1]
62 | fy = intrinsics[..., 1, 1:2]
63 | cx = intrinsics[..., 0, 2:3]
64 | cy = intrinsics[..., 1, 2:3]
65 | return fx, fy, cx, cy
66 |
67 |
68 | from einops import rearrange, repeat
69 | ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
70 | hom = lambda x, i=-1: torch.cat((x, torch.ones_like(x.unbind(i)[0].unsqueeze(i))), i)
71 |
72 |
73 | def expand_as(x, y):
74 | if len(x.shape) == len(y.shape):
75 | return x
76 |
77 | for i in range(len(y.shape) - len(x.shape)):
78 | x = x.unsqueeze(-1)
79 |
80 | return x
81 |
82 |
83 | def lift(x, y, z, intrinsics, homogeneous=False):
84 | """
85 |
86 | :param self:
87 | :param x: Shape (batch_size, num_points)
88 | :param y:
89 | :param z:
90 | :param intrinsics:
91 | :return:
92 | """
93 | fx, fy, cx, cy = parse_intrinsics(intrinsics)
94 |
95 | x_lift = (x - expand_as(cx, x)) / expand_as(fx, x) * z
96 | y_lift = (y - expand_as(cy, y)) / expand_as(fy, y) * z
97 |
98 | if homogeneous:
99 | return torch.stack((x_lift, y_lift, z, torch.ones_like(z).to(x.device)), dim=-1)
100 | else:
101 | return torch.stack((x_lift, y_lift, z), dim=-1)
102 |
103 |
104 | def world_from_xy_depth(xy, depth, cam2world, intrinsics):
105 | batch_size, *_ = cam2world.shape
106 |
107 | x_cam = xy[..., 0]
108 | y_cam = xy[..., 1]
109 | z_cam = depth
110 |
111 | pixel_points_cam = lift(
112 | x_cam, y_cam, z_cam, intrinsics=intrinsics, homogeneous=True
113 | )
114 | world_coords = torch.einsum("b...ij,b...kj->b...ki", cam2world, pixel_points_cam)[
115 | ..., :3
116 | ]
117 |
118 | return world_coords
119 |
120 |
121 | def get_ray_directions(xy, cam2world, intrinsics, normalize=True):
122 | z_cam = torch.ones(xy.shape[:-1]).to(xy.device)
123 | pixel_points = world_from_xy_depth(
124 | xy, z_cam, intrinsics=intrinsics, cam2world=cam2world
125 | ) # (batch, num_samples, 3)
126 |
127 | cam_pos = cam2world[..., :3, 3]
128 | ray_dirs = pixel_points - cam_pos[..., None, :] # (batch, num_samples, 3)
129 | if normalize:
130 | ray_dirs = F.normalize(ray_dirs, dim=-1)
131 | return ray_dirs
132 |
133 | from PIL import Image
134 | def _load_16big_png_depth(depth_png) -> np.ndarray:
135 | with Image.open(depth_png) as depth_pil:
136 | # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
137 | # we cast it to uint16, then reinterpret as float16, then cast to float32
138 | depth = (
139 | np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
140 | .astype(np.float32)
141 | .reshape((depth_pil.size[1], depth_pil.size[0]))
142 | )
143 | return depth
144 | def _load_depth(path, scale_adjustment) -> np.ndarray:
145 | d = _load_16big_png_depth(path) * scale_adjustment
146 | d[~np.isfinite(d)] = 0.0
147 | return d[None] # fake feature channel
148 |
149 | # NOTE currently using CO3D V1 because they switch to NDC cameras in 2. TODO is to make conversion code (different intrinsics), verify pointclouds, and switch.
150 |
151 | class Co3DNoCams(torch.utils.data.Dataset):
152 | """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""
153 |
154 | def __init__(
155 | self,
156 | num_context=2,
157 | n_skip=1,
158 | num_trgt=1,
159 | low_res=(128,144),
160 | depth_scale=1,#1.8/5,
161 | val=False,
162 | num_cat=1000,
163 | overfit=False,
164 | category=None,
165 | use_mask=False,
166 | use_v1=True,
167 | # delete below, not used
168 | vary_context_number=False,
169 | query_sparsity=None,
170 | img_sidelength=None,
171 | input_img_sidelength=None,
172 | max_num_instances=None,
173 | max_observations_per_instance=None,
174 | specific_observation_idcs=None,
175 | test=False,
176 | test_context_idcs=None,
177 | context_is_last=False,
178 | context_is_first=False,
179 | cache=None,
180 | video=True,
181 | ):
182 |
183 | if num_cat is None: num_cat=1000
184 |
185 | self.n_trgt=num_trgt
186 | self.use_mask=use_mask
187 | self.depth_scale=depth_scale
188 | self.of=overfit
189 | self.val=val
190 |
191 | self.num_skip=n_skip
192 | self.low_res=low_res
193 | max_num_instances = None
194 |
195 | self.base_path=os.environ['CO3D_ROOT']
196 | print(self.base_path)
197 |
198 | # Get sequences!
199 | from collections import defaultdict
200 | sequences = defaultdict(list)
201 | self.total_num_data=0
202 | self.all_frame_names=[]
203 | all_cats = [ "hydrant","teddybear","apple", "ball", "bench", "cake", "donut", "plant", "suitcase", "vase","backpack", "banana", "baseballbat", "baseballglove", "bicycle", "book", "bottle", "bowl", "broccoli", "car", "carrot", "cellphone", "chair", "couch", "cup", "frisbee", "hairdryer", "handbag", "hotdog", "keyboard", "kite", "laptop", "microwave", "motorcycle", "mouse", "orange", "parkingmeter", "pizza", "remote", "sandwich", "skateboard", "stopsign", "toaster", "toilet", "toybus", "toyplane", "toytrain", "toytruck", "tv", "umbrella", "wineglass", ]
204 |
205 | for cat in (all_cats[:num_cat]) if category is None else [category]:
206 | print(cat)
207 | dataset = json.loads(gzip.GzipFile(os.path.join(self.base_path,cat,"frame_annotations.jgz"),"rb").read().decode("utf8"))
208 | val_amt = int(len(dataset)*.03)
209 | dataset = dataset[:-val_amt] if not val else dataset[-val_amt:]
210 | self.total_num_data+=len(dataset)
211 | for i,data in enumerate(dataset):
212 | self.all_frame_names.append((data["sequence_name"],data["frame_number"]))
213 | sequences[data["sequence_name"]].append(data)
214 |
215 | sorted_seq={}
216 | for k,v in sequences.items():
217 | sorted_seq[k]=sorted(sequences[k],key=lambda x:x["frame_number"])
218 | #for k,v in sequences.items(): sequences[k]=v[:-(max(self.num_skip) if type(self.num_skip)==list else self.num_skip)*self.n_trgt]
219 | self.seqs = sorted_seq
220 |
221 | print("done with dataloader init")
222 |
223 | def sparsify(self, dict, sparsity):
224 | new_dict = {}
225 | if sparsity is None:
226 | return dict
227 | else:
228 | # Sample upper_limit pixel idcs at random.
229 | rand_idcs = np.random.choice(
230 | self.img_sidelength ** 2, size=sparsity, replace=False
231 | )
232 | for key in ["rgb", "uv"]:
233 | new_dict[key] = dict[key][rand_idcs]
234 |
235 | for key, v in dict.items():
236 | if key not in ["rgb", "uv"]:
237 | new_dict[key] = dict[key]
238 |
239 | return new_dict
240 |
241 | def set_img_sidelength(self, new_img_sidelength):
242 | """For multi-resolution training: Updates the image sidelength with which images are loaded."""
243 | self.img_sidelength = new_img_sidelength
244 | for instance in self.all_instances:
245 | instance.set_img_sidelength(new_img_sidelength)
246 |
247 | def __len__(self):
248 | return self.total_num_data
249 |
250 | def collate_fn(self, batch_list):
251 | keys = batch_list[0].keys()
252 | result = defaultdict(list)
253 |
254 | for entry in batch_list:
255 | # make them all into a new dict
256 | for key in keys:
257 | result[key].append(entry[key])
258 |
259 | for key in keys:
260 | try:
261 | result[key] = torch.stack(result[key], dim=0)
262 | except:
263 | continue
264 |
265 | return result
266 |
267 | def __getitem__(self, idx,seq_query=None):
268 |
269 | context = []
270 | trgt = []
271 | post_input = []
272 |
273 | n_skip = (random.choice(self.num_skip) if type(self.num_skip)==list else self.num_skip) + 1
274 |
275 | if seq_query is None:
276 | try:
277 | seq_name,frame_idx=self.all_frame_names[idx]
278 | except:
279 | print(f"Out of bounds erorr at {idx}. Investigate.")
280 | return self[-2*n_skip*self.n_trgt if self.val else np.random.randint(len(self))]
281 |
282 | if seq_query is not None:
283 | frame_idx=idx
284 | seq_name = list(self.seqs.keys())[seq_query]
285 | all_frames= self.seqs[seq_name]
286 | else:
287 | all_frames=self.seqs[seq_name] if not self.of else self.seqs[random.choice(list(self.seqs.keys())[:int(self.of)])]
288 |
289 | if len(all_frames)<=self.n_trgt*n_skip or frame_idx >= (len(all_frames)-self.n_trgt*n_skip):
290 | frame_idx=0
291 | if len(all_frames)<=self.n_trgt*n_skip or frame_idx >= (len(all_frames)-self.n_trgt*n_skip):
292 | if len(all_frames)<=self.n_trgt*n_skip:
293 | print(len(all_frames) ," frames < ",self.n_trgt*n_skip," queries")
294 | print("returning low/high")
295 | return self[-2*n_skip*self.n_trgt if self.val else np.random.randint(len(self))]
296 | start_idx = frame_idx
297 |
298 | if self.of and 1: start_idx=0
299 |
300 | frames = all_frames[start_idx:start_idx+self.n_trgt*n_skip:n_skip]
301 | if np.random.rand()<.5 and not self.of and not self.val: frames=frames[::-1]
302 |
303 | paths = [os.path.join(self.base_path,x["image"]["path"]) for x in frames]
304 | for path in paths:
305 | if not os.path.exists(path):
306 | print("path missing")
307 | return self[np.random.randint(len(self))]
308 |
309 | #masks=[torch.from_numpy(plt.imread(os.path.join(self.base_path,x["mask"]["path"]))) for x in frames]
310 | imgs=[torch.from_numpy(plt.imread(path)) for path in paths]
311 |
312 | Ks=[]
313 | c2ws=[]
314 | depths=[]
315 | for data in frames:
316 |
317 | #depths.append(torch.from_numpy(_load_depth(os.path.join(self.base_path,data["depth"]["path"]), data["depth"]["scale_adjustment"])[0])) # commenting out since slow to load; uncomment when needed
318 |
319 | # Below pose processing taken from co3d github issue
320 | p = data["viewpoint"]["principal_point"]
321 | f = data["viewpoint"]["focal_length"]
322 | h, w = data["image"]["size"]
323 | K = np.eye(3)
324 | s = (min(h, w)) / 2
325 | K[0, 0] = f[0] * (w) / 2
326 | K[1, 1] = f[1] * (h) / 2
327 | K[0, 2] = -p[0] * s + (w) / 2
328 | K[1, 2] = -p[1] * s + (h) / 2
329 |
330 | # Normalize intrinsics to [-1,1]
331 | #print(K)
332 | raw_K=[torch.from_numpy(K).clone(),[h,w]]
333 | K[:2] /= torch.tensor([w, h])[:, None]
334 | Ks.append(torch.from_numpy(K).float())
335 |
336 | R = np.asarray(data["viewpoint"]["R"]).T # note the transpose here
337 | T = np.asarray(data["viewpoint"]["T"]) * self.depth_scale
338 | pose = np.concatenate([R,T[:,None]],1)
339 | pose = torch.from_numpy( np.diag([-1,-1,1]).astype(np.float32) @ pose )# flip the direction of x,y axis
340 | tmp=torch.eye(4)
341 | tmp[:3,:4]=pose
342 | c2ws.append(tmp.inverse())
343 |
344 | Ks=torch.stack(Ks)
345 | c2w=torch.stack(c2ws).float()
346 |
347 | no_mask=0
348 | if no_mask:
349 | masks=[x*0+1 for x in masks]
350 |
351 | low_res=self.low_res#(128,144)#(108,144)
352 | minx,miny=min([x.size(0) for x in imgs]),min([x.size(1) for x in imgs])
353 |
354 | imgs=[x[:minx,:miny].float() for x in imgs]
355 |
356 | if self.use_mask: # mask images and depths
357 | imgs = [x*y.unsqueeze(-1)+(255*(1-y).unsqueeze(-1)) for x,y in zip(imgs,masks)]
358 | depths = [x*y for x,y in zip(depths,masks)]
359 |
360 | large_scale=2
361 | imgs_large = F.interpolate(torch.stack([x.permute(2,0,1) for x in imgs]),(int(256*large_scale),int(288*large_scale)),antialias=True,mode="bilinear")
362 | imgs_med = F.interpolate(torch.stack([x.permute(2,0,1) for x in imgs]),(int(256),int(288)),antialias=True,mode="bilinear")
363 | imgs = F.interpolate(torch.stack([x.permute(2,0,1) for x in imgs]),low_res,antialias=True,mode="bilinear")
364 |
365 | if self.use_mask:
366 | imgs = imgs*masks[:,None]+255*(1-masks[:,None])
367 |
368 | imgs = imgs/255 * 2 - 1
369 |
370 | uv = np.mgrid[0:low_res[0], 0:low_res[1]].astype(float).transpose(1, 2, 0)
371 | uv = torch.from_numpy(np.flip(uv, axis=-1).copy()).long()
372 | uv = uv/ torch.tensor([low_res[1]-1, low_res[0]-1]) # uv in [0,1]
373 | uv = uv[None].expand(len(imgs),-1,-1,-1).flatten(1,2)
374 |
375 | model_input = {
376 | "trgt_rgb": imgs[1:],
377 | "ctxt_rgb": imgs[:-1],
378 | "trgt_rgb_large": imgs_large[1:],
379 | "ctxt_rgb_large": imgs_large[:-1],
380 | "trgt_rgb_med": imgs_med[1:],
381 | "ctxt_rgb_med": imgs_med[:-1],
382 | #"ctxt_depth": depths.squeeze(1)[:-1],
383 | #"trgt_depth": depths.squeeze(1)[1:],
384 | "intrinsics": Ks[1:],
385 | "trgt_c2w": c2w[1:],
386 | "ctxt_c2w": c2w[:-1],
387 | "x_pix": uv[1:],
388 | #"trgt_mask": masks[1:],
389 | #"ctxt_mask": masks[:-1],
390 | }
391 |
392 | gt = {
393 | #"paths": paths,
394 | #"raw_K": raw_K,
395 | #"seq_name": seq_name,
396 | "trgt_rgb": ch_sec(imgs[1:])*.5+.5,
397 | "ctxt_rgb": ch_sec(imgs[:-1])*.5+.5,
398 | #"ctxt_depth": depths.squeeze(1)[:-1].flatten(1,2).unsqueeze(-1),
399 | #"trgt_depth": depths.squeeze(1)[1:].flatten(1,2).unsqueeze(-1),
400 | "intrinsics": Ks[1:],
401 | "x_pix": uv[1:],
402 | #"seq_name": [seq_name],
403 | #"trgt_mask": masks[1:].flatten(1,2).unsqueeze(-1),
404 | #"ctxt_mask": masks[:-1].flatten(1,2).unsqueeze(-1),
405 | }
406 |
407 | return model_input,gt
408 |
--------------------------------------------------------------------------------
/data/realestate10k_dataio.py:
--------------------------------------------------------------------------------
1 | import random
2 | from torch.nn import functional as F
3 | import os
4 | import torch
5 | import numpy as np
6 | from glob import glob
7 | import json
8 | from collections import defaultdict
9 | import os.path as osp
10 | from imageio import imread
11 | from torch.utils.data import Dataset
12 | from pathlib import Path
13 | import cv2
14 | from tqdm import tqdm
15 | from scipy.io import loadmat
16 |
17 | import functools
18 | import cv2
19 | import numpy as np
20 | import imageio
21 | from glob import glob
22 | import os
23 | import shutil
24 | import io
25 |
26 | not_of=1
27 |
28 | def load_rgb(path, sidelength=None):
29 | img = imageio.imread(path)[:, :, :3]
30 | img = skimage.img_as_float32(img)
31 |
32 | img = square_crop_img(img)
33 |
34 | if sidelength is not None:
35 | img = cv2.resize(img, (sidelength, sidelength), interpolation=cv2.INTER_NEAREST)
36 |
37 | img -= 0.5
38 | img *= 2.
39 | return img
40 |
41 | def load_depth(path, sidelength=None):
42 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32)
43 |
44 | if sidelength is not None:
45 | img = cv2.resize(img, (sidelength, sidelength), interpolation=cv2.INTER_NEAREST)
46 |
47 | img *= 1e-4
48 |
49 | if len(img.shape) == 3:
50 | img = img[:, :, :1]
51 | img = img.transpose(2, 0, 1)
52 | else:
53 | img = img[None, :, :]
54 | return img
55 |
56 |
57 | def load_pose(filename):
58 | lines = open(filename).read().splitlines()
59 | if len(lines) == 1:
60 | pose = np.zeros((4, 4), dtype=np.float32)
61 | for i in range(16):
62 | pose[i // 4, i % 4] = lines[0].split(" ")[i]
63 | return pose.squeeze()
64 | else:
65 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines[:4])]
66 | return np.asarray(lines).astype(np.float32).squeeze()
67 |
68 |
69 | def load_numpy_hdf5(instance_ds, key):
70 | rgb_ds = instance_ds['rgb']
71 | raw = rgb_ds[key][...]
72 | s = raw.tostring()
73 | f = io.BytesIO(s)
74 |
75 | img = imageio.imread(f)[:, :, :3]
76 | img = skimage.img_as_float32(img)
77 |
78 | img = square_crop_img(img)
79 |
80 | img -= 0.5
81 | img *= 2.
82 |
83 | return img
84 |
85 |
86 | def load_rgb_hdf5(instance_ds, key, sidelength=None):
87 | rgb_ds = instance_ds['rgb']
88 | raw = rgb_ds[key][...]
89 | s = raw.tostring()
90 | f = io.BytesIO(s)
91 |
92 | img = imageio.imread(f)[:, :, :3]
93 | img = skimage.img_as_float32(img)
94 |
95 | img = square_crop_img(img)
96 |
97 | if sidelength is not None:
98 | img = cv2.resize(img, (sidelength, sidelength), interpolation=cv2.INTER_AREA)
99 |
100 | img -= 0.5
101 | img *= 2.
102 |
103 | return img
104 |
105 |
106 | def load_pose_hdf5(instance_ds, key):
107 | pose_ds = instance_ds['pose']
108 | raw = pose_ds[key][...]
109 | ba = bytearray(raw)
110 | s = ba.decode('ascii')
111 |
112 | lines = s.splitlines()
113 |
114 | if len(lines) == 1:
115 | pose = np.zeros((4, 4), dtype=np.float32)
116 | for i in range(16):
117 | pose[i // 4, i % 4] = lines[0].split(" ")[i]
118 | # processed_pose = pose.squeeze()
119 | return pose.squeeze()
120 | else:
121 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines[:4])]
122 | return np.asarray(lines).astype(np.float32).squeeze()
123 |
124 |
125 | def cond_mkdir(path):
126 | if not os.path.exists(path):
127 | os.makedirs(path)
128 |
129 |
130 | def square_crop_img(img):
131 | min_dim = np.amin(img.shape[:2])
132 | center_coord = np.array(img.shape[:2]) // 2
133 | img = img[center_coord[0] - min_dim // 2:center_coord[0] + min_dim // 2,
134 | center_coord[1] - min_dim // 2:center_coord[1] + min_dim // 2]
135 | return img
136 |
137 |
138 | def glob_imgs(path):
139 | imgs = []
140 | for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG']:
141 | imgs.extend(glob(os.path.join(path, ext)))
142 | return imgs
143 |
144 | def augment(rgb, intrinsics, c2w_mat):
145 |
146 | # Horizontal Flip with 50% Probability
147 | if np.random.uniform(0, 1) < 0.5:
148 | rgb = rgb[:, ::-1, :]
149 | tf_flip = np.array([[-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
150 | c2w_mat = c2w_mat @ tf_flip
151 |
152 | # Crop by aspect ratio
153 | if np.random.uniform(0, 1) < 0.5:
154 | py = np.random.randint(1, 32)
155 | rgb = rgb[py:-py, :, :]
156 | else:
157 | py = 0
158 |
159 | if np.random.uniform(0, 1) < 0.5:
160 | px = np.random.randint(1, 32)
161 | rgb = rgb[:, px:-px, :]
162 | else:
163 | px = 0
164 |
165 | H, W, _ = rgb.shape
166 | rgb = cv2.resize(rgb, (256, 256))
167 | xscale = 256 / W
168 | yscale = 256 / H
169 |
170 | intrinsics[0, 0] = intrinsics[0, 0] * xscale
171 | intrinsics[1, 1] = intrinsics[1, 1] * yscale
172 |
173 | return rgb, intrinsics, c2w_mat
174 |
175 | class Camera(object):
176 | def __init__(self, entry):
177 | fx, fy, cx, cy = entry[1:5]
178 | self.intrinsics = np.array([[fx, 0, cx, 0],
179 | [0, fy, cy, 0],
180 | [0, 0, 1, 0],
181 | [0, 0, 0, 1]])
182 | w2c_mat = np.array(entry[7:]).reshape(3, 4)
183 | w2c_mat_4x4 = np.eye(4)
184 | w2c_mat_4x4[:3, :] = w2c_mat
185 | self.w2c_mat = w2c_mat_4x4
186 | self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
187 |
188 |
189 | def unnormalize_intrinsics(intrinsics, h, w):
190 | intrinsics = intrinsics.copy()
191 | intrinsics[0] *= w
192 | intrinsics[1] *= h
193 | return intrinsics
194 |
195 |
196 | def parse_pose_file(file):
197 | f = open(file, 'r')
198 | cam_params = {}
199 | for i, line in enumerate(f):
200 | if i == 0:
201 | continue
202 | entry = [float(x) for x in line.split()]
203 | id = int(entry[0])
204 | cam_params[id] = Camera(entry)
205 | return cam_params
206 |
207 |
208 | def parse_pose(pose, timestep):
209 | timesteps = pose[:, :1]
210 | timesteps = np.around(timesteps)
211 | mask = (timesteps == timestep)[:, 0]
212 | pose_entry = pose[mask][0]
213 | camera = Camera(pose_entry)
214 |
215 | return camera
216 |
217 |
218 | def get_camera_pose(scene_path, all_pose_dir, uv, views=1):
219 | npz_files = sorted(scene_path.glob("*.npz"))
220 | npz_file = npz_files[0]
221 | data = np.load(npz_file)
222 | all_pose_dir = Path(all_pose_dir)
223 |
224 | rgb_files = list(data.keys())
225 |
226 | timestamps = [int(rgb_file.split('.')[0]) for rgb_file in rgb_files]
227 | sorted_ids = np.argsort(timestamps)
228 |
229 | rgb_files = np.array(rgb_files)[sorted_ids]
230 | timestamps = np.array(timestamps)[sorted_ids]
231 |
232 | camera_file = all_pose_dir / (str(scene_path.name) + '.txt')
233 | cam_params = parse_pose_file(camera_file)
234 | # H, W, _ = data[rgb_files[0]].shape
235 |
236 | # Weird cropping of images
237 | H, W = 256, 456
238 |
239 | xscale = W / min(H, W)
240 | yscale = H / min(H, W)
241 |
242 |
243 | query = {}
244 | context = {}
245 |
246 | render_frame = min(128, rgb_files.shape[0])
247 |
248 | query_intrinsics = []
249 | query_c2w = []
250 | query_rgbs = []
251 | for i in range(1, render_frame):
252 | rgb = data[rgb_files[i]]
253 | timestep = timestamps[i]
254 |
255 | # rgb = cv2.resize(rgb, (W, H))
256 | intrinsics = unnormalize_intrinsics(cam_params[timestep].intrinsics, H, W)
257 |
258 | intrinsics[0, 2] = intrinsics[0, 2] / xscale
259 | intrinsics[1, 2] = intrinsics[1, 2] / yscale
260 | rgb = rgb.astype(np.float32) / 127.5 - 1
261 |
262 | query_intrinsics.append(intrinsics)
263 | query_c2w.append(cam_params[timestep].c2w_mat)
264 | query_rgbs.append(rgb)
265 |
266 | context_intrinsics = []
267 | context_c2w = []
268 | context_rgbs = []
269 |
270 | if views == 1:
271 | render_ids = [0]
272 | elif views == 2:
273 | render_ids = [0, min(len(rgb_files) - 1, 128)]
274 | else:
275 | assert False
276 |
277 | for i in render_ids:
278 | rgb = data[rgb_files[i]]
279 | timestep = timestamps[i]
280 | # print("render: ", i)
281 | # rgb = cv2.resize(rgb, (W, H))
282 | intrinsics = unnormalize_intrinsics(cam_params[timestep].intrinsics, H, W)
283 | intrinsics[0, 2] = intrinsics[0, 2] / xscale
284 | intrinsics[1, 2] = intrinsics[1, 2] / yscale
285 |
286 | rgb = rgb.astype(np.float32) / 127.5 - 1
287 |
288 | context_intrinsics.append(intrinsics)
289 | context_c2w.append(cam_params[timestep].c2w_mat)
290 | context_rgbs.append(rgb)
291 |
292 | query = {'rgb': torch.Tensor(query_rgbs)[None].float(),
293 | 'cam2world': torch.Tensor(query_c2w)[None].float(),
294 | 'intrinsics': torch.Tensor(query_intrinsics)[None].float(),
295 | 'uv': uv.view(-1, 2)[None, None].expand(1, render_frame - 1, -1, -1)}
296 | ctxt = {'rgb': torch.Tensor(context_rgbs)[None].float(),
297 | 'cam2world': torch.Tensor(context_c2w)[None].float(),
298 | 'intrinsics': torch.Tensor(context_intrinsics)[None].float()}
299 |
300 | return {'query': query, 'context': ctxt}
301 |
302 | class RealEstate10k():
303 | def __init__(self, img_root=None, pose_root=None,
304 | num_ctxt_views=2, num_query_views=2, query_sparsity=None,imsl=256,
305 | max_num_scenes=None, square_crop=True, augment=False, lpips=False, dual_view=False, val=False,n_skip=12):
306 |
307 | self.n_skip =n_skip[0] if type(n_skip)==type([]) else n_skip
308 | print(self.n_skip,"n_skip")
309 | self.val = val
310 | if img_root is None: img_root = os.path.join(os.environ['RE10K_IMG_ROOT'],"test" if val else "train")
311 | if pose_root is None: pose_root = os.path.join(os.environ['RE10K_POSE_ROOT'],"test" if val else "train")
312 | print("Loading RealEstate10k...")
313 | self.num_ctxt_views = num_ctxt_views
314 | self.num_query_views = num_query_views
315 | self.query_sparsity = query_sparsity
316 | self.dual_view = dual_view
317 |
318 | self.imsl=imsl
319 |
320 | all_im_dir = Path(img_root)
321 | #self.all_pose_dir = Path(pose_root)
322 | self.all_pose = loadmat(pose_root)
323 | self.lpips = lpips
324 |
325 | self.all_scenes = sorted(all_im_dir.glob('*/'))
326 |
327 | dummy_img_path = str(next(self.all_scenes[0].glob("*.npz")))
328 |
329 | if max_num_scenes:
330 | self.all_scenes = list(self.all_scenes)[:max_num_scenes]
331 |
332 | data = np.load(dummy_img_path)
333 | key = list(data.keys())[0]
334 | im = data[key]
335 |
336 | H, W = im.shape[:2]
337 | H, W = 256, 455
338 | self.H, self.W = H, W
339 | self.augment = augment
340 |
341 | self.square_crop = square_crop
342 | # Downsample to be 256 x 256 image
343 | # self.H, self.W = 256, 455
344 |
345 | xscale = W / min(H, W)
346 | yscale = H / min(H, W)
347 |
348 | dim = min(H, W)
349 |
350 | self.xscale = xscale
351 | self.yscale = yscale
352 |
353 | # For now the images are already square cropped
354 | self.H = 256
355 | self.W = 455
356 |
357 | print(f"Resolution is {H}, {W}.")
358 |
359 | if self.square_crop:
360 | i, j = torch.meshgrid(torch.arange(0, self.imsl), torch.arange(0, self.imsl))
361 | else:
362 | i, j = torch.meshgrid(torch.arange(0, W), torch.arange(0, H))
363 |
364 | self.uv = torch.stack([i.float(), j.float()], dim=-1).permute(1, 0, 2)
365 |
366 | # if self.square_crop:
367 | # self.uv = data_util.square_crop_img(self.uv)
368 |
369 | self.uv = self.uv[None].permute(0, -1, 1, 2).permute(0, 2, 3, 1)
370 | self.uv = self.uv.reshape(-1, 2)
371 |
372 | self.scene_path_list = list(Path(img_root).glob("*/"))
373 |
374 | def __len__(self):
375 | return len(self.all_scenes)
376 |
377 | def __getitem__(self, idx,scene_query=None):
378 | idx = idx if not_of else 0
379 | scene_path = self.all_scenes[idx if scene_query is None else scene_query]
380 | npz_files = sorted(scene_path.glob("*.npz"))
381 |
382 | name = scene_path.name
383 |
384 | def get_another():
385 | if self.val:
386 | return self[idx-1 if idx >200 else idx+1]
387 | return self.__getitem__(random.randint(0, len(self.all_scenes) - 1))
388 |
389 | if name not in self.all_pose: return get_another()
390 |
391 | pose = self.all_pose[name]
392 |
393 | if len(npz_files) == 0:
394 | print("npz get another")
395 | return get_another()
396 |
397 | npz_file = npz_files[0]
398 | try:
399 | data = np.load(npz_file)
400 | except:
401 | print("npz load error get another")
402 | return get_another()
403 |
404 | rgb_files = list(data.keys())
405 | window_size = 128
406 |
407 | if len(rgb_files) <= 20:
408 | print("<20 rgbs error get another")
409 | return get_another()
410 |
411 | timestamps = [int(rgb_file.split('.')[0]) for rgb_file in rgb_files]
412 | sorted_ids = np.argsort(timestamps)
413 |
414 | rgb_files = np.array(rgb_files)[sorted_ids]
415 | timestamps = np.array(timestamps)[sorted_ids]
416 |
417 | assert (timestamps == sorted(timestamps)).all()
418 | num_frames = len(rgb_files)
419 | left_bound = 0
420 | right_bound = num_frames - 1
421 | candidate_ids = np.arange(left_bound, right_bound)
422 |
423 | # remove windows between frame -32 to 32
424 | nframe = 1
425 | nframe_view = 140 if self.val else 92
426 |
427 | id_feats = []
428 |
429 | n_skip=self.n_skip
430 |
431 | id_feat = np.array(id_feats)
432 | low = 0
433 | high = num_frames-1-n_skip*self.num_query_views
434 |
435 | if high <= low:
436 | n_skip = int(num_frames//(self.num_query_views+1))
437 | high = num_frames-1-n_skip*self.num_query_views
438 | print("high ... (x y) c")
18 |
19 | # A quick dummy dataset for the demo rgb folder
20 | class SingleVid(Dataset):
21 |
22 | # If specified here, intrinsics should be a 4-element array of [fx,fy,cx,cy] at input image resolution
23 | def __init__(self, img_dir,intrinsics=None,n_trgt=6,num_skip=0,low_res=None,hi_res=None):
24 | self.low_res,self.intrinsics,self.n_trgt,self.num_skip,self.hi_res=low_res,intrinsics,n_trgt,num_skip,hi_res
25 | if self.hi_res is None:self.hi_res=[x*2 for x in self.low_res]
26 | self.hi_res = [(x+x%64) for x in self.hi_res]
27 |
28 | self.img_paths = glob(img_dir + '/*.png') + glob(img_dir + '/*.jpg')
29 | self.img_paths.sort()
30 |
31 | def __len__(self):
32 | return len(self.img_paths)-(1+self.n_trgt)*(1+self.num_skip)
33 |
34 | def __getitem__(self, idx):
35 |
36 | n_skip=self.num_skip+1
37 | paths = self.img_paths[idx:idx+self.n_trgt*n_skip:n_skip]
38 | imgs=torch.stack([torch.from_numpy(plt.imread(path)).permute(2,0,1) for path in paths]).float()
39 |
40 | imgs_large = F.interpolate(imgs,self.hi_res,antialias=True,mode="bilinear")
41 | frames = F.interpolate(imgs,self.low_res)
42 |
43 | frames = frames/255 * 2 - 1
44 |
45 | uv = np.mgrid[0:self.low_res[0], 0:self.low_res[1]].astype(float).transpose(1, 2, 0)
46 | uv = torch.from_numpy(np.flip(uv, axis=-1).copy()).long()
47 | uv = uv/ torch.tensor([self.low_res[1], self.low_res[0]]) # uv in [0,1]
48 | uv = uv[None].expand(len(frames),-1,-1,-1).flatten(1,2)
49 |
50 | #imgs large values in [0,255], imgs in [-1,1], gt_rgb in [0,1],
51 |
52 | model_input = {
53 | "trgt_rgb": frames[1:],
54 | "ctxt_rgb": frames[:-1],
55 | "trgt_rgb_large": imgs_large[1:],
56 | "ctxt_rgb_large": imgs_large[:-1],
57 | "x_pix": uv[1:],
58 | }
59 | gt = {
60 | "trgt_rgb": ch_sec(frames[1:])*.5+.5,
61 | "ctxt_rgb": ch_sec(frames[:-1])*.5+.5,
62 | "x_pix": uv[1:],
63 | }
64 |
65 | if self.intrinsics is not None:
66 | K = torch.eye(3)
67 | K[0,0],K[1,1],K[0,2],K[1,2]=[float(x) for x in self.intrinsics.strip().split(",")]
68 | h,w=imgs[0].shape[-2:]
69 | K[:2] /= torch.tensor([w, h])[:, None]
70 | model_input["intrinsics"] = K[None].expand(self.n_trgt-1,-1,-1)
71 |
72 | return model_input,gt
73 |
74 | dataset=SingleVid(args.demo_rgb,args.intrinsics,args.vid_len,args.n_skip,args.low_res)
75 |
76 | all_poses = torch.tensor([]).cuda()
77 | all_render_rgb=torch.tensor([]).cuda()
78 | all_render_depth=torch.tensor([])
79 | for seq_i in range(len(dataset)//(dataset.n_trgt)):
80 | print(seq_i*(dataset.n_trgt),"/",len(dataset))
81 | model_input = {k:to_gpu(v)[None] for k,v in dataset.__getitem__(seq_i*(dataset.n_trgt-1))[0].items()}
82 | with torch.no_grad(): out = (model.forward if not args.render_imgs else model.render_full_img)(model_input)
83 | curr_transfs = out["poses"][0]
84 | if len(all_poses): curr_transfs = all_poses[[-1]] @ curr_transfs # integrate poses
85 | all_poses = torch.cat((all_poses,curr_transfs))
86 | all_render_rgb = torch.cat((all_render_rgb,out["rgb"][0]))
87 | all_render_depth = torch.cat((all_render_depth,out["depth"][0]))
88 |
89 | out_dir="demo_output/"+args.demo_rgb.replace("/","_")
90 | os.makedirs(out_dir,exist_ok=True)
91 | fig = plt.figure()
92 | ax = fig.add_subplot(111, projection='3d')
93 | ax.plot(*all_poses[:,:3,-1].T.cpu().numpy())
94 | ax.xaxis.set_tick_params(labelbottom=False);ax.yaxis.set_tick_params(labelleft=False);ax.zaxis.set_tick_params(labelleft=False)
95 | ax.view_init(elev=10., azim=45)
96 | plt.tight_layout()
97 | fp = os.path.join(out_dir,f"pose_plot.png");plt.savefig(fp,bbox_inches='tight');plt.close()
98 |
99 | fp = os.path.join(out_dir,f"poses.npy");np.save(fp,all_poses.cpu())
100 | if args.render_imgs:
101 | out_dir=os.path.join(out_dir,"renders")
102 | os.makedirs(out_dir,exist_ok=True)
103 | for i,(rgb,depth) in enumerate(zip(all_render_rgb.unflatten(1,model_input["trgt_rgb"].shape[-2:]),all_render_depth.unflatten(1,model_input["trgt_rgb"].shape[-2:]))):
104 | plt.imsave(os.path.join(out_dir,"render_rgb_%04d.png"%i),rgb.clip(0,1).cpu().numpy())
105 | plt.imsave(os.path.join(out_dir,"render_depth_%04d.png"%i),depth.clip(0,1).cpu().numpy())
106 |
107 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from run import *
2 |
3 | # Evaluation script
4 | import piqa,lpips
5 | from torchvision.utils import make_grid
6 | import matplotlib.pyplot as plt
7 | loss_fn_vgg = lpips.LPIPS(net='vgg').cuda()
8 | lpips,psnr,ate=0,0,0
9 |
10 | eval_dir = save_dir+"/"+args.name+datetime.datetime.now().strftime("%b%d%Y_")+str(random.randint(0,1e3))
11 | try: os.mkdir(eval_dir)
12 | except: pass
13 | torch.set_grad_enabled(False)
14 |
15 | model.n_samples=128
16 |
17 | val_dataset = get_dataset(val=True,)
18 |
19 | for eval_idx,eval_dataset_idx in enumerate(tqdm(torch.linspace(0,len(val_dataset)-1,min(args.n_eval,len(val_dataset))).int())):
20 | model_input,ground_truth = val_dataset[eval_dataset_idx]
21 |
22 | for x in (model_input,ground_truth):
23 | for k,v in x.items(): x[k] = v[None].cuda() # collate
24 |
25 | model_out = model.render_full_img(model_input)
26 |
27 | # remove last frame since used as ctxt when n_ctxt=2
28 | rgb_est,rgb_gt = [rearrange(img[:,:-1].clip(0,1),"b trgt (x y) c -> (b trgt) c x y",x=model_input["trgt_rgb"].size(-2))
29 | for img in (model_out["fine_rgb" if "fine_rgb" in model_out else "rgb"],ground_truth["trgt_rgb"])]
30 | depth_est = rearrange(model_out["depth"][:,:-1],"b trgt (x y) c -> (b trgt) c x y",x=model_input["trgt_rgb"].size(-2))
31 |
32 | psnr += piqa.PSNR()(rgb_est.clip(0,1).contiguous(),rgb_gt.clip(0,1).contiguous())
33 | lpips += loss_fn_vgg(rgb_est*2-1,rgb_gt*2-1).mean()
34 |
35 | print(args.save_imgs)
36 | if args.save_imgs:
37 | fp = os.path.join(eval_dir,f"{eval_idx}_est.png");plt.imsave(fp,make_grid(rgb_est).permute(1,2,0).clip(0,1).cpu().numpy())
38 | if depth_est.size(1)==3: fp = os.path.join(eval_dir,f"{eval_idx}_depth.png");plt.imsave(fp,make_grid(depth_est).clip(0,1).permute(1,2,0).cpu().numpy())
39 | fp = os.path.join(eval_dir,f"{eval_idx}_gt.png");plt.imsave(fp,make_grid(rgb_gt).permute(1,2,0).cpu().numpy())
40 | print(fp)
41 |
42 |
43 | if args.save_imgs and args.save_ind: # save individual images separately
44 | eval_idx_dir = os.path.join(eval_dir,f"dir_{eval_idx}")
45 |
46 | try: os.mkdir(eval_idx_dir)
47 | except: pass
48 | ctxt_rgbs = torch.cat((model_input["ctxt_rgb"][:,0],model_input["trgt_rgb"][:,model_input["trgt_rgb"].size(1)//2],model_input["trgt_rgb"][:,-1]))*.5+.5
49 | fp = os.path.join(eval_idx_dir,f"ctxt0.png");plt.imsave(fp,ctxt_rgbs[0].clip(0,1).permute(1,2,0).cpu().numpy())
50 | fp = os.path.join(eval_idx_dir,f"ctxt1.png");plt.imsave(fp,ctxt_rgbs[1].clip(0,1).permute(1,2,0).cpu().numpy())
51 | fp = os.path.join(eval_idx_dir,f"ctxt2.png");plt.imsave(fp,ctxt_rgbs[2].clip(0,1).permute(1,2,0).cpu().numpy())
52 | for i,(rgb_est,rgb_gt,depth) in enumerate(zip(rgb_est,rgb_gt,depth_est)):
53 | fp = os.path.join(eval_idx_dir,f"{i}_est.png");plt.imsave(fp,rgb_est.clip(0,1).permute(1,2,0).cpu().numpy())
54 | print(fp)
55 | fp = os.path.join(eval_idx_dir,f"{i}_gt.png");plt.imsave(fp,rgb_gt.clip(0,1).permute(1,2,0).cpu().numpy())
56 | if depth_est.size(1)==3: fp = os.path.join(eval_idx_dir,f"{i}_depth.png");plt.imsave(fp,depth.permute(1,2,0).cpu().clip(1e-4,1-1e-4).numpy())
57 |
58 | # Pose plotting/evaluation
59 | if "poses" in model_out:
60 | import scipy.spatial
61 | pose_est,pose_gt = model_out["poses"][0][:,:3,-1].cpu(),model_input["trgt_c2w"][0][:,:3,-1].cpu()
62 | pose_gt,pose_est,_ = scipy.spatial.procrustes(pose_gt.numpy(),pose_est.numpy())
63 | ate += ((pose_est-pose_gt)**2).mean()
64 | if args.save_imgs:
65 | fig = plt.figure()
66 | ax = fig.add_subplot(111, projection='3d')
67 | ax.plot(*pose_gt.T)
68 | ax.plot(*pose_est.T)
69 | ax.xaxis.set_tick_params(labelbottom=False)
70 | ax.yaxis.set_tick_params(labelleft=False)
71 | ax.zaxis.set_tick_params(labelleft=False)
72 | ax.view_init(elev=10., azim=45)
73 | plt.tight_layout()
74 | fp = os.path.join(eval_dir,f"{eval_idx}_pose_plot.png");plt.savefig(fp,bbox_inches='tight');plt.close()
75 | if args.save_ind:
76 | for i in range(len(pose_est)):
77 | fig = plt.figure()
78 | ax = fig.add_subplot(111, projection='3d')
79 | ax.plot(*pose_gt.T,color="black")
80 | ax.plot(*pose_est.T,alpha=0)
81 | ax.plot(*pose_est[:i].T,alpha=1,color="red")
82 | ax.xaxis.set_tick_params(labelbottom=False)
83 | ax.yaxis.set_tick_params(labelleft=False)
84 | ax.zaxis.set_tick_params(labelleft=False)
85 | ax.view_init(elev=10., azim=45)
86 | plt.tight_layout()
87 | fp = os.path.join(eval_idx_dir,f"pose_{i}.png"); plt.savefig(fp,bbox_inches='tight');plt.close()
88 |
89 | print(f"psnr {psnr/(1+eval_idx)}, lpips {lpips/(1+eval_idx)}, ate {(ate/(1+eval_idx))**.5}, eval_idx {eval_idx}", flush=True)
90 |
91 |
--------------------------------------------------------------------------------
/geometry.py:
--------------------------------------------------------------------------------
1 | """Multi-view geometry & proejction code.."""
2 | import torch
3 | from einops import rearrange, repeat
4 | from torch.nn import functional as F
5 | import numpy as np
6 |
7 | def d6_to_rotmat(d6):
8 | a1, a2 = d6[..., :3], d6[..., 3:]
9 | b1 = F.normalize(a1, dim=-1)
10 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
11 | b2 = F.normalize(b2, dim=-1)
12 | b3 = torch.cross(b1, b2, dim=-1)
13 | return torch.stack((b1, b2, b3), dim=-2)
14 |
15 | def time_interp_poses(pose_inp,time_i,n_trgt,eye_pts):
16 | i,j = max(0,int(time_i*(n_trgt-1))-1),int(time_i*(n_trgt-1))
17 | pose_interp = camera_interp(*pose_inp[:,[i,j]].unbind(1),time_i)
18 | if i==j: pose_interp=pose_inp[:,0]
19 | pose_interp = repeat(pose_interp,"b x y -> b trgt x y",trgt=n_trgt)
20 | return pose_interp
21 | eye_pts = torch.cat((eye_pts,torch.ones_like(eye_pts[...,[0]])),-1)
22 | query_pts = torch.einsum("bcij,bcdkj->bcdki",pose_interp,eye_pts)[...,:3]
23 | return query_pts
24 |
25 | def pixel_aligned_features(
26 | coords_3d_world, cam2world, intrinsics, img_features, interp="bilinear",padding_mode="border",
27 | ):
28 | # Args:
29 | # coords_3d_world: shape (b, n, 3)
30 | # cam2world: camera pose of shape (..., 4, 4)
31 |
32 | # project 3d points to 2D
33 | c3d_world_hom = homogenize_points(coords_3d_world)
34 | c3d_cam_hom = transform_world2cam(c3d_world_hom, cam2world)
35 | c2d_cam, depth = project(c3d_cam_hom, intrinsics.unsqueeze(1))
36 |
37 | # now between 0 and 1. Map to -1 and 1
38 | c2d_norm = (c2d_cam - 0.5) * 2
39 | c2d_norm = rearrange(c2d_norm, "b n ch -> b n () ch")
40 | c2d_norm = c2d_norm[..., :2]
41 |
42 | # grid_sample
43 | feats = F.grid_sample(
44 | img_features, c2d_norm, align_corners=True, padding_mode=padding_mode, mode=interp
45 | )
46 | feats = feats.squeeze(-1) # b ch n
47 |
48 | feats = rearrange(feats, "b ch n -> b n ch")
49 | return feats, c3d_cam_hom[..., :3], c2d_cam
50 |
51 | # https://gist.github.com/mkocabas/54ea2ff3b03260e3fedf8ad22536f427
52 | def procrustes(S1, S2,weights=None):
53 |
54 | if len(S1.shape)==4:
55 | out = procrustes(S1.flatten(0,1),S2.flatten(0,1),weights.flatten(0,1) if weights is not None else None)
56 | return out[0],out[1].unflatten(0,S1.shape[:2])
57 | '''
58 | Computes a similarity transform (sR, t) that takes
59 | a set of 3D points S1 (BxNx3) closest to a set of 3D points, S2,
60 | where R is an 3x3 rotation matrix, t 3x1 translation, s scale. / mod : assuming scale is 1
61 | i.e. solves the orthogonal Procrutes problem.
62 | '''
63 | with torch.autocast(device_type='cuda', dtype=torch.float32):
64 | S1 = S1.permute(0,2,1)
65 | S2 = S2.permute(0,2,1)
66 | if weights is not None:
67 | weights=weights.permute(0,2,1)
68 | transposed = True
69 |
70 | if weights is None:
71 | weights = torch.ones_like(S1[:,:1])
72 |
73 | # 1. Remove mean.
74 | weights_norm = weights/(weights.sum(-1,keepdim=True)+1e-6)
75 | mu1 = (S1*weights_norm).sum(2,keepdim=True)
76 | mu2 = (S2*weights_norm).sum(2,keepdim=True)
77 |
78 | X1 = S1 - mu1
79 | X2 = S2 - mu2
80 |
81 | diags = torch.stack([torch.diag(w.squeeze(0)) for w in weights]) # does batched version exist?
82 |
83 | # 3. The outer product of X1 and X2.
84 | K = (X1@diags).bmm(X2.permute(0,2,1))
85 |
86 | # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K.
87 | U, s, V = torch.svd(K)
88 |
89 | # Construct Z that fixes the orientation of R to get det(R)=1.
90 | Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0)
91 | Z = Z.repeat(U.shape[0],1,1)
92 | Z[:,-1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0,2,1))))
93 |
94 | # Construct R.
95 | R = V.bmm(Z.bmm(U.permute(0,2,1)))
96 |
97 | # 6. Recover translation.
98 | t = mu2 - ((R.bmm(mu1)))
99 |
100 | # 7. Error:
101 | S1_hat = R.bmm(S1) + t
102 |
103 | # Combine recovered transformation as single matrix
104 | R_=torch.eye(4)[None].expand(S1.size(0),-1,-1).to(S1)
105 | R_[:,:3,:3]=R
106 | T_=torch.eye(4)[None].expand(S1.size(0),-1,-1).to(S1)
107 | T_[:,:3,-1]=t.squeeze(-1)
108 | S_=torch.eye(4)[None].expand(S1.size(0),-1,-1).to(S1)
109 | transf = T_@S_@R_
110 |
111 | return (S1_hat-S2).square().mean(),transf
112 |
113 | def symmetric_orthogonalization(x):
114 | # https://github.com/amakadia/svd_for_pose
115 | m = x.view(-1, 3, 3).type(torch.float)
116 | u, s, v = torch.svd(m)
117 | vt = torch.transpose(v, 1, 2)
118 | det = torch.det(torch.matmul(u, vt))
119 | det = det.view(-1, 1, 1)
120 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1)
121 | r = torch.matmul(u, vt)
122 | return r
123 |
124 | def rigidity_loss(ctx_xyz,trgt_xyz):
125 |
126 | x_points = ctx_xyz #.view(-1, 3)
127 | y_points = trgt_xyz #.view(-1, 3)
128 |
129 | x_mean = x_points.mean(1, keepdim=True) # x_mean and y_mean define the global translation
130 | y_mean = y_points.mean(1, keepdim=True)
131 |
132 | x_points_centered = x_points - x_mean
133 | y_points_centered = y_points - y_mean
134 |
135 | x_scale = torch.sqrt(x_points_centered.pow(2).sum(2, keepdim=True)).mean(1, keepdim=True)
136 | x_points_normalized = x_points_centered / x_scale # x_scale and y_scale define the global scales
137 |
138 | y_scale = torch.sqrt(y_points_centered.pow(2).sum(2, keepdim=True)).mean(1, keepdim=True)
139 | y_points_normalized = y_points_centered / y_scale
140 |
141 | M = torch.einsum('b i k, b i j -> b k j', x_points_normalized, y_points_normalized) # M is the covariance matrix
142 | R = symmetric_orthogonalization(M) #this is the rotation matrix
143 |
144 | # Compute the transformed ctxt points
145 | x_points_transformed = torch.matmul(x_points_normalized, R)
146 |
147 | loss = (x_points_transformed - y_points_normalized).pow(2).mean()
148 | return loss
149 |
150 |
151 | def homogenize_points(points: torch.Tensor):
152 | """Appends a "1" to the coordinates of a (batch of) points of dimension DIM.
153 |
154 | Args:
155 | points: points of shape (..., DIM)
156 |
157 | Returns:
158 | points_hom: points with appended "1" dimension.
159 | """
160 | ones = torch.ones_like(points[..., :1], device=points.device)
161 | return torch.cat((points, ones), dim=-1)
162 |
163 |
164 | def homogenize_vecs(vectors: torch.Tensor):
165 | """Appends a "0" to the coordinates of a (batch of) vectors of dimension DIM.
166 |
167 | Args:
168 | vectors: vectors of shape (..., DIM)
169 |
170 | Returns:
171 | vectors_hom: points with appended "0" dimension.
172 | """
173 | zeros = torch.zeros_like(vectors[..., :1], device=vectors.device)
174 | return torch.cat((vectors, zeros), dim=-1)
175 |
176 |
177 | def unproject(
178 | xy_pix: torch.Tensor, z: torch.Tensor, intrinsics: torch.Tensor
179 | ) -> torch.Tensor:
180 | """Unproject (lift) 2D pixel coordinates x_pix and per-pixel z coordinate
181 | to 3D points in camera coordinates.
182 |
183 | Args:
184 | xy_pix: 2D pixel coordinates of shape (..., 2)
185 | z: per-pixel depth, defined as z coordinate of shape (..., 1)
186 | intrinscis: camera intrinscics of shape (..., 3, 3)
187 |
188 | Returns:
189 | xyz_cam: points in 3D camera coordinates.
190 | """
191 | xy_pix_hom = homogenize_points(xy_pix)
192 | xyz_cam = torch.einsum("...ij,...kj->...ki", intrinsics.inverse(), xy_pix_hom)
193 | xyz_cam *= z
194 | return xyz_cam
195 |
196 |
197 | def transform_world2cam(
198 | xyz_world_hom: torch.Tensor, cam2world: torch.Tensor
199 | ) -> torch.Tensor:
200 | """Transforms points from 3D world coordinates to 3D camera coordinates.
201 |
202 | Args:
203 | xyz_world_hom: homogenized 3D points of shape (..., 4)
204 | cam2world: camera pose of shape (..., 4, 4)
205 |
206 | Returns:
207 | xyz_cam: points in camera coordinates.
208 | """
209 | world2cam = torch.inverse(cam2world)
210 | return transform_rigid(xyz_world_hom, world2cam)
211 |
212 |
213 | def transform_cam2world(
214 | xyz_cam_hom: torch.Tensor, cam2world: torch.Tensor
215 | ) -> torch.Tensor:
216 | """Transforms points from 3D world coordinates to 3D camera coordinates.
217 |
218 | Args:
219 | xyz_cam_hom: homogenized 3D points of shape (..., 4)
220 | cam2world: camera pose of shape (..., 4, 4)
221 |
222 | Returns:
223 | xyz_world: points in camera coordinates.
224 | """
225 | return transform_rigid(xyz_cam_hom, cam2world)
226 |
227 |
228 | def transform_rigid(xyz_hom: torch.Tensor, T: torch.Tensor) -> torch.Tensor:
229 | """Apply a rigid-body transform to a (batch of) points / vectors.
230 |
231 | Args:
232 | xyz_hom: homogenized 3D points of shape (..., 4)
233 | T: rigid-body transform matrix of shape (..., 4, 4)
234 |
235 | Returns:
236 | xyz_trans: transformed points.
237 | """
238 | return torch.einsum("...ij,...kj->...ki", T, xyz_hom)
239 |
240 |
241 | def get_unnormalized_cam_ray_directions(
242 | xy_pix: torch.Tensor, intrinsics: torch.Tensor
243 | ) -> torch.Tensor:
244 | return unproject(
245 | xy_pix,
246 | torch.ones_like(xy_pix[..., :1], device=xy_pix.device),
247 | intrinsics=intrinsics,
248 | )
249 |
250 |
251 | def get_world_rays_(
252 | xy_pix: torch.Tensor,
253 | intrinsics: torch.Tensor,
254 | cam2world: torch.Tensor,
255 | ) -> torch.Tensor:
256 |
257 | if cam2world is None:
258 | cam2world = torch.eye(4)[None].expand(xy_pix.size(0),-1,-1).to(xy_pix)
259 |
260 | # Get camera origin of camera 1
261 | cam_origin_world = cam2world[..., :3, -1]
262 |
263 | # Get ray directions in cam coordinates
264 | ray_dirs_cam = get_unnormalized_cam_ray_directions(xy_pix, intrinsics)
265 | ray_dirs_cam = ray_dirs_cam / ray_dirs_cam.norm(dim=-1, keepdim=True)
266 |
267 | # Homogenize ray directions
268 | rd_cam_hom = homogenize_vecs(ray_dirs_cam)
269 |
270 | # Transform ray directions to world coordinates
271 | rd_world_hom = transform_cam2world(rd_cam_hom, cam2world)
272 |
273 | # Tile the ray origins to have the same shape as the ray directions.
274 | # Currently, ray origins have shape (batch, 3), while ray directions have shape
275 | cam_origin_world = repeat(
276 | cam_origin_world, "b ch -> b num_rays ch", num_rays=ray_dirs_cam.shape[1]
277 | )
278 |
279 | # Return tuple of cam_origins, ray_world_directions
280 | return cam_origin_world, rd_world_hom[..., :3]
281 |
282 | def get_world_rays(
283 | xy_pix: torch.Tensor,
284 | intrinsics: torch.Tensor,
285 | cam2world: torch.Tensor,
286 | ) -> torch.Tensor:
287 | if len(xy_pix.shape)==4:
288 | out = get_world_rays_(xy_pix.flatten(0,1),intrinsics.flatten(0,1),cam2world.flatten(0,1) if cam2world is not None else None)
289 | return [x.unflatten(0,xy_pix.shape[:2]) for x in out]
290 | return get_world_rays_(xy_pix,intrinsics,cam2world)
291 |
292 |
293 |
294 |
295 | def get_opencv_pixel_coordinates(
296 | y_resolution: int,
297 | x_resolution: int,
298 | device='cpu'
299 | ):
300 | """For an image with y_resolution and x_resolution, return a tensor of pixel coordinates
301 | normalized to lie in [0, 1], with the origin (0, 0) in the top left corner,
302 | the x-axis pointing right, the y-axis pointing down, and the bottom right corner
303 | being at (1, 1).
304 |
305 | Returns:
306 | xy_pix: a meshgrid of values from [0, 1] of shape
307 | (y_resolution, x_resolution, 2)
308 | """
309 | i, j = torch.meshgrid(
310 | torch.linspace(0, 1, steps=x_resolution, device=device),
311 | torch.linspace(0, 1, steps=y_resolution, device=device),
312 | )
313 |
314 | xy_pix = torch.stack([i.float(), j.float()], dim=-1).permute(1, 0, 2)
315 | return xy_pix
316 |
317 |
318 | def project(xyz_cam_hom: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
319 | """Projects homogenized 3D points xyz_cam_hom in camera coordinates
320 | to pixel coordinates.
321 |
322 | Args:
323 | xyz_cam_hom: 3D points of shape (..., 4)
324 | intrinsics: camera intrinscics of shape (..., 3, 3)
325 |
326 | Returns:
327 | xy: homogeneous pixel coordinates of shape (..., 3) (final coordinate is 1)
328 | """
329 | if len(intrinsics.shape)==len(xyz_cam_hom.shape): intrinsics=intrinsics.unsqueeze(1)
330 | xyw = torch.einsum("...ij,...j->...i", intrinsics, xyz_cam_hom[..., :3])
331 | z = xyw[..., -1:]
332 | xyw = xyw / (z + 1e-5) # z-divide
333 | return xyw[..., :2], z
334 |
335 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
336 | # from pytorch3d
337 | """
338 | Returns torch.sqrt(torch.max(0, x))
339 | but with a zero subgradient where x is 0.
340 | """
341 | ret = torch.zeros_like(x)
342 | positive_mask = x > 0
343 | ret[positive_mask] = torch.sqrt(x[positive_mask])
344 | return ret
345 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
346 | # from pytorch3d
347 | """
348 | Convert rotations given as rotation matrices to quaternions.
349 | Args:
350 | matrix: Rotation matrices as tensor of shape (..., 3, 3).
351 | Returns:
352 | quaternions with real part first, as tensor of shape (..., 4).
353 | """
354 | if matrix.size(-1) != 3 or matrix.size(-2) != 3:
355 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
356 |
357 | batch_dim = matrix.shape[:-2]
358 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
359 | matrix.reshape(batch_dim + (9,)), dim=-1
360 | )
361 |
362 | q_abs = _sqrt_positive_part(
363 | torch.stack(
364 | [
365 | 1.0 + m00 + m11 + m22,
366 | 1.0 + m00 - m11 - m22,
367 | 1.0 - m00 + m11 - m22,
368 | 1.0 - m00 - m11 + m22,
369 | ],
370 | dim=-1,
371 | )
372 | )
373 |
374 | # we produce the desired quaternion multiplied by each of r, i, j, k
375 | quat_by_rijk = torch.stack(
376 | [
377 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
378 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
379 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
380 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
381 | ],
382 | dim=-2,
383 | )
384 |
385 | # We floor here at 0.1 but the exact level is not important; if q_abs is small,
386 | # the candidate won't be picked.
387 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
388 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
389 |
390 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
391 | # forall i; we pick the best-conditioned one (with the largest denominator)
392 |
393 | return quat_candidates[
394 | F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
395 | ].reshape(batch_dim + (4,))
396 | def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
397 | # from pytorch3d
398 | """
399 | Convert rotations given as quaternions to rotation matrices.
400 | Args:
401 | quaternions: quaternions with real part first,
402 | as tensor of shape (..., 4).
403 | Returns:
404 | Rotation matrices as tensor of shape (..., 3, 3).
405 | """
406 | r, i, j, k = torch.unbind(quaternions, -1)
407 | two_s = 2.0 / (quaternions * quaternions).sum(-1)
408 |
409 | o = torch.stack(
410 | (
411 | 1 - two_s * (j * j + k * k),
412 | two_s * (i * j - k * r),
413 | two_s * (i * k + j * r),
414 | two_s * (i * j + k * r),
415 | 1 - two_s * (i * i + k * k),
416 | two_s * (j * k - i * r),
417 | two_s * (i * k - j * r),
418 | two_s * (j * k + i * r),
419 | 1 - two_s * (i * i + j * j),
420 | ),
421 | -1,
422 | )
423 | return o.reshape(quaternions.shape[:-1] + (3, 3))
424 | def camera_interp(camera1, camera2, t):
425 | if len(camera1.shape)==3:
426 | return torch.stack([camera_interp(cam1,cam2,t) for cam1,cam2 in zip(camera1,camera2)])
427 | # Extract the rotation component from the camera matrices
428 | q1 = matrix_to_quaternion(camera1[:3, :3])
429 | q2 = matrix_to_quaternion(camera2[:3, :3])
430 |
431 | # todo add negative quaternion check to not go long way around
432 |
433 | # Interpolate the quaternions using slerp
434 | cos_angle = (q1 * q2).sum(dim=0)
435 | angle = torch.acos(cos_angle.clamp(-1, 1))
436 | q_interpolated = (q1 * torch.sin((1 - t) * angle) + q2 * torch.sin(t * angle)) / torch.sin(angle)
437 | rotation_interpolated = quaternion_to_matrix(q_interpolated)
438 |
439 | # Interpolate the translation component
440 | translation_interpolated = torch.lerp(camera1[:3,-1], camera2[:3,-1], t)
441 |
442 | cam_interpolated = torch.eye(4)
443 | cam_interpolated[:3,:3]=rotation_interpolated
444 | cam_interpolated[:3,-1]=translation_interpolated
445 |
446 | return cam_interpolated.cuda()
447 |
--------------------------------------------------------------------------------
/mlp_modules.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import numpy as np
3 | from collections import OrderedDict
4 | import torch
5 | from torch import nn
6 | from einops import repeat,rearrange
7 |
8 | def init_weights_normal(m):
9 | if type(m) == nn.Linear:
10 | if hasattr(m, 'weight'):
11 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
12 |
13 | class ResMLP(nn.Module):
14 | def __init__(self, ch_in, ch_mod, out_ch, num_res_block=1 ):
15 | super().__init__()
16 |
17 | self.res_blocks = nn.ModuleList([
18 | nn.Sequential(nn.Linear(ch_mod,ch_mod),nn.ReLU(),
19 | nn.LayerNorm([ch_mod], elementwise_affine=True),
20 | nn.Linear(ch_mod,ch_mod),nn.ReLU())
21 | for _ in range(num_res_block)
22 | ])
23 |
24 | self.proj_in = nn.Linear(ch_in,ch_mod)
25 | self.out = nn.Linear(ch_mod,out_ch)
26 |
27 | def forward(self,x):
28 |
29 | x = self.proj_in(x)
30 |
31 | for i,block in enumerate(self.res_blocks):
32 |
33 | x_in = x
34 |
35 | x = block(x)
36 |
37 | if i!=len(self.res_blocks)-1: x = x + x_in
38 |
39 | return self.out(x)
40 |
41 | # FILM, but just the biases, not scalings - featurewise additive modulation
42 | # "x" is the input coordinate and "y" is the conditioning feature (img features, for exmaple)
43 | class ResFAMLP(nn.Module):
44 | def __init__(self, ch_in_x,ch_in_y, ch_mod, out_ch, num_res_block=1, last_res=False):
45 | super().__init__()
46 |
47 | self.res_blocks = nn.ModuleList([
48 | nn.Sequential(nn.Linear(ch_mod,ch_mod),nn.ReLU(),
49 | nn.LayerNorm([ch_mod], elementwise_affine=True),
50 | nn.Linear(ch_mod,ch_mod),nn.ReLU())
51 | for _ in range(num_res_block)
52 | ])
53 |
54 | self.last_res=last_res
55 | self.proj_in = nn.Linear(ch_in_x,ch_mod)
56 | self.modulators = nn.ModuleList([nn.Linear(ch_in_y,ch_mod) for _ in range(num_res_block)])
57 | self.out = nn.Linear(ch_mod,out_ch)
58 |
59 | def forward(self,x,y):
60 |
61 | x = self.proj_in(x)
62 |
63 | for i,(block,modulator) in enumerate(zip(self.res_blocks,self.modulators)):
64 |
65 | x_in = x + modulator(y)
66 |
67 | x = block(x)
68 |
69 | if i!=len(self.res_blocks)-1 or self.last_res: x = x + x_in
70 |
71 | return self.out(x)
72 |
73 |
74 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | """Code for pixelnerf and alternatives."""
2 | import torch, torchvision
3 | from torch import nn
4 | from einops import rearrange, repeat
5 | from torch.nn import functional as F
6 | import numpy as np
7 | import time
8 | import timm
9 | from matplotlib import cm
10 | import kornia
11 | from tqdm import tqdm
12 |
13 | import conv_modules
14 | import mlp_modules
15 | import geometry
16 | import renderer
17 | from geometry import procrustes
18 |
19 | ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
20 | ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
21 | hom = lambda x: torch.cat((x,torch.ones_like(x[...,[0]])),-1)
22 |
23 | class FlowCam(nn.Module):
24 | def __init__(self, near=1.75, far=8, n_samples=64,num_view=2,logspace=True,use_trgt_crop=False,use_midas=False):
25 | super().__init__()
26 |
27 | self.raft_midas_net = RaftAndMidas(raft=True,midas=use_midas)
28 |
29 | self.near,self.far,self.n_samples = near,far,n_samples
30 | self.logspace=logspace
31 | self.use_trgt_crop=use_trgt_crop
32 |
33 | self.num_view=num_view
34 |
35 | self.nerf_enc_flow = conv_modules.PixelNeRFEncoder(in_ch=5,use_first_pool=False)
36 | #self.nerf_enc_flow = conv_modules.PixelNeRFEncoder(in_ch=3,use_first_pool=False)
37 | phi_latent=64
38 | self.renderer = MLPImplicit(latent_in=phi_latent,inner_latent=phi_latent,white_back=False,num_first_res_block=2)
39 |
40 | self.pos_encoder = PositionalEncodingNoFreqFactor(3,5)
41 |
42 | self.ray_comb = nn.Sequential(torch.nn.Conv2d(512+33,512,3,padding=1),nn.ReLU(),
43 | torch.nn.Conv2d(512,512,3,padding=1),nn.ReLU(),
44 | torch.nn.Conv2d(512,phi_latent-3,3,padding=1)
45 | )
46 | self.corr_weighter_perpoint = nn.Sequential(
47 | nn.Linear(128,128),nn.ReLU(),
48 | nn.Linear(128,128),nn.ReLU(),
49 | nn.Linear(128,128),nn.ReLU(),
50 | nn.Linear(128,1),
51 | )
52 | self.corr_weighter_perpoint.apply(mlp_modules.init_weights_normal)
53 |
54 | def encoder(self,model_input):
55 | imsize=model_input["ctxt_rgb"].shape[-2:]
56 |
57 | if "backbone_feats" in model_input: return model_input["backbone_feats"]
58 | if "bwd_flow" not in model_input: model_input = self.raft_midas_net(model_input)
59 |
60 | # ctxt[1:]==trgt[:-1], using this property to avoid redundant computation
61 | all_rgb = torch.cat((model_input["ctxt_rgb"][:,:1],model_input["trgt_rgb"]),1)
62 | all_flow = torch.cat((model_input["bwd_flow"][:,:1],model_input["bwd_flow"]),1)
63 |
64 | # Resnet rgb+flow feats
65 | rgb_flow = torch.cat((all_rgb,all_flow*4),2)
66 | rgb_flow_feats = self.nerf_enc_flow(rgb_flow,imsize)
67 |
68 | # Add rays to features for some amount of focal length information
69 | rds = self.pos_encoder(geometry.get_world_rays(model_input["x_pix"], model_input["intrinsics"], None)[1])
70 | rds = ch_fst(torch.cat((rds[:,:1],rds),1),imsize[0])
71 |
72 | all_feats = self.ray_comb(torch.cat((rgb_flow_feats,rds),2).flatten(0,1)).unflatten(0,all_rgb.shape[:2])
73 | all_feats = torch.cat((all_feats,all_rgb),2)
74 | model_input["backbone_feats"] = all_feats
75 |
76 | return all_feats
77 |
78 | def forward(self, model_input, trgt_rays=None,ctxt_rays=None,poses=None):
79 |
80 | imsize=model_input["ctxt_rgb"].shape[-2:]
81 | (b,n_ctxt),n_trgt=model_input["ctxt_rgb"].shape[:2],model_input["trgt_rgb"].size(1)
82 | add_ctxt = lambda x: torch.cat((x[:,:1],x),1)
83 | if trgt_rays is None: trgt_rays,ctxt_rays = self.make_rays(model_input)
84 |
85 | # Encode images
86 | backbone_feats= self.encoder(model_input)
87 |
88 | # Expand identity camera into 3d points and render rays
89 | ros, rds = geometry.get_world_rays(add_ctxt(model_input["x_pix"]), add_ctxt(model_input["intrinsics"]), None)
90 | eye_pts, z_vals = renderer.sample_points_along_rays(self.near, self.far, self.n_samples, ros, rds, device=model_input["x_pix"].device,logspace=self.logspace)
91 | eye_render, eye_depth, eye_weights= self.renderer( backbone_feats, eye_pts[:,:,ctxt_rays], add_ctxt(model_input["intrinsics"]), z_vals,identity=True)
92 |
93 | # Render out correspondence's surface point now
94 | corresp_uv = (model_input["x_pix"]+ch_sec(model_input["bwd_flow"]))[:,:,ctxt_rays]
95 | ros, rds = geometry.get_world_rays(corresp_uv, model_input["intrinsics"], None)
96 | corresp_pts, _ = renderer.sample_points_along_rays(self.near, self.far, self.n_samples, ros, rds, device=model_input["x_pix"].device,logspace=self.logspace)
97 | _, _, corresp_weights= self.renderer( backbone_feats[:,:-1], corresp_pts, model_input["intrinsics"], z_vals,identity=True)
98 |
99 | # Predict correspondence weights as function of source feature and correspondence
100 | corresp_feat = F.grid_sample(backbone_feats[:,:-1].flatten(0,1),corresp_uv.flatten(0,1).unsqueeze(1)*2-1).squeeze(-2).permute(0,2,1).unflatten(0,(b,n_trgt))
101 | corr_weights = self.corr_weighter_perpoint(torch.cat((corresp_feat,ch_sec(backbone_feats)[:,1:,ctxt_rays]),-1)).sigmoid()
102 |
103 | # Weighted procrustes on scene flow
104 | if poses is None:
105 | adj_transf = procrustes((eye_pts[:,1:,ctxt_rays]*eye_weights[:,1:]).sum(-2), (corresp_weights*corresp_pts).sum(-2), corr_weights)[1]
106 | poses = adj_transf
107 | for i in range(n_trgt-1,0,-1):
108 | poses = torch.cat((poses[:,:i],poses[:,[i-1]]@poses[:,i:]),1)
109 | else: adj_transf = torch.cat((poses[:,:1],poses[:,:-1].inverse()@poses[:,1:]),1)
110 |
111 | # Render out trgt frames from [ctxt=0, ctxt=-1, ctxt=middle][: num context frames ]
112 | render = self.render(model_input,poses,trgt_rays)
113 |
114 | # Pose induced flow using ctxt depth and then multiview rendered depth (latter not used in paper experiments)
115 | corresp_surf_from_pose = (torch.einsum("bcij,bcdkj->bcdki",adj_transf,hom(eye_pts[:,1:]))[:,:,ctxt_rays,...,:3]*eye_weights[:,1:]).sum(-2)
116 | flow_from_pose = geometry.project(corresp_surf_from_pose.flatten(0,1), model_input["intrinsics"].flatten(0,1))[0].unflatten(0,(b,n_trgt))-model_input["x_pix"][:,:,ctxt_rays]
117 | corresp_surf_from_pose_render = (torch.einsum("bcij,bcdkj->bcdki",adj_transf,hom(eye_pts[:,1:]))[:,:,trgt_rays,...,:3]*render["weights"]).sum(-2)
118 | flow_from_pose_render = geometry.project(corresp_surf_from_pose_render.flatten(0,1), model_input["intrinsics"].flatten(0,1))[0].unflatten(0,(b,n_trgt))-model_input["x_pix"][:,:,trgt_rays]
119 |
120 | out= {
121 | "rgb":render["rgb"],
122 | "ctxt_rgb":eye_render[:,:-1],
123 | "poses":poses,
124 | "depth":render["depth"],
125 | "ctxt_depth":eye_depth[:,:-1],
126 | "corr_weights": corr_weights,
127 | "flow_from_pose": flow_from_pose,
128 | "flow_from_pose_render": flow_from_pose_render,
129 | "ctxt_rays": ctxt_rays.to(eye_render)[None].expand(b,-1),
130 | "trgt_rays": trgt_rays.to(eye_render)[None].expand(b,-1),
131 | "flow_inp": model_input["bwd_flow"],
132 | }
133 | if "ctxt_depth" in model_input:
134 | out["ctxt_depth_inp"]=model_input["ctxt_depth"]
135 | out["trgt_depth_inp"]=model_input["trgt_depth"]
136 | return out
137 |
138 | def render(self,model_input,poses,trgt_rays,query_pose=None):
139 | if query_pose is None: query_pose=poses
140 |
141 | ros, rds = geometry.get_world_rays(model_input["x_pix"][:,:query_pose.size(1),trgt_rays], model_input["intrinsics"][:,:query_pose.size(1)], query_pose)
142 | query_pts, z_vals = renderer.sample_points_along_rays(self.near, self.far, self.n_samples, ros, rds, device=model_input["x_pix"].device,logspace=self.logspace)
143 |
144 | ctxt_poses = torch.cat((torch.eye(4).cuda()[None].expand(poses.size(0),-1,-1)[:,None],poses),1)
145 | ctxt_idxs = [0,-1,model_input["trgt_rgb"].size(1)//2][:self.num_view]
146 | ctxt_pts = torch.einsum("bvcij,bcdkj->bvcdki",ctxt_poses[:,ctxt_idxs].inverse().unsqueeze(2),hom(query_pts))[...,:3] # in coord system of ctxt frames
147 | rgb, depth, weights = self.renderer(model_input["backbone_feats"][:,ctxt_idxs], ctxt_pts,model_input["intrinsics"][:,:query_pose.size(1)],z_vals)
148 | return {"rgb":rgb,"depth":depth,"weights":weights}
149 |
150 | def make_rays(self,model_input):
151 |
152 | imsize=model_input["ctxt_rgb"].shape[-2:]
153 |
154 | # Pick random subset of rays
155 | crop_res=32 if self.n_samples<100 else 16
156 | if self.use_trgt_crop: # choose random crop of rays instead of random set
157 | start_x,start_y = np.random.randint(0,imsize[1]-crop_res-1),np.random.randint(0,imsize[0]-crop_res-1)
158 | trgt_rays = torch.arange(imsize[0]*imsize[1]).view(imsize[0],imsize[1])[start_y:start_y+crop_res,start_x:start_x+crop_res].flatten()
159 | else:
160 | trgt_rays = torch.randperm(imsize[0]*imsize[1]-2)[:crop_res**2]
161 |
162 | ctxt_rays = torch.randperm(imsize[0]*imsize[1]-2)[:(32 if torch.is_grad_enabled() else 48)**2]
163 | return trgt_rays,ctxt_rays
164 |
165 | def render_full_img(self,model_input, query_pose=None,sample_out=None):
166 | num_chunk=8 if self.n_samples<90 else 16
167 |
168 | imsize=model_input["ctxt_rgb"].shape[-2:]
169 | (b,n_ctxt),n_trgt=model_input["ctxt_rgb"].shape[:2],model_input["trgt_rgb"].size(1)
170 |
171 | if sample_out is None: sample_out = self(model_input)
172 |
173 | # Render out image iteratively and aggregate outputs
174 | outs=[]
175 | for j,trgt_rays in enumerate(torch.arange(imsize[0]*imsize[1]).chunk(num_chunk)):
176 | with torch.no_grad():
177 | outs.append( self(model_input,trgt_rays=trgt_rays,ctxt_rays=trgt_rays,poses=sample_out["poses"]) if query_pose is None else
178 | self.render(model_input,sample_out["poses"],trgt_rays,query_pose[:,None]))
179 |
180 | out_all = {}
181 | for k,v in outs[0].items():
182 | if len(v.shape)>3 and "inp" not in k and "poses" not in k: out_all[k]=torch.cat([out[k] for out in outs],2)
183 | else:out_all[k]=v
184 | out_all["depth_raw"] = out_all["depth"]
185 | for k,v in out_all.items():
186 | if "depth" in k: out_all[k] = torch.from_numpy(cm.get_cmap('magma')(v.min().item()/v.cpu().numpy())).squeeze(-2)[...,:3]
187 |
188 | return out_all
189 |
190 | # Implicit function which performs pixel-aligned nerf
191 | class MLPImplicit(nn.Module):
192 |
193 | def __init__(self,latent_in=512,inner_latent=128,white_back=True,num_first_res_block=2,add_view_dir=False):
194 | super().__init__()
195 | self.white_back=white_back
196 |
197 | self.pos_encoder = PositionalEncodingNoFreqFactor(3,5)
198 |
199 | self.phi1= mlp_modules.ResFAMLP(3,latent_in,inner_latent,inner_latent,num_first_res_block,last_res=True) # 2 res blocks, outputs deep feature to be averaged over ctxts
200 | self.phi1.apply(mlp_modules.init_weights_normal)
201 | self.phi2= mlp_modules.ResMLP(inner_latent,inner_latent,4)
202 | self.phi2.apply(mlp_modules.init_weights_normal)
203 |
204 | # Note this method does not use cam2world and assumes `pos` is transformed to each context view's coordinate system and duplicated for each context view
205 | def forward(self,ctxt_feats,pos,intrinsics,samp_dists,flow_feats=None,identity=False):
206 |
207 | if identity: # 1:1 correspondence of ctxt img to trgt render, used for rendering out identity camera
208 | b_org,n_trgt_org = ctxt_feats.shape[:2]
209 | ctxt_feats,pos,intrinsics = [x.flatten(0,1)[:,None] for x in [ctxt_feats,pos,intrinsics]]
210 |
211 | if len(pos.shape)==5: pos=pos.unsqueeze(1) # unsqueeze ctxt if no ctxt dim (means only 1 ctxt supplied)
212 | (b,n_ctxt),n_trgt=(ctxt_feats.shape[:2],pos.size(2))
213 |
214 | # Projection onto identity camera
215 | img_feat, pos_,_ = geometry.pixel_aligned_features(
216 | repeat(pos,"b ctxt trgt xy d c -> (b ctxt trgt) (xy d) c"),
217 | repeat(torch.eye(4)[None,None].to(ctxt_feats),"1 1 x y -> (b ctxt trgt) x y",b=b,ctxt=n_ctxt,trgt=n_trgt),
218 | repeat(intrinsics,"b trgt x y -> (b ctxt trgt) x y",ctxt=n_ctxt),
219 | repeat(ctxt_feats,"b ctxt c x y -> (b ctxt trgt) c x y",trgt=n_trgt),
220 | )
221 | img_feat,pos= [rearrange(x,"(b ctxt trgt) (xy d) c -> (b trgt) xy d ctxt c",b=b,ctxt=n_ctxt,d=samp_dists.size(-1)) for x in [img_feat,pos_]]
222 |
223 | # Map (3d crd,projecting img feature) to (rgb,sigma)
224 | out = self.phi2(self.phi1(pos,img_feat).mean(dim=-2))
225 | rgb,sigma= out[...,:3].sigmoid(),out[...,3:4].relu()
226 |
227 | # Alphacomposite
228 | rgb, depth, weights = renderer.volume_integral(z_vals=samp_dists, sigmas=sigma, radiances=rgb, white_back=self.white_back)
229 |
230 | out = [x.unflatten(0,(b,n_trgt)) for x in [rgb,depth, weights]]
231 | return [x.squeeze(1).unflatten(0,(b_org,n_trgt_org)) for x in out] if identity else out
232 |
233 | # Adds RAFT flow and Midas depth to the model input
234 | class RaftAndMidas(nn.Module):
235 | def __init__(self,raft=True,midas=True):
236 | super().__init__()
237 | self.run_raft,self.run_midas=raft,midas
238 |
239 | if self.run_raft:
240 | from torchvision.models.optical_flow import Raft_Small_Weights
241 | from torchvision.models.optical_flow import raft_small
242 | print("making raft")
243 | self.raft_transforms = Raft_Small_Weights.DEFAULT.transforms()
244 | self.raft = raft_small(weights=Raft_Small_Weights.DEFAULT, progress=False)
245 | print("done making raft")
246 | if self.run_midas:
247 | print("making midas")
248 | self.midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms").dpt_transform
249 | self.midas_large = torch.hub.load("intel-isl/MiDaS", "DPT_Hybrid")
250 | print("done making midas")
251 | else: print("no midas")
252 |
253 | def forward(self,model_input):
254 |
255 | # Estimate raft flow and midas depth if no flow/depth in dataset
256 | imsize=model_input["ctxt_rgb"].shape[-2:]
257 | (b,n_ctxt),n_trgt=model_input["ctxt_rgb"].shape[:2],model_input["trgt_rgb"].size(1)
258 | # Inputs should be in range [0,255]; TODO change to [-1,1] to stay consistent with other RGB range
259 |
260 | with torch.no_grad():
261 | # Compute RAFT flow from each frame to next frame forward and backward - note image shape must be > 128 for raft to work
262 | if "bwd_flow" not in model_input and self.run_raft:
263 | raft = lambda x,y: F.interpolate(self.raft(x,y,num_flow_updates=12)[-1]/(torch.tensor(x.shape[-2:][::-1])-1).to(x)[None,:,None,None],imsize)
264 | #raft_rgbs = torch.cat([self.midas_transforms(raft_rgb.permute(1,2,0).cpu().numpy()) for raft_rgb in raft_rgbs]).cuda()
265 | raft_inputs = self.raft_transforms(model_input["trgt_rgb_large"].flatten(0,1).to(torch.uint8),model_input["ctxt_rgb_large"].flatten(0,1).to(torch.uint8))
266 | model_input["bwd_flow"] = raft(*raft_inputs).unflatten(0,(b,n_trgt))
267 |
268 | # Compute midas depth if not on a datastet with depth
269 | if "trgt_depth" not in model_input and self.run_midas:
270 | # Compute midas depth for sup.
271 | # TODO normalize this input correctly based on torch transform. I think it's just imagenet + [0,1] mapping but check
272 | midas = lambda x: F.interpolate(1/(1e-3+self.midas_large(x)).unsqueeze(1),imsize)
273 | midas_rgbs = torch.cat((model_input["ctxt_rgb_med"],model_input["trgt_rgb_med"][:,-1:]),1).flatten(0,1)
274 | midas_rgbs = (midas_rgbs/255)*2-1
275 | all_depth = midas(midas_rgbs).unflatten(0,(b,n_trgt+1))
276 | model_input["trgt_depth"] = all_depth[:,1:]
277 | model_input["ctxt_depth"] = all_depth[:,:-1]
278 | return model_input
279 |
280 | class PositionalEncodingNoFreqFactor(nn.Module):
281 | """PositionalEncoding module
282 |
283 | Maps v to positional encoding representation phi(v)
284 |
285 | Arguments:
286 | i_dim (int): input dimension for v
287 | N_freqs (int): #frequency to sample (default: 10)
288 | """
289 | def __init__(
290 | self,
291 | i_dim: int,
292 | N_freqs: int = 10,
293 | ) -> None:
294 |
295 | super().__init__()
296 |
297 | self.i_dim = i_dim
298 | self.out_dim = i_dim + (2 * N_freqs) * i_dim
299 | self.N_freqs = N_freqs
300 |
301 | a, b = 1, self.N_freqs - 1
302 | self.freq_bands = 2 ** torch.linspace(a, b, self.N_freqs)
303 |
304 | def forward(self, v):
305 | pe = [v]
306 | for freq in self.freq_bands:
307 | fv = freq * v
308 | pe += [torch.sin(fv), torch.cos(fv)]
309 | return torch.cat(pe, dim=-1)
310 |
311 |
--------------------------------------------------------------------------------
/renderer.py:
--------------------------------------------------------------------------------
1 | """Volume rendering code."""
2 | from geometry import *
3 | from typing import Callable, List, Optional, Tuple, Generator, Dict
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 | import numpy as np
8 | from torch import Tensor,device
9 |
10 |
11 | def pdf_z_values(
12 | bins: Tensor,
13 | weights: Tensor,
14 | samples: int,
15 | d: device,
16 | perturb: bool,
17 | ) -> Tensor:
18 | """Generate z-values from pdf
19 | Arguments:
20 | bins (Tensor): z-value bins (B, N - 2)
21 | weights (Tensor): bin weights gathered from first pass (B, N - 1)
22 | samples (int): number of samples N
23 | d (device): torch device
24 | perturb (bool): peturb ray query segment
25 | Returns:
26 | t (Tensor): z-values sampled from pdf (B, N)
27 | """
28 | EPS = 1e-5
29 | B, N = weights.size()
30 |
31 | weights = weights + EPS
32 | pdf = weights / torch.sum(weights, dim=-1, keepdim=True)
33 | cdf = torch.cumsum(pdf, dim=-1)
34 | cdf = torch.cat((torch.zeros_like(cdf[:, :1]), cdf), dim=-1)
35 |
36 | if perturb:
37 | u = torch.rand((B, samples), device=d)
38 | u = u.contiguous()
39 | else:
40 | u = torch.linspace(0, 1, samples, device=d)
41 | u = u.expand(B, samples)
42 | u = u.contiguous()
43 |
44 | idxs = torch.searchsorted(cdf, u, right=True)
45 | idxs_below = torch.clamp_min(idxs - 1, 0)
46 | idxs_above = torch.clamp_max(idxs, N)
47 | idxs = torch.stack((idxs_below, idxs_above), dim=-1).view(B, 2 * samples)
48 |
49 | cdf = torch.gather(cdf, dim=1, index=idxs).view(B, samples, 2)
50 | bins = torch.gather(bins, dim=1, index=idxs).view(B, samples, 2)
51 |
52 | den = cdf[:, :, 1] - cdf[:, :, 0]
53 | den[den < EPS] = 1.0
54 |
55 | t = (u - cdf[:, :, 0]) / den
56 | t = bins[:, :, 0] + t * (bins[:, :, 1] - bins[:, :, 0])
57 |
58 | return t
59 |
60 |
61 | def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5):
62 | """
63 | Sample @N_importance samples from @bins with distribution defined by @weights.
64 | Inputs:
65 | bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
66 | weights: (N_rays, N_samples_)
67 | N_importance: the number of samples to draw from the distribution
68 | det: deterministic or not
69 | eps: a small number to prevent division by zero
70 | Outputs:
71 | samples: the sampled samples
72 | Source: https://github.com/kwea123/nerf_pl/blob/master/models/rendering.py
73 | """
74 | N_rays, N_samples_ = weights.shape
75 | weights = weights + eps # prevent division by zero (don't do inplace op!)
76 | pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
77 | cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
78 | cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1) # (N_rays, N_samples_+1)
79 | # padded to 0~1 inclusive
80 |
81 | if det:
82 | u = torch.linspace(0, 1, N_importance, device=bins.device)
83 | u = u.expand(N_rays, N_importance)
84 | else:
85 | u = torch.rand(N_rays, N_importance, device=bins.device)
86 | u = u.contiguous()
87 |
88 | inds = torch.searchsorted(cdf, u)
89 | below = torch.clamp_min(inds - 1, 0)
90 | above = torch.clamp_max(inds, N_samples_)
91 |
92 | inds_sampled = torch.stack([below, above], -1).view(N_rays, 2 * N_importance)
93 | cdf_g = torch.gather(cdf, 1, inds_sampled)
94 | cdf_g = cdf_g.view(N_rays, N_importance, 2)
95 | bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
96 |
97 | denom = cdf_g[..., 1] - cdf_g[..., 0]
98 | denom[
99 | denom < eps
100 | ] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled
101 | # anyway, therefore any value for it is fine (set to 1 here)
102 |
103 | samples = bins_g[..., 0] + (u - cdf_g[..., 0]) / denom * (
104 | bins_g[..., 1] - bins_g[..., 0]
105 | )
106 | return samples
107 |
108 |
109 | def pdf_rays(
110 | ro: Tensor,
111 | rd: Tensor,
112 | t: Tensor,
113 | weights: Tensor,
114 | samples: int,
115 | perturb: bool,
116 | ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
117 | """Sample pdf along rays given computed weights
118 | Arguments:
119 | ro (Tensor): rays origin (B, 3)
120 | rd (Tensor): rays direction (B, 3)
121 | t (Tensor): coarse z-value (B, N)
122 | weights (Tensor): weights gathered from first pass (B, N)
123 | samples (int): number of samples along the ray
124 | perturb (bool): peturb ray query segment
125 | Returns:
126 | rx (Tensor): rays position queries (B, Nc + Nf, 3)
127 | rd (Tensor): rays direction (B, Nc + Nf, 3)
128 | t (Tensor): z-values from near to far (B, Nc + Nf)
129 | delta (Tensor): rays segment lengths (B, Nc + Nf)
130 | """
131 | B, S, N_coarse, _ = weights.shape
132 | weights = rearrange(weights, "b n s 1 -> (b n) s")
133 | t = rearrange(t, "b n s 1 -> (b n) s")
134 |
135 | Nf = samples
136 | tm = 0.5 * (t[:, :-1] + t[:, 1:])
137 | t_pdf = sample_pdf(tm, weights[..., 1:-1], Nf, det=False).detach().view(B, S, Nf)
138 | rx = ro[..., None, :] + rd[..., None, :] * t_pdf[..., None]
139 |
140 | return rx, t_pdf
141 |
142 |
143 | def sample_points_along_rays(
144 | near_depth: float,
145 | far_depth: float,
146 | num_samples: int,
147 | ray_origins: torch.Tensor,
148 | ray_directions: torch.Tensor,
149 | device: torch.device,
150 | logspace=False,
151 | perturb=False,
152 | ):
153 | # Compute a linspace of num_samples depth values beetween near_depth and far_depth.
154 | if logspace:
155 | z_vals = torch.logspace(np.log10(near_depth), np.log10(far_depth), num_samples, device=device)
156 | else:
157 | z_vals = torch.linspace(near_depth, far_depth, num_samples, device=device)
158 |
159 | if perturb:
160 | z_vals = z_vals + .5*(torch.rand_like(z_vals)-.5)*torch.cat([(z_vals[:-1]-z_vals[1:]).abs(),torch.tensor([0]).cuda()])
161 |
162 | # Using the ray_origins, ray_directions, generate 3D points along
163 | # the camera rays according to the z_vals.
164 | pts = (
165 | ray_origins[..., None, :] + ray_directions[..., None, :] * z_vals[..., :, None]
166 | )
167 |
168 | return pts, z_vals
169 |
170 |
171 | def volume_integral(
172 | z_vals: torch.tensor, sigmas: torch.tensor, radiances: torch.tensor, white_back=False,dist_scale=True,
173 | ) -> Tuple[torch.tensor, torch.tensor, torch.tensor]:
174 | # Compute the deltas in depth between the points.
175 | dists = torch.cat(
176 | [
177 | z_vals[..., 1:] - z_vals[..., :-1],
178 | (z_vals[..., 1:] - z_vals[..., :-1])[..., -1:],
179 | ],
180 | -1,
181 | )
182 |
183 | # Compute the alpha values from the densities and the dists.
184 | # Tip: use torch.einsum for a convenient way of multiplying the correct
185 | # dimensions of the sigmas and the dists.
186 | # TODO switch to just expanding shape of dists for less code
187 | dist_scaling=dists if dist_scale else torch.ones_like(dists)
188 | if len(dists.shape)==1: alpha = 1.0 - torch.exp(-torch.einsum("brzs, z -> brzs", F.relu(sigmas), dist_scaling))
189 | else: alpha = 1.0 - torch.exp(-torch.einsum("brzs, brz -> brzs", F.relu(sigmas), dist_scaling.flatten(0,1)))
190 |
191 | alpha_shifted = torch.cat(
192 | [torch.ones_like(alpha[:, :, :1]), 1.0 - alpha + 1e-10], -2
193 | )
194 |
195 | # Compute the Ts from the alpha values. Use torch.cumprod.
196 | Ts = torch.cumprod(alpha_shifted, -2)
197 |
198 | # Compute the weights from the Ts and the alphas.
199 | weights = alpha * Ts[..., :-1, :]
200 |
201 | # Compute the pixel color as the weighted sum of the radiance values.
202 | rgb = torch.einsum("brzs, brzs -> brs", weights, radiances)
203 |
204 | # Compute the depths as the weighted sum of z_vals.
205 | # Tip: use torch.einsum for a convenient way of computing the weighted sum,
206 | # without the need to reshape the z_vals.
207 | if len(dists.shape)==1:
208 | depth_map = torch.einsum("brzs, z -> brs", weights, z_vals)
209 | else:
210 | depth_map = torch.einsum("brzs, brz -> brs", weights, z_vals.flatten(0,1))
211 |
212 | if white_back:
213 | accum = weights.sum(dim=-2)
214 | backgrd_color = torch.tensor([1,1,1]+[0]*(rgb.size(-1)-3)).broadcast_to(rgb.shape).to(rgb)
215 | #backgrd_color = torch.ones(rgb.size(-1)).broadcast_to(rgb.shape).to(rgb)
216 | rgb = rgb + (backgrd_color - accum)
217 |
218 | return rgb, depth_map, weights
219 |
220 |
221 | class VolumeRenderer(nn.Module):
222 | def __init__(self, near, far, n_samples=32, backgrd_color=None):
223 | super().__init__()
224 | self.near = near
225 | self.far = far
226 | self.n_samples = n_samples
227 |
228 | if backgrd_color is not None:
229 | self.register_buffer('backgrd_color', backgrd_color)
230 | else:
231 | self.backgrd_color = None
232 |
233 | def forward(
234 | self, cam2world, intrinsics, x_pix, radiance_field: nn.Module
235 | ) -> Tuple[torch.tensor, torch.tensor]:
236 | """
237 | Takes as inputs ray origins and directions - samples points along the
238 | rays and then calculates the volume rendering integral.
239 |
240 | Params:
241 | input_dict: Dictionary with keys 'cam2world', 'intrinsics', and 'x_pix'
242 | radiance_field: nn.Module instance of the radiance field we want to render.
243 |
244 | Returns:
245 | Tuple of rgb, depth_map
246 | rgb: for each pixel coordinate x_pix, the color of the respective ray.
247 | depth_map: for each pixel coordinate x_pix, the depth of the respective ray.
248 |
249 | """
250 | batch_size, num_rays = x_pix.shape[0], x_pix.shape[1]
251 |
252 | # Compute the ray directions in world coordinates.
253 | # Use the function get_world_rays.
254 | ros, rds = get_world_rays(x_pix, intrinsics, cam2world)
255 |
256 | # Generate the points along rays and their depth values
257 | # Use the function sample_points_along_rays.
258 | pts, z_vals = sample_points_along_rays(
259 | self.near, self.far, self.n_samples, ros, rds, device=x_pix.device
260 | )
261 |
262 | # Reshape pts to (batch_size, -1, 3).
263 | pts = pts.reshape(batch_size, -1, 3)
264 |
265 | # Sample the radiance field with the points along the rays.
266 | sigma, feats, misc = radiance_field(pts)
267 |
268 | # Reshape sigma and feats back to (batch_size, num_rays, self.n_samples, -1)
269 | sigma = sigma.view(batch_size, num_rays, self.n_samples, 1)
270 | feats = feats.view(batch_size, num_rays, self.n_samples, -1)
271 |
272 | # Compute pixel colors, depths, and weights via the volume integral.
273 | rendering, depth_map, weights = volume_integral(z_vals, sigma, feats)
274 |
275 | if self.backgrd_color is not None:
276 | accum = weights.sum(dim=-2)
277 | backgrd_color = self.backgrd_color.broadcast_to(rendering.shape)
278 | rendering = rendering + (backgrd_color - accum)
279 |
280 | return rendering, depth_map, misc
281 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import os,random,time,datetime,glob
2 |
3 | import numpy as np
4 | from functools import partial
5 | from tqdm import tqdm,trange
6 | from einops import rearrange
7 |
8 | import wandb
9 | import torch
10 |
11 | import models
12 | import vis_scripts
13 |
14 | from data.KITTI import KittiDataset
15 | from data.co3d import Co3DNoCams
16 | from data.realestate10k_dataio import RealEstate10k
17 |
18 | def to_gpu(ob): return {k: to_gpu(v) for k, v in ob.items()} if isinstance(ob, dict) else ob.cuda()
19 |
20 | import argparse
21 | parser = argparse.ArgumentParser(description='simple training job')
22 | # logging parameters
23 | parser.add_argument('-n','--name', type=str,default="",required=False,help="wandb training name")
24 | parser.add_argument('-c','--init_ckpt', type=str,default=None,required=False,help="File for checkpoint loading. If folder specific, will use latest .pt file")
25 | parser.add_argument('-d','--dataset', type=str,default="hydrant")
26 | parser.add_argument('-o','--online', default=False, action=argparse.BooleanOptionalAction)
27 | # data/training parameters
28 | parser.add_argument('-b','--batch_size', type=int,default=1,help="number of videos/sequences per training step")
29 | parser.add_argument('-v','--vid_len', type=int,default=6,help="video length or number of images per batch")
30 | parser.add_argument('--midas_sup', default=False, action=argparse.BooleanOptionalAction,help="Whether to use midas depth supervision or not")
31 | parser.add_argument('--category', type=str,default=None,help="if want to use a specific co3d category, such as 'bicycle', specify here")
32 | # model parameters
33 | parser.add_argument('--n_skip', nargs="+",type=int,default=0,help="Number of frames to skip between adjacent frames in dataloader. If list, dataset randomly chooses between skips. Only used for co3d")
34 | parser.add_argument('--n_ctxt', type=int,default=2,help="Number of context views to use. 1 is just first frame, 2 is second and last, 3 is also middle, etc")
35 | # eval/vis
36 | parser.add_argument('--eval', default=False, action=argparse.BooleanOptionalAction,help="whether to train or run evaluation")
37 | parser.add_argument('--n_eval', type=int,default=int(1e8),help="Number of eval samples to run")
38 | parser.add_argument('--save_ind', default=False, action=argparse.BooleanOptionalAction,help="whether to save out each individual image (in rendering images) or just save the all-trajectory image")
39 | parser.add_argument('--save_imgs', default=False, action=argparse.BooleanOptionalAction,help="whether to save out the all-trajectory images")
40 | # demo args
41 | parser.add_argument('--demo_rgb', default="", type=str,required=False,help="The image folder path for demo inference.")
42 | parser.add_argument('--render_imgs', default=False, action=argparse.BooleanOptionalAction,help="whether to rerender out images (video reconstructions) during the demo inference (slower than just estimating poses)")
43 | parser.add_argument('--intrinsics', default=None, type=str,required=False,help="The intrinsics corresponding to the image path for demo inference as fx,fy,cx,cy. To use predicted intrinsics, leave as None")
44 | parser.add_argument('--low_res', nargs="+",type=int,default=[128,128],help="Low resolution to perform renderings at. Default (128,128)")
45 |
46 | args = parser.parse_args()
47 | if args.n_skip==0 and args.dataset=="realestate": args.n_skip=9 # realestate is the only dataset where 0 frameskip isn't meaningful
48 |
49 | # Wandb init: install wandb and initialize via wandb.login() before running
50 | run = wandb.init(project="flowcam",mode="online" if args.online else "disabled",name=args.name)
51 | wandb.run.log_code(".")
52 | save_dir = os.path.join(os.environ.get('LOGDIR', "") , run.name)
53 | print(save_dir)
54 | os.makedirs(save_dir,exist_ok=True)
55 | wandb.save(os.path.join(save_dir, "checkpoint*"))
56 | wandb.save(os.path.join(save_dir, "video*"))
57 |
58 | # Make dataset
59 | get_dataset = lambda val=False: ( Co3DNoCams(num_trgt=args.vid_len+1,low_res=(156,102),num_cat=1 if args.dataset=="hydrant" else 10 if args.dataset=="10cat" else 30,
60 | n_skip=args.n_skip,val=val,category=args.category) if args.dataset in ["hydrant","10cat","allcat"]
61 | else RealEstate10k(imsl=128, num_ctxt_views=2, num_query_views=args.vid_len+1, val=val, n_skip = args.n_skip) if args.dataset == "realestate"
62 | else KittiDataset(num_context=1,num_trgt=args.vid_len+1,low_res=(76,250),val=val,n_skip=args.n_skip)
63 | )
64 | get_dataloader = lambda dataset: iter(torch.utils.data.DataLoader(dataset, batch_size=args.batch_size*torch.cuda.device_count(),num_workers=args.batch_size,shuffle=True,pin_memory=True))
65 |
66 | # Make model + optimizer
67 | model = models.FlowCam(near=1,far=8,num_view=args.n_ctxt,use_midas=args.midas_sup).cuda()
68 | if args.init_ckpt is not None:
69 | ckpt_file = args.init_ckpt if os.path.isfile(os.path.expanduser(args.init_ckpt)) else max(glob.glob(os.path.join(args.init_ckpt,"*.pt")), key=os.path.getctime)
70 | model.load_state_dict(torch.load(ckpt_file)["model_state_dict"],strict=False)
71 | optim = torch.optim.Adam(lr=1e-4, params=model.parameters())
72 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # todo get rid of start import, just import run and access run.train_dataset, etc
2 | from run import *
3 |
4 | train_dataset,val_dataset = get_dataset(),get_dataset(val=True,)
5 | train_dataset[0]
6 | train_dataloader,val_dataloader = get_dataloader(train_dataset),get_dataloader(val_dataset)
7 |
8 | def loss_fn(model_out, gt, model_input):
9 |
10 | rays = lambda x,y: torch.stack([x[i,:,y[i].long()] for i in range(len(x))])
11 | ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
12 | ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
13 | losses = { "metrics/rgb_loss": (model_out["rgb"] - rays(gt["trgt_rgb"],model_out["trgt_rays"])).square().mean() * (1e1 if args.dataset=="shapenet" else 2e2) }
14 |
15 | losses["metrics/ctxt_rgb_loss"]= (model_out["ctxt_rgb"] - rays(gt["ctxt_rgb"],model_out["ctxt_rays"])).square().mean() * 1e2
16 | gt_bwd_flow = rays(gt["bwd_flow"] if "bwd_flow" in gt else ch_sec(model_out["flow_inp"]),model_out["ctxt_rays"])
17 | losses["metrics/flow_from_pose"] = ((model_out["flow_from_pose"].clip(-.2,.2) - gt_bwd_flow.clip(-.2,.2)).square().mean() * 6e3).clip(0,10)
18 |
19 | gt_bwd_flow_trgt = rays(gt["bwd_flow"] if "bwd_flow" in gt else ch_sec(model_out["flow_inp"]),model_out["trgt_rays"])
20 |
21 | # monodepth loss (not used in paper but may be useful later)
22 | if args.midas_sup:
23 | def depth_lstsq_fit(depthgt,depthest):
24 | depthgt,depthest=1/(1e-8+depthgt),1/(1e-8+depthest)
25 | lstsq=torch.linalg.lstsq(depthgt,depthest)
26 | return ((depthgt@lstsq.solution)-depthest).square().mean() * 1e2
27 |
28 | losses["metrics/ctxt_depth_loss_lstsq"] = (depth_lstsq_fit(rays(ch_sec(model_out["ctxt_depth_inp"]),model_out["ctxt_rays"]).flatten(0,1),model_out["ctxt_depth"].flatten(0,1))*2e0).clip(0,2)/2
29 | losses["metrics/depth_loss_lstsq"] = (depth_lstsq_fit(rays(ch_sec(model_out["trgt_depth_inp"]),model_out["trgt_rays"]).flatten(0,1),model_out["depth"].flatten(0,1))*2e0).clip(0,2)/2
30 |
31 | return losses
32 |
33 | model = torch.nn.DataParallel(model)
34 |
35 | # Train loop
36 | for step in trange(0 if args.eval else int(1e8), desc="Fitting"): # train until user interruption
37 |
38 | # Run val set every n iterations
39 | val_step = step>10 and step%150<10
40 | prefix = "val" if val_step else ""
41 | torch.set_grad_enabled(not val_step)
42 | if val_step: print("\n\n\nval step\n\n\n")
43 |
44 | # Get data
45 | try: model_input, ground_truth = next(train_dataloader if not val_step else val_dataloader)
46 | except StopIteration:
47 | if val_step: val_dataloader = get_dataloader(val_dataset)
48 | else: train_dataloader = get_dataloader(train_dataset)
49 | continue
50 |
51 | model_input, ground_truth = to_gpu(model_input), to_gpu(ground_truth)
52 |
53 | # Run model and calculate losses
54 | total_loss = 0.
55 | for loss_name, loss in loss_fn(model(model_input), ground_truth, model_input).items():
56 | wandb.log({prefix+loss_name: loss.item()}, step=step)
57 | total_loss += loss
58 |
59 | wandb.log({prefix+"loss": total_loss.item()}, step=step)
60 | wandb.log({"epoch": (step*args.batch_size)/len(train_dataset)}, step=step)
61 |
62 | if not val_step:
63 | optim.zero_grad(); total_loss.backward(); optim.step();
64 |
65 | # Image summaries and checkpoint
66 | if step%50==0 : # write image summaries
67 | print("writing summary")
68 | with torch.no_grad(): model_output = model.module.render_full_img(model_input=model_input)
69 | vis_scripts.wandb_summary( total_loss, model_output, model_input, ground_truth, None,prefix=prefix)
70 | if step%100==0: #write video summaries
71 | print("writing video summary")
72 | try: vis_scripts.write_video(save_dir, vis_scripts.render_time_interp(model_input,model.module,None,16), prefix+"time_interp",step)
73 | except Exception as e: print("error in writing video",e)
74 | if step%500 == 0 and step: # save model
75 | print(f"Saving to {save_dir}"); torch.save({ 'step': step, 'model_state_dict': model.module.state_dict(), }, os.path.join(save_dir, f"checkpoint_{step}.pt"))
76 |
77 |
78 |
--------------------------------------------------------------------------------
/vis_scripts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import geometry
3 | import wandb
4 | from matplotlib import cm
5 | from torchvision.utils import make_grid
6 | import torch.nn.functional as F
7 | import numpy as np
8 | import torch
9 | import flow_vis
10 | import flow_vis_torch
11 | import matplotlib.pyplot as plt;
12 | from einops import rearrange, repeat
13 | import piqa
14 | import imageio
15 | import splines.quaternion
16 | from torchcubicspline import (natural_cubic_spline_coeffs, NaturalCubicSpline)
17 |
18 | def write_video(save_dir,frames,vid_name,step,write_frames=False):
19 | frames = [(255*x).astype(np.uint8) for x in frames]
20 | if "time" in vid_name: frames = frames + frames[::-1]
21 | f = os.path.join(save_dir, f'{vid_name}.mp4')
22 | imageio.mimwrite(f, frames, fps=8, quality=7)
23 | wandb.log({f'vid/{vid_name}':wandb.Video(f, format='mp4', fps=8)})
24 | print("writing video at",f)
25 | if write_frames:
26 | for i,img in enumerate(frames):
27 | try: os.mkdir(os.path.join(save_dir, f'{vid_name}_{step}'))
28 | except:pass
29 | f=os.path.join(save_dir, f'{vid_name}/{i}.png');plt.imsave(f,img);print(f)
30 |
31 | def normalize(a):
32 | return (a - a.min()) / (a.max() - a.min())
33 |
34 | def cvt(a):
35 | a = a.permute(1, 2, 0).detach().cpu()
36 | a = (a - a.min()) / (a.max() - a.min())
37 | a = a.numpy()
38 | return a
39 |
40 | ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
41 |
42 | # Renders out query frame with interpolated motion field
43 | def render_time_interp(model_input,model,resolution,n):
44 |
45 | b=model_input["ctxt_rgb"].size(0)
46 | resolution = list(model_input["ctxt_rgb"].flatten(0,1).permute(0,2,3,1).shape)
47 | frames=[]
48 | thetas = np.linspace(0, 1, n)
49 | with torch.no_grad(): sample_out = model(model_input)
50 | if "flow_inp" in sample_out:model_input["bwd_flow"]=sample_out["flow_inp"]
51 |
52 | # TODO add wobble flag back in here from satori code
53 |
54 | all_poses=sample_out["poses"]
55 | pos_spline_idxs=torch.linspace(0,all_poses.size(1)-1,all_poses.size(1)) # no compression
56 | rot_spline_idxs=torch.linspace(0,all_poses.size(1)-1,all_poses.size(1)) # no compression
57 | all_pos_spline=[]
58 | all_quat_spline=[]
59 | for b_i in range(b):
60 | all_pos_spline.append(NaturalCubicSpline(natural_cubic_spline_coeffs(pos_spline_idxs, all_poses[b_i,pos_spline_idxs.long(),:3,-1].cpu())))
61 | quats=geometry.matrix_to_quaternion(all_poses[b_i,:,:3,:3])
62 | all_quat_spline.append(splines.quaternion.PiecewiseSlerp([splines.quaternion.UnitQuaternion.from_unit_xyzw(quat_)
63 | for quat_ in quats[rot_spline_idxs.long()].cpu().numpy()],grid=rot_spline_idxs.tolist()))
64 |
65 | for t in torch.linspace(0,all_poses.size(1)-1,n):
66 | print(t)
67 |
68 | custom_poses=[]
69 | for b_i,(pos_spline,quat_spline_) in enumerate(zip(all_pos_spline,all_quat_spline)):
70 | custom_pose=torch.eye(4).cuda()
71 | custom_pose[:3,-1]=pos_spline.evaluate(t)
72 | closest_t = (custom_pose[:3,-1]-all_poses[b_i,:,:3,-1]).square().sum(-1).argmin()
73 | quat_eval=quat_spline_.evaluate(t.item())
74 | curr_quats = torch.tensor(list(quat_eval.vector)+[quat_eval.scalar])
75 | custom_pose[:3,:3] = geometry.quaternion_to_matrix(curr_quats)
76 | custom_poses.append(custom_pose)
77 | custom_pose=torch.stack(custom_poses)
78 | with torch.no_grad(): model_out = model.render_full_img(model_input,query_pose=custom_pose,sample_out=sample_out)
79 |
80 | rgb_pred = model_out["rgb"]
81 | resolution = list(model_input["ctxt_rgb"][:,:1].flatten(0,1).permute(0,2,3,1).shape)
82 | rgb_pred=rgb_pred[:,:1].view(resolution).permute(1,0,2,3).flatten(1,2).cpu().numpy()
83 | magma_depth=model_out["depth"][:,:1].view(resolution).permute(1,0,2,3).flatten(1,2).cpu()
84 | rgbd_im=torch.cat((torch.from_numpy(rgb_pred),magma_depth),0).numpy()
85 | frames.append(rgbd_im)
86 | return frames
87 |
88 |
89 | for i in range(n):
90 | print(i,n)
91 | query_pose = geometry.time_interp_poses(sample_out["poses"],i/(n-1), model_input["trgt_rgb"].size(1),None)[:,0]
92 | # todo fix this interpolation -- is it incorrect to interpolate here?
93 | with torch.no_grad(): model_out = model.render_full_img(model_input,query_pose=query_pose,sample_out=sample_out)
94 | rgb_pred = model_out["rgb"]
95 | resolution = list(model_input["ctxt_rgb"][:,:1].flatten(0,1).permute(0,2,3,1).shape)
96 | rgb_pred=rgb_pred[:,:1].view(resolution).permute(1,0,2,3).flatten(1,2).cpu().numpy()
97 | magma_depth=model_out["depth"][:,:1].view(resolution).permute(1,0,2,3).flatten(1,2).cpu()
98 | rgbd_im=torch.cat((torch.from_numpy(rgb_pred),magma_depth),0).numpy()
99 | frames.append(rgbd_im)
100 | return frames
101 |
102 | def look_at(eye, at=torch.Tensor([0, 0, 0]).cuda(), up=torch.Tensor([0, 1, 0]).cuda(), eps=1e-5):
103 | #at = at.unsqueeze(0).unsqueeze(0)
104 | #up = up.unsqueeze(0).unsqueeze(0)
105 |
106 | z_axis = eye - at
107 | #z_axis /= z_axis.norm(dim=-1, keepdim=True) + eps
108 | z_axis = z_axis/(z_axis.norm(dim=-1, keepdim=True) + eps)
109 |
110 | up = up.expand(z_axis.shape)
111 | x_axis = torch.cross(up, z_axis)
112 | #x_axis /= x_axis.norm(dim=-1, keepdim=True) + eps
113 | x_axis = x_axis/(x_axis.norm(dim=-1, keepdim=True) + eps)
114 |
115 | y_axis = torch.cross(z_axis, x_axis)
116 | #y_axis /= y_axis.norm(dim=-1, keepdim=True) + eps
117 | y_axis = y_axis/(y_axis.norm(dim=-1, keepdim=True) + eps)
118 |
119 | r_mat = torch.stack((x_axis, y_axis, z_axis), axis=-1)
120 | return r_mat
121 |
122 | def render_cam_traj_wobble(model_input,model,resolution,n):
123 |
124 | c2w = torch.eye(4, device='cuda')[None]
125 | tmp = torch.eye(4).cuda()
126 | circ_scale = .1
127 | thetas = np.linspace(0, 2 * np.pi, n)
128 | frames = []
129 |
130 | if "ctxt_c2w" not in model_input:
131 | model_input["ctxt_c2w"] = torch.tensor([[-2.5882e-01, -4.8296e-01, 8.3652e-01, -2.2075e+00],
132 | [ 2.1187e-08, -8.6603e-01, -5.0000e-01, 2.3660e+00],
133 | [-9.6593e-01, 1.2941e-01, -2.2414e-01, 5.9150e-01],
134 | [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00]]
135 | )[None,None].expand(model_input["trgt_rgb"].size(0),model_input["trgt_rgb"].size(1),-1,-1).cuda()
136 |
137 | resolution = list(model_input["ctxt_rgb"].flatten(0,1).permute(0,2,3,1).shape)
138 |
139 | c2w = model_input["ctxt_c2w"]
140 | circ_scale = c2w[0,0,[0,2],-1].norm()
141 | #circ_scale = c2w[0,[0,1],-1].norm()
142 | thetas=np.linspace(0,np.pi*2,n)
143 | rgb_imgs=[]
144 | depth_imgs=[]
145 | start_theta=0#(model_input["ctxt_c2w"][0,0,0,-1]/circ_scale).arccos()
146 | with torch.no_grad(): sample_out = model(model_input)
147 | if "flow_inp" in sample_out: model_input["bwd_flow"]=sample_out["flow_inp"]
148 | step=2 if n==8 else 4
149 | zs = torch.cat((torch.linspace(0,-n//4,n//step),torch.linspace(-n//4,0,n//step),torch.linspace(0,n//4,n//step),torch.linspace(n//4,0,n//step)))
150 | for i in range(n):
151 | print(i,n)
152 | theta=float(thetas[i] + start_theta)
153 | x=np.cos(theta) * circ_scale * .075
154 | y=np.sin(theta) * circ_scale * .075
155 | tmp=torch.eye(4).cuda()
156 | newpos=torch.tensor([x,y,zs[i]*1e-1]).cuda().float()
157 | tmp[:3,-1] = newpos
158 | custom_c2w = tmp[None].expand(c2w.size(0),c2w.size(1),-1,-1)
159 | with torch.no_grad(): model_out = model(model_input,custom_transf=custom_c2w,full_img=True)
160 |
161 | resolution = [model_input["trgt_rgb"].size(0)]+list(resolution[1:])
162 | b = model_out["rgb"].size(0)
163 | rgb_pred = model_out["rgb"][:,0].view(resolution).permute(1,0,2,3).flatten(1,2).cpu().numpy()
164 | magma_depth = model_out["depth"][:,0].view(resolution).permute(1,0,2,3).flatten(1,2).cpu()
165 | rgbd_im=torch.cat((torch.from_numpy(rgb_pred),magma_depth),0).numpy()
166 | frames.append(rgbd_im)
167 | return frames
168 |
169 |
170 |
171 | def render_cam_traj_time_wobble(model_input,model,resolution,n):
172 |
173 | c2w = torch.eye(4, device='cuda')[None]
174 | if "ctxt_c2w" not in model_input:
175 | model_input["ctxt_c2w"] = torch.tensor([[-2.5882e-01, -4.8296e-01, 8.3652e-01, -2.2075e+00],
176 | [ 2.1187e-08, -8.6603e-01, -5.0000e-01, 2.3660e+00],
177 | [-9.6593e-01, 1.2941e-01, -2.2414e-01, 5.9150e-01],
178 | [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00]]
179 | )[None,None].expand(model_input["trgt_rgb"].size(0),model_input["trgt_rgb"].size(1),-1,-1).cuda()
180 |
181 | c2w = model_input["ctxt_c2w"]
182 | circ_scale = c2w[0,0,[0,2],-1].norm()
183 | thetas=np.linspace(0,np.pi*2,n)
184 | rgb_imgs=[]
185 | depth_imgs=[]
186 | start_theta=0#(model_input["ctxt_c2w"][0,0,0,-1]/circ_scale).arccos()
187 |
188 | resolution = list(model_input["ctxt_rgb"].flatten(0,1).permute(0,2,3,1).shape)
189 | frames=[]
190 | thetas = np.linspace(0, 1, n)
191 | with torch.no_grad(): sample_out = model(model_input)
192 | if "flow_inp" in sample_out: model_input["bwd_flow"]=sample_out["flow_inp"]
193 | step=2 if n==8 else 4
194 | zs = torch.cat((torch.linspace(0,-n//4,n//step),torch.linspace(-n//4,0,n//step),torch.linspace(0,n//4,n//step),torch.linspace(n//4,0,n//step)))
195 | for i in range(n):
196 |
197 | print(i,n)
198 | theta=float(thetas[i] + start_theta)
199 | x=np.cos(theta) * circ_scale * .005
200 | y=np.sin(theta) * circ_scale * .005
201 | tmp=torch.eye(4).cuda()
202 | newpos=torch.tensor([x,y,zs[i]*2e-1]).cuda().float()
203 | tmp[:3,-1] = newpos
204 | custom_c2w = tmp[None].expand(c2w.size(0),c2w.size(1),-1,-1)
205 |
206 | with torch.no_grad(): model_out = model(model_input,time_i=i/(n-1),full_img=True,custom_transf=custom_c2w)
207 | rgb_pred = model_out["rgb"]
208 | same_all=True
209 | if same_all:
210 | resolution = list(model_input["ctxt_rgb"][:,:1].flatten(0,1).permute(0,2,3,1).shape)
211 | rgb_pred=rgb_pred[:,:1].view(resolution).permute(1,0,2,3).flatten(1,2).cpu().numpy()
212 | else:
213 | rgb_pred=rgb_pred.view(resolution).permute(1,0,2,3).flatten(1,2).cpu().numpy()
214 | depth_pred = model_out["depth"].clone()
215 | mind,maxd=sample_out["depth"].cpu().min(),sample_out["depth"].cpu().max()
216 | depth_pred[0,0,0]=mind #normalize
217 | depth_pred[0,0,1]=maxd #normalize
218 | if same_all:
219 | depth_pred = (mind/(1e-3+depth_pred[:,:1]).view(resolution[:-1]).permute(1,0,2).flatten(1,2).cpu().numpy())
220 | else:
221 | depth_pred = (mind/(1e-3+depth_pred).view(resolution[:-1]).permute(1,0,2).flatten(1,2).cpu().numpy())
222 | magma = cm.get_cmap('magma')
223 | magma_depth = torch.from_numpy(magma(depth_pred))[...,:3]
224 | rgbd_im=torch.cat((torch.from_numpy(rgb_pred),magma_depth),0).numpy()
225 | frames.append(rgbd_im)
226 |
227 | return frames
228 |
229 | # Renders out context frame with novel camera pose
230 | def render_view_interp(model_input,model,resolution,n):
231 |
232 | resolution = list(model_input["ctxt_rgb"].flatten(0,1).permute(0,2,3,1).shape)
233 |
234 | c2w = torch.eye(4, device='cuda')[None]
235 | tmp = torch.eye(4).cuda()
236 | circ_scale = .1
237 | thetas = np.linspace(0, 2 * np.pi, n)
238 | frames = []
239 |
240 | if "ctxt_c2w" not in model_input:
241 | model_input["ctxt_c2w"] = torch.tensor([[-2.5882e-01, -4.8296e-01, 8.3652e-01, -2.2075e+00],
242 | [ 2.1187e-08, -8.6603e-01, -5.0000e-01, 2.3660e+00],
243 | [-9.6593e-01, 1.2941e-01, -2.2414e-01, 5.9150e-01],
244 | [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00]]
245 | )[None,None].expand(model_input["trgt_rgb"].size(0),model_input["trgt_rgb"].size(1),-1,-1).cuda()
246 |
247 | c2w = model_input["ctxt_c2w"]
248 | circ_scale = c2w[0,0,[0,2],-1].norm()
249 | #circ_scale = c2w[0,[0,1],-1].norm()
250 | thetas=np.linspace(0,np.pi*2,n)
251 | rgb_imgs=[]
252 | depth_imgs=[]
253 | start_theta=0#(model_input["ctxt_c2w"][0,0,0,-1]/circ_scale).arccos()
254 | with torch.no_grad(): sample_out = model(model_input)
255 | if "flow_inp" in sample_out: model_input["bwd_flow"]=sample_out["flow_inp"]
256 | for i in range(n):
257 | print(i,n)
258 | theta=float(thetas[i] + start_theta)
259 | x=np.cos(theta) * circ_scale * 1
260 | y=np.sin(theta) * circ_scale * 1
261 | tmp=torch.eye(4).cuda()
262 | #newpos=torch.tensor([x,y,c2w[0,2,-1]]).cuda().float()
263 | newpos=torch.tensor([x,c2w[0,0,1,-1],y]).cuda().float()
264 | rot = look_at(newpos,torch.tensor([0,0,0]).cuda())
265 | rot[:,1:]*=-1
266 | tmp[:3,:3]=rot
267 | newpos=torch.tensor([x,c2w[0,0,1,-1],y]).cuda().float()
268 | tmp[:3,-1] = newpos
269 | #with torch.no_grad(): model_out = model(model_input,custom_transf=tmp[None].expand(c2w.size(0),-1,-1))
270 | custom_c2w = tmp[None].expand(c2w.size(0),c2w.size(1),-1,-1)
271 | #TODO make circle radius and only use first img
272 | #from pdb import set_trace as pdb_;pdb_()
273 | if 1:
274 | custom_c2w = custom_c2w.inverse()@model_input["ctxt_c2w"]
275 | #custom_c2w = model_input["ctxt_c2w"].inverse()@custom_c2w
276 | with torch.no_grad(): model_out = model(model_input,custom_transf=custom_c2w,full_img=True)
277 |
278 | resolution = [model_input["trgt_rgb"].size(0)]+list(resolution[1:])
279 |
280 | b = model_out["rgb"].size(0)
281 | rgb_pred = model_out["rgb"][:,0].view(resolution).permute(1,0,2,3).flatten(1,2).cpu().numpy()
282 | magma_depth = model_out["depth"][:,0].view(resolution).permute(1,0,2,3).flatten(1,2).cpu()
283 | rgbd_im=torch.cat((torch.from_numpy(rgb_pred),magma_depth),0).numpy()
284 | frames.append(rgbd_im)
285 | return frames
286 |
287 | def wandb_summary(loss, model_output, model_input, ground_truth, resolution,prefix=""):
288 |
289 | resolution = list(model_input["ctxt_rgb"].flatten(0,1).permute(0,2,3,1).shape)
290 | resolution[0]=ground_truth["trgt_rgb"].size(1)*ground_truth["trgt_rgb"].size(0)
291 | nrow=model_input["trgt_rgb"].size(1)
292 | imsl=model_input["ctxt_rgb"].shape[-2:]
293 | inv = lambda x : 1/(x+1e-8)
294 |
295 | depth = make_grid(model_output["depth"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
296 |
297 | wandb_out = {
298 | "est/rgb_pred": make_grid(model_output["rgb"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow),
299 | "ref/rgb_gt": make_grid(ground_truth["trgt_rgb"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow),
300 | #"ref/rgb_gt": make_grid(ground_truth["trgt_rgb"].cpu().view(*resolution).detach().permute(0, -1, 1, 2),nrow=nrow),
301 | "ref/ctxt_img": make_grid(model_input["ctxt_rgb"][:,0].cpu().detach(),nrow=1)*.5+.5,
302 | "est/depth": depth,
303 | "est/depth_1ch":make_grid(model_output["depth_raw"].flatten(0,1).permute(0,2,1).unflatten(-1,imsl).cpu(),normalize=True,nrow=nrow),
304 | }
305 |
306 | depthgt = (ground_truth["trgt_depth"] if "trgt_depth" in ground_truth else model_output["trgt_depth_inp"] if "trgt_depth_inp" in model_output
307 | else model_input["trgt_depth"] if "trgt_depth" in model_input else None)
308 |
309 | if "ctxt_rgb" in model_output:
310 | wandb_out["est/ctxt_depth"] =make_grid(model_output["ctxt_depth"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
311 | wandb_out["est/ctxt_rgb_pred"] = ctxt_rgb_pred = make_grid(model_output["ctxt_rgb"].cpu().view(*resolution).detach().permute(0, -1, 1, 2),nrow=nrow)
312 |
313 | if "corr_weights" in model_output:
314 | #corr_weights = make_grid(model_output["corr_weights"].flatten(0,1)[:,:1].cpu().detach(),normalize=False,nrow=nrow)
315 | corr_weights = make_grid(ch_fst(model_output["corr_weights"],resolution[1]).flatten(0,1)[:,:1].cpu().detach(),normalize=False,nrow=nrow)
316 | wandb_out["est/corr_weights"] = corr_weights
317 |
318 | if "flow_from_pose" in model_output and not torch.isnan(model_output["flow_from_pose"]).any() and not torch.isnan(model_output["flow_from_pose"]).any():
319 | #psnr = piqa.PSNR()(ch_fst(model_output["rgb"],imsl[0]).flatten(0,1).contiguous(),ch_fst(ground_truth["trgt_rgb"],imsl[0]).flatten(0,1).contiguous())
320 | #wandb.log({prefix+"metrics/psnr": psnr})
321 |
322 | gt_flow_bwd = flow_vis_torch.flow_to_color(make_grid(model_output["flow_inp"].flatten(0,1),nrow=nrow))/255
323 | wandb_out["ref/flow_gt_bwd"]=gt_flow_bwd
324 | if "flow_from_pose" in model_output:
325 | wandb_out["est/flow_est_pose"] = flow_vis_torch.flow_to_color(make_grid(model_output["flow_from_pose"].flatten(0,1).permute(0,2,1).unflatten(-1,imsl),nrow=nrow))/255
326 | if "flow_from_pose_render" in model_output:
327 | wandb_out["est/flow_est_pose_render"] = flow_vis_torch.flow_to_color(make_grid(model_output["flow_from_pose_render"].flatten(0,1).permute(0,2,1).unflatten(-1,imsl),nrow=nrow))/255
328 | else:
329 | print("skipping flow plotting")
330 |
331 | wandb_out = {prefix+k:wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()}
332 |
333 | wandb.log(wandb_out)
334 |
335 |
336 | def pose_summary(loss, model_output, model_input, ground_truth, resolution,prefix=""):
337 | # Log points and boxes in W&B
338 | point_scene = wandb.Object3D({
339 | "type": "lidar/beta",
340 | "points": model_output["poses"][:,:3,-1].cpu().numpy(),
341 | })
342 | wandb.log({"camera positions": point_scene})
343 |
344 |
345 |
346 |
--------------------------------------------------------------------------------
/wandb_logging.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wandb
3 | from matplotlib import cm
4 | from torchvision.utils import make_grid
5 | import torch.nn.functional as F
6 | import numpy as np
7 | import torch
8 | import flow_vis
9 | import flow_vis_torch
10 | import matplotlib.pyplot as plt; imsave = lambda x,y=0: plt.imsave("/nobackup/users/camsmith/img/tmp%s.png"%y,x.cpu().numpy());
11 | from einops import rearrange, repeat
12 | import piqa
13 | import imageio
14 |
15 | def write_video(save_dir,frames,vid_name,step,write_frames=False):
16 | frames = [(255*x).astype(np.uint8) for x in frames]
17 | if "time" in vid_name: frames = frames + frames[::-1]
18 | f = os.path.join(save_dir, f'{vid_name}_{step}.mp4')
19 | imageio.mimwrite(f, frames, fps=8, quality=7)
20 | wandb.log({f'vid/{vid_name}':wandb.Video(f, format='mp4', fps=8)})
21 | print("writing video at",f)
22 | if write_frames:
23 | for i,img in enumerate(frames):
24 | try: os.mkdir(os.path.join(save_dir, f'{vid_name}_{step}'))
25 | except:pass
26 | f=os.path.join(save_dir, f'{vid_name}_{step}/{i}.png');plt.imsave(f,img);print(f)
27 |
28 | def normalize(a):
29 | return (a - a.min()) / (a.max() - a.min())
30 |
31 | def cvt(a):
32 | a = a.permute(1, 2, 0).detach().cpu()
33 | a = (a - a.min()) / (a.max() - a.min())
34 | a = a.numpy()
35 | return a
36 |
37 | ch_fst = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
38 |
39 | def _wandb_summary(loss, model_output, model_input, ground_truth, resolution,prefix=""):
40 |
41 | resolution = list(model_input["ctxt_rgb"].flatten(0,1).shape)
42 |
43 | nrow=model_input["trgt_rgb"].size(1)
44 | imsly,imslx=model_input["ctxt_rgb"].shape[-2:]
45 |
46 | resolution = list(model_input["ctxt_rgb"].flatten(0,1).permute(0,2,3,1).shape)
47 |
48 | rgb_gt= ground_truth["trgt_rgb"]
49 | rgb_pred,depth,=[model_output[x] for x in ["rgb","depth"]]
50 |
51 | inv = lambda x : 1/(x+1e-3)
52 | depth = make_grid(model_output["depth"].flatten(0,1).permute(0,2,1).unflatten(-1,(imsly,imslx)).cpu(),normalize=True,nrow=nrow)
53 |
54 | rgb_pred = make_grid(model_output["rgb"].flatten(0,1).permute(0,2,1).unflatten(-1,(imsly,imslx)),normalize=True,nrow=nrow)
55 | rgb_gt = make_grid(ground_truth["trgt_rgb"].flatten(0,1).permute(0,2,1).unflatten(-1,(imsly,imslx)),normalize=True,nrow=nrow)
56 | ctxt_img = make_grid(model_input["ctxt_rgb"].cpu().flatten(0,1),normalize=True,nrow=nrow)
57 |
58 | print("add psnr metric here")
59 |
60 | wandb_out = {
61 | "est/rgb_pred": rgb_pred,
62 | "ref/rgb_gt": rgb_gt,
63 | "ref/ctxt_img": ctxt_img,
64 | "est/depth": depth,
65 | }
66 | if "trgt_depth" in ground_truth:
67 | wandb_out["depthgt"]=make_grid(ground_truth["trgt_depth"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,(imsly,imslx)).cpu(),normalize=True,nrow=nrow)
68 |
69 | for k,v in wandb_out.items(): print(k,v.max(),v.min(),v.shape)
70 | #for k,v in wandb_out.items():plt.imsave("/nobackup/users/camsmith/img/%s.png"%k,v.permute(1,2,0).detach().cpu().numpy().clip(0,1));
71 | wandb.log({"sanity/"+k+"_min":v.min() for k,v in wandb_out.items()})
72 | wandb.log({"sanity/"+k+"_max":v.max() for k,v in wandb_out.items()})
73 | wandb_out = {prefix+k:wandb.Image(v.permute(1, 2, 0).detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()}
74 |
75 | wandb.log(wandb_out)
76 |
77 | #def dyn_wandb_summary(loss, model_output, model_input, ground_truth, resolution,prefix=""):
78 | def wandb_summary(loss, model_output, model_input, ground_truth, resolution,prefix=""):
79 |
80 | resolution = list(model_input["ctxt_rgb"].flatten(0,1).permute(0,2,3,1).shape)
81 | nrow=model_input["trgt_rgb"].size(1)
82 | imsl=model_input["ctxt_rgb"].shape[-2:]
83 | inv = lambda x : 1/(x+1e-8)
84 |
85 | depth = make_grid(model_output["depth"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
86 |
87 | wandb_out = {
88 | "est/rgb_pred": make_grid(model_output["rgb"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow),
89 | "ref/rgb_gt": make_grid(ground_truth["trgt_rgb"].cpu().view(*resolution).detach().permute(0, -1, 1, 2),nrow=nrow),
90 | "ref/ctxt_img": make_grid(model_input["ctxt_rgb"][:,0].cpu().detach(),nrow=1)*.5+.5,
91 | "est/depth": depth,
92 | "est/depth_1ch":make_grid(model_output["depth_raw"].flatten(0,1).permute(0,2,1).unflatten(-1,imsl).cpu(),normalize=True,nrow=nrow),
93 | }
94 |
95 | depthgt = (ground_truth["trgt_depth"] if "trgt_depth" in ground_truth else model_output["trgt_depth_inp"] if "trgt_depth_inp" in model_output
96 | else model_input["trgt_depth"] if "trgt_depth" in model_input else None)
97 | if depthgt is not None:
98 | depthgt = make_grid(inv(depthgt).cpu().view(*resolution[:3]).detach().unsqueeze(1),normalize=True,nrow=nrow)
99 | wandb_out["ref/depthgt"]= depthgt
100 |
101 | if "fine_rgb" in model_output:
102 | wandb_out["est/fine_rgb_pred"] = make_grid(model_output["fine_rgb"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
103 | wandb_out["est/fine_depth_pred"] = make_grid(model_output["fine_depth"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow,normalize=True)
104 |
105 | if "ctxt_rgb" in model_output:
106 | wandb_out["est/ctxt_depth"] =make_grid(model_output["ctxt_depth"].cpu().flatten(0,1).permute(0,2,1).unflatten(-1,imsl).detach(),nrow=nrow)
107 | wandb_out["est/ctxt_rgb_pred"] = ctxt_rgb_pred = make_grid(model_output["ctxt_rgb"].cpu().view(*resolution).detach().permute(0, -1, 1, 2),nrow=nrow)
108 |
109 | if "corr_weights" in model_output:
110 | #corr_weights = make_grid(model_output["corr_weights"].flatten(0,1)[:,:1].cpu().detach(),normalize=False,nrow=nrow)
111 | corr_weights = make_grid(ch_fst(model_output["corr_weights"],resolution[1]).flatten(0,1)[:,:1].cpu().detach(),normalize=False,nrow=nrow)
112 | wandb_out["est/corr_weights"] = corr_weights
113 |
114 | if "flow" in model_output and not torch.isnan(model_output["flow"]).any() and not torch.isnan(model_output["flow_from_pose"]).any():
115 | psnr = piqa.PSNR()(ch_fst(model_output["rgb"],imsl[0]).flatten(0,1).contiguous(),ch_fst(ground_truth["trgt_rgb"],imsl[0]).flatten(0,1).contiguous())
116 | wandb.log({prefix+"metrics/psnr": psnr})
117 |
118 | gt_flow_bwd = flow_vis_torch.flow_to_color(make_grid(model_output["flow_inp"].flatten(0,1),nrow=nrow))/255
119 | est_flow = flow_vis_torch.flow_to_color(make_grid(model_output["flow"].flatten(0,1).permute(0,2,1).unflatten(-1,imsl),nrow=nrow))/255
120 | wandb_out["est/flow_est"]= est_flow
121 | wandb_out["ref/flow_gt_bwd"]=gt_flow_bwd
122 | if "flow_from_pose" in model_output:
123 | wandb_out["est/flow_est_pose"] = flow_vis_torch.flow_to_color(make_grid(model_output["flow_from_pose"].flatten(0,1).permute(0,2,1).unflatten(-1,imsl),nrow=nrow))/255
124 | elif "flow" in model_output:
125 | print("skipping nan flow")
126 | for k,v in wandb_out.items(): print(k,v.max(),v.min())
127 | #for k,v in wandb_out.items():plt.imsave("/nobackup/users/camsmith/img/%s.png"%k,v.permute(1,2,0).detach().cpu().numpy().clip(0,1));
128 | #zz
129 | #wandb.log({"sanity/"+k+"_min":v.min() for k,v in wandb_out.items()})
130 | #wandb.log({"sanity/"+k+"_max":v.max() for k,v in wandb_out.items()})
131 | #for k,v in wandb_out.items(): print(v.shape)
132 | wandb_out = {prefix+k:wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in wandb_out.items()}
133 |
134 | wandb.log(wandb_out)
135 |
--------------------------------------------------------------------------------