├── CONTRIBUTING.md
├── images
└── tensorboard.png
├── ibrnet
├── data_loaders
│ ├── __init__.py
│ ├── flow_utils.py
│ ├── create_training_dataset.py
│ ├── data_utils.py
│ ├── llff_data_utils.py
│ └── monocular.py
├── criterion.py
├── projection.py
├── feature_network.py
├── sample_ray.py
├── render_image.py
├── model.py
└── mlp_network.py
├── configs_nvidia
├── eval_truck_long.txt
├── eval_jumping_long.txt
├── eval_skating_long.txt
├── eval_balloon1_long.txt
├── eval_balloon2_long.txt
├── eval_umbrella_long.txt
├── eval_dynamicFace_long.txt
└── eval_playground_long.txt
├── environment_dynibar.yml
├── configs
├── test_kid-running.txt
└── train_kid-running.txt
├── utils.py
├── save_monocular_cameras.py
├── README.md
├── config.py
├── render_source_vv.py
├── LICENSE
├── render_monocular_bt.py
└── eval_nvidia.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/images/tensorboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/dynibar/HEAD/images/tensorboard.png
--------------------------------------------------------------------------------
/ibrnet/data_loaders/__init__.py:
--------------------------------------------------------------------------------
1 | """Defining a dictionary of dataset class."""
2 |
3 | from .monocular import MonocularDataset
4 |
5 | dataset_dict = {
6 | 'monocular': MonocularDataset,
7 | }
8 |
--------------------------------------------------------------------------------
/configs_nvidia/eval_truck_long.txt:
--------------------------------------------------------------------------------
1 | expname = truck
2 |
3 | rootdir = /home/zhengqili/dynibar
4 |
5 | folder_path = /home/zhengqili/nvidia_long_release
6 |
7 | coarse_dir = checkpoints/coarse/truck
8 |
9 | distributed = False
10 |
11 | ## dataset
12 | eval_dataset = Nvidia
13 | eval_scenes = Truck
14 | ### TESTING
15 | chunk_size = 8192
16 |
17 | ### RENDERING
18 | N_importance = 64
19 | N_samples = 64
20 | inv_uniform = True
21 | anti_alias_pooling = 1
22 | mask_rgb = 0
23 |
24 | input_dir = True
25 | input_xyz = False
26 |
27 | mask_static = True
--------------------------------------------------------------------------------
/configs_nvidia/eval_jumping_long.txt:
--------------------------------------------------------------------------------
1 | expname = jumping
2 |
3 | rootdir = /home/zhengqili/dynibar
4 |
5 | folder_path = /home/zhengqili/nvidia_long_release
6 |
7 | coarse_dir = checkpoints/coarse/jumping
8 |
9 | distributed = False
10 |
11 | ## dataset
12 | eval_dataset = Nvidia
13 | eval_scenes = Jumping
14 | ### TESTING
15 | chunk_size = 8192
16 |
17 | ### RENDERING
18 | N_importance = 64
19 | N_samples = 64
20 | inv_uniform = True
21 | anti_alias_pooling = 1
22 | mask_rgb = 0
23 |
24 | input_dir = True
25 | input_xyz = False
26 |
27 | mask_static = True
--------------------------------------------------------------------------------
/configs_nvidia/eval_skating_long.txt:
--------------------------------------------------------------------------------
1 | expname = skating
2 |
3 | rootdir = /home/zhengqili/dynibar
4 |
5 | folder_path = /home/zhengqili/nvidia_long_release
6 |
7 | coarse_dir = checkpoints/coarse/skating
8 |
9 | distributed = False
10 |
11 | ## dataset
12 | eval_dataset = Nvidia
13 | eval_scenes = Skating
14 | ### TESTING
15 | chunk_size = 8192
16 |
17 | ### RENDERING
18 | N_importance = 64
19 | N_samples = 64
20 | inv_uniform = True
21 | anti_alias_pooling = 1
22 | mask_rgb = 0
23 |
24 | input_dir = True
25 | input_xyz = False
26 |
27 | mask_static = True
--------------------------------------------------------------------------------
/configs_nvidia/eval_balloon1_long.txt:
--------------------------------------------------------------------------------
1 | expname = balloon1
2 |
3 | rootdir = /home/zhengqili/dynibar
4 |
5 | folder_path = /home/zhengqili/nvidia_long_release
6 |
7 | coarse_dir = checkpoints/coarse/balloon1
8 |
9 | distributed = False
10 |
11 | ## dataset
12 | eval_dataset = Nvidia
13 | eval_scenes = Balloon1
14 | ### TESTING
15 | chunk_size = 8192
16 |
17 | ### RENDERING
18 | N_importance = 64
19 | N_samples = 64
20 | inv_uniform = True
21 | anti_alias_pooling = 1
22 | mask_rgb = 0
23 |
24 | input_dir = True
25 | input_xyz = False
26 |
27 | mask_static = True
--------------------------------------------------------------------------------
/configs_nvidia/eval_balloon2_long.txt:
--------------------------------------------------------------------------------
1 | expname = balloon2
2 |
3 | rootdir = /home/zhengqili/dynibar
4 |
5 | folder_path = /home/zhengqili/nvidia_long_release
6 |
7 | coarse_dir = checkpoints/coarse/balloon2
8 |
9 | distributed = False
10 |
11 | ## dataset
12 | eval_dataset = Nvidia
13 | eval_scenes = Balloon2
14 | ### TESTING
15 | chunk_size = 8192
16 |
17 | ### RENDERING
18 | N_importance = 64
19 | N_samples = 64
20 | inv_uniform = True
21 | anti_alias_pooling = 1
22 | mask_rgb = 0
23 |
24 | input_dir = True
25 | input_xyz = False
26 |
27 | mask_static = True
--------------------------------------------------------------------------------
/configs_nvidia/eval_umbrella_long.txt:
--------------------------------------------------------------------------------
1 | expname = umbrella
2 |
3 | rootdir = /home/zhengqili/dynibar
4 |
5 | folder_path = /home/zhengqili/nvidia_long_release
6 |
7 | coarse_dir = checkpoints/coarse/umbrella
8 |
9 | distributed = False
10 |
11 | ## dataset
12 | eval_dataset = Nvidia
13 | eval_scenes = Umbrella
14 | ### TESTING
15 | chunk_size = 8192
16 |
17 | ### RENDERING
18 | N_importance = 64
19 | N_samples = 64
20 | inv_uniform = True
21 | anti_alias_pooling = 1
22 | mask_rgb = 0
23 |
24 | input_dir = True
25 | input_xyz = False
26 |
27 | mask_static = True
--------------------------------------------------------------------------------
/configs_nvidia/eval_dynamicFace_long.txt:
--------------------------------------------------------------------------------
1 | expname = dynamicFace
2 |
3 | rootdir = /home/zhengqili/dynibar
4 |
5 | folder_path = /home/zhengqili/nvidia_long_release
6 |
7 | coarse_dir = checkpoints/coarse/dynamicFace
8 |
9 | distributed = False
10 |
11 | ## dataset
12 | eval_dataset = Nvidia
13 | eval_scenes = dynamicFace
14 | ### TESTING
15 | chunk_size = 8192
16 |
17 | ### RENDERING
18 | N_importance = 64
19 | N_samples = 64
20 | inv_uniform = True
21 | anti_alias_pooling = 1
22 | mask_rgb = 0
23 |
24 | input_dir = True
25 | input_xyz = False
26 |
27 | mask_static = True
--------------------------------------------------------------------------------
/configs_nvidia/eval_playground_long.txt:
--------------------------------------------------------------------------------
1 | expname = playground
2 |
3 | rootdir = /home/zhengqili/dynibar
4 |
5 | folder_path = /home/zhengqili/nvidia_long_release
6 |
7 | coarse_dir = checkpoints/coarse/playground
8 |
9 | distributed = False
10 |
11 | ## dataset
12 | eval_dataset = Nvidia
13 | eval_scenes = Playground
14 | ### TESTING
15 | chunk_size = 8192
16 |
17 | ### RENDERING
18 | N_importance = 64
19 | N_samples = 64
20 | inv_uniform = True
21 | anti_alias_pooling = 1
22 | mask_rgb = 0
23 |
24 | input_dir = True
25 | input_xyz = False
26 |
27 | mask_static = True
--------------------------------------------------------------------------------
/environment_dynibar.yml:
--------------------------------------------------------------------------------
1 | name: dynibar
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - python=3.8
8 | - pip=20.3
9 | - conda-forge::cudatoolkit=11.3
10 | - pytorch::pytorch=1.10.1
11 | - pytorch::torchvision==0.11.2
12 | - pip:
13 | - configargparse
14 | - scikit-image==0.19.3
15 | - matplotlib
16 | - opencv-python
17 | - torch_efficient_distloss
18 | - imageio==2.22.0
19 | - tensorboard==2.10.0
20 | - scipy==1.9.1
21 | - timm==0.6.7
22 | - kornia==0.6.7
23 | - ninja==1.11.1
24 | - setuptools==59.5.0
25 |
--------------------------------------------------------------------------------
/configs/test_kid-running.txt:
--------------------------------------------------------------------------------
1 | # make sure expname is the saved folder name in 'out' directory
2 | expname = kid-running-test_mr-42_w-disp-0.100_w-flow-0.010_anneal_cycle-0.1-0.1-w_mode-0
3 |
4 | rootdir = /home/zhengqili/dynibar
5 |
6 | folder_path = /home/zhengqili/release
7 |
8 | distributed = False
9 |
10 | ## dataset
11 | eval_dataset = dynamic-test
12 | eval_scenes = kid-running
13 | ### TESTING
14 | chunk_size = 8192
15 |
16 | ### RENDERING
17 | N_importance = 64
18 | N_samples = 64
19 | inv_uniform = True
20 | white_bkgd = False
21 |
22 | anti_alias_pooling = 0
23 | mask_rgb = 1
24 | input_dir = True
25 | input_xyz = False
26 |
27 | training_height = 288
28 |
29 | max_range = 40
30 | num_source_views = 7
31 |
32 | render_idx = 30
33 |
34 | mask_src_view = True
35 | num_vv = 3
36 |
--------------------------------------------------------------------------------
/configs/train_kid-running.txt:
--------------------------------------------------------------------------------
1 | expname = kid-running-test
2 |
3 | rootdir = /home/zhengqili/dynibar
4 |
5 | folder_path = /home/zhengqili/release
6 |
7 | no_reload = False
8 | render_stride = 1
9 | distributed = False
10 | no_load_opt = True
11 | no_load_scheduler = True
12 | n_iters = 400000
13 |
14 | ## dataset
15 | train_dataset = monocular
16 | train_scenes = kid-running
17 | eval_dataset = monocular
18 | eval_scenes = kid-running
19 |
20 | ### TRAINING
21 | N_rand = 3072
22 | lrate_feature = 8e-4
23 | lrate_mlp = 4e-4
24 | lrate_decay_factor = 0.5
25 | init_decay_epoch = 400 # modify this s.t. num_imgs * num_epoch ~= 30-40K
26 |
27 | ### TESTING
28 | chunk_size = 8192
29 |
30 | ### RENDERING
31 | N_importance = 0
32 | N_samples = 64
33 | inv_uniform = True
34 | white_bkgd = False
35 |
36 | ### CONSOLE AND TENSORBOARD
37 | i_img = 5000
38 | i_print = 5000
39 | i_weights = 10000
40 |
41 | anti_alias_pooling = 0
42 | mask_rgb = 1
43 | input_dir = True
44 | input_xyz = False
45 |
46 | training_height = 288
47 |
48 | w_cycle = 0.1
49 | cycle_factor = 0.1
50 |
51 | w_disp = 1e-1
52 | w_flow = 1e-2
53 | w_distortion = 1e-3
54 | w_reg = 0.05
55 |
56 | w_skew_entropy = 5e-4
57 | lr_multipler = 1.0
58 |
59 | decay_rate = 10
60 | anneal_cycle = True
61 |
62 | erosion_radius = 3
63 | occ_weights_mode = 0
64 |
65 | max_range = 42
66 | num_source_views = 7
67 |
68 | num_vv = 3
69 | mask_src_view = True
70 |
--------------------------------------------------------------------------------
/ibrnet/criterion.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch.nn as nn
16 | import torch
17 | from utils import img2charbonier
18 |
19 | EPSILON = 0.001
20 |
21 | class Criterion(nn.Module):
22 | def __init__(self):
23 | super().__init__()
24 |
25 | def forward(self, outputs, ray_batch, motion_mask=None):
26 | '''
27 | training criterion
28 | '''
29 | pred_rgb = outputs['rgb']
30 | pred_mask = outputs['mask'].float()
31 | gt_rgb = ray_batch['rgb']
32 |
33 | if motion_mask is not None:
34 | pred_mask = pred_mask * motion_mask.float()
35 |
36 | loss = img2charbonier(pred_rgb, gt_rgb, pred_mask, EPSILON)
37 |
38 | return loss
39 |
40 |
41 |
42 | def compute_temporal_rgb_loss(outputs, ray_batch, motion_mask=None):
43 | pred_rgb = outputs['rgb']
44 | gt_rgb = ray_batch['rgb']
45 |
46 | occ_weight_map = outputs['occ_weight_map']
47 | pred_mask = outputs['mask'].float()
48 |
49 | if motion_mask is not None:
50 | pred_mask = pred_mask * motion_mask
51 |
52 | final_w = pred_mask * occ_weight_map
53 | final_w = final_w.unsqueeze(-1).repeat(1, 3)
54 |
55 | loss = torch.sum(final_w * torch.sqrt((pred_rgb - gt_rgb)**2 + EPSILON**2) ) / (torch.sum(final_w) + 1e-8)
56 | return loss
57 |
58 | def compute_rgb_loss(pred_rgb, ray_batch, pred_mask):
59 | gt_rgb = ray_batch['rgb']
60 | loss = img2charbonier(pred_rgb, gt_rgb, pred_mask, EPSILON)
61 |
62 | return loss
63 |
64 | # def compute_mask_ssi_depth_loss(pred_depth, gt_depth, mask):
65 | # t_pred = torch.median(pred_depth)
66 | # s_pred = torch.mean(torch.abs(pred_depth - t_pred))
67 |
68 | # t_gt = torch.median(gt_depth)
69 | # s_gt = torch.mean(torch.abs(gt_depth - t_gt))
70 |
71 | # pred_depth_n = (pred_depth - t_pred) / s_pred
72 | # gt_depth_n = (gt_depth - t_gt) / s_gt
73 |
74 | # num_pixel = torch.sum(mask) + 1e-8
75 |
76 | # return torch.sum(torch.abs(pred_depth_n - gt_depth_n) * mask)/num_pixel
77 |
78 |
79 | def compute_entropy(x):
80 | return -torch.mean(x * torch.log(x + 1e-8))
81 |
82 |
83 | def compute_flow_loss(render_flow, gt_flow, gt_mask):
84 | gt_mask_rep = gt_mask.repeat(1, 1, 2)
85 | return torch.sum(torch.abs(render_flow - gt_flow) * gt_mask_rep) / (torch.sum(gt_mask_rep) + 1e-8)
86 |
--------------------------------------------------------------------------------
/ibrnet/data_loaders/flow_utils.py:
--------------------------------------------------------------------------------
1 | """Optical flow helper functions."""
2 |
3 | import numpy as np
4 | import cv2
5 |
6 | def warp_flow(img, flow):
7 | h, w = flow.shape[:2]
8 | flow_new = flow.copy()
9 | flow_new[:,:,0] += np.arange(w)
10 | flow_new[:,:,1] += np.arange(h)[:,np.newaxis]
11 |
12 | res = cv2.remap(img, flow_new, None,
13 | cv2.INTER_LINEAR,
14 | borderMode=cv2.BORDER_CONSTANT)
15 | return res
16 |
17 | def make_color_wheel():
18 | """
19 | Generate color wheel according Middlebury color code
20 | :return: Color wheel
21 | """
22 | RY = 15
23 | YG = 6
24 | GC = 4
25 | CB = 11
26 | BM = 13
27 | MR = 6
28 |
29 | ncols = RY + YG + GC + CB + BM + MR
30 |
31 | colorwheel = np.zeros([ncols, 3])
32 |
33 | col = 0
34 |
35 | # RY
36 | colorwheel[0:RY, 0] = 255
37 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
38 | col += RY
39 |
40 | # YG
41 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
42 | colorwheel[col:col+YG, 1] = 255
43 | col += YG
44 |
45 | # GC
46 | colorwheel[col:col+GC, 1] = 255
47 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
48 | col += GC
49 |
50 | # CB
51 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
52 | colorwheel[col:col+CB, 2] = 255
53 | col += CB
54 |
55 | # BM
56 | colorwheel[col:col+BM, 2] = 255
57 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
58 | col += + BM
59 |
60 | # MR
61 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
62 | colorwheel[col:col+MR, 0] = 255
63 |
64 | return colorwheel
65 |
66 |
67 | def compute_color(u, v):
68 | """
69 | compute optical flow color map
70 | :param u: optical flow horizontal map
71 | :param v: optical flow vertical map
72 | :return: optical flow in color code
73 | """
74 | [h, w] = u.shape
75 | img = np.zeros([h, w, 3])
76 | nanIdx = np.isnan(u) | np.isnan(v)
77 | u[nanIdx] = 0
78 | v[nanIdx] = 0
79 |
80 | colorwheel = make_color_wheel()
81 | ncols = np.size(colorwheel, 0)
82 |
83 | rad = np.sqrt(u**2+v**2)
84 |
85 | a = np.arctan2(-v, -u) / np.pi
86 |
87 | fk = (a+1) / 2 * (ncols - 1) + 1
88 |
89 | k0 = np.floor(fk).astype(int)
90 |
91 | k1 = k0 + 1
92 | k1[k1 == ncols+1] = 1
93 | f = fk - k0
94 |
95 | for i in range(0, np.size(colorwheel,1)):
96 | tmp = colorwheel[:, i]
97 | col0 = tmp[k0-1] / 255
98 | col1 = tmp[k1-1] / 255
99 | col = (1-f) * col0 + f * col1
100 |
101 | idx = rad <= 1
102 | col[idx] = 1-rad[idx]*(1-col[idx])
103 | notidx = np.logical_not(idx)
104 |
105 | col[notidx] *= 0.75
106 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
107 |
108 | return img
109 |
110 |
111 |
112 | def flow_to_image(flow, display=False):
113 | """
114 | Convert flow into middlebury color code image
115 | :param flow: optical flow map
116 | :return: optical flow image in middlebury color
117 | """
118 | UNKNOWN_FLOW_THRESH = 200
119 | u = flow[:, :, 0]
120 | v = flow[:, :, 1]
121 |
122 | maxu = -999.
123 | maxv = -999.
124 | minu = 999.
125 | minv = 999.
126 |
127 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
128 | u[idxUnknow] = 0
129 | v[idxUnknow] = 0
130 |
131 | maxu = max(maxu, np.max(u))
132 | minu = min(minu, np.min(u))
133 |
134 | maxv = max(maxv, np.max(v))
135 | minv = min(minv, np.min(v))
136 |
137 | # sqrt_rad = u**2 + v**2
138 | rad = np.sqrt(u**2 + v**2)
139 |
140 | maxrad = max(-1, np.max(rad))
141 |
142 | if display:
143 | print("max flow: %.4f\nflow range:\nu = %.3f .. %.3f\nv = %.3f .. %.3f" % (maxrad, minu,maxu, minv, maxv))
144 |
145 | u = u/(maxrad + np.finfo(float).eps)
146 | v = v/(maxrad + np.finfo(float).eps)
147 |
148 | img = compute_color(u, v)
149 |
150 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
151 | img[idx] = 0
152 |
153 | return np.uint8(img)
--------------------------------------------------------------------------------
/ibrnet/data_loaders/create_training_dataset.py:
--------------------------------------------------------------------------------
1 | """Class definition of data sampler."""
2 |
3 | from operator import itemgetter
4 | from typing import Optional
5 |
6 | import numpy as np
7 | import torch
8 | from torch.utils.data import Dataset
9 | from torch.utils.data import DistributedSampler
10 | from torch.utils.data import Sampler
11 | from torch.utils.data import WeightedRandomSampler
12 |
13 | from . import dataset_dict
14 |
15 |
16 | class DatasetFromSampler(Dataset):
17 | """Dataset to create indexes from `Sampler`."""
18 |
19 | def __init__(self, sampler: Sampler):
20 | """Initialisation for DatasetFromSampler."""
21 | self.sampler = sampler
22 | self.sampler_list = None
23 |
24 | def __getitem__(self, index: int):
25 | """Gets element of the dataset.
26 |
27 | Args:
28 | index: index of the element in the dataset
29 |
30 | Returns:
31 | Single element by index
32 | """
33 | if self.sampler_list is None:
34 | self.sampler_list = list(self.sampler)
35 | return self.sampler_list[index]
36 |
37 | def __len__(self) -> int:
38 | return len(self.sampler)
39 |
40 |
41 | class DistributedSamplerWrapper(DistributedSampler):
42 | """Wrapper over `Sampler` for distributed training.
43 |
44 | Allows you to use any sampler in distributed mode. It is especially useful in
45 | conjunction with `torch.nn.parallel.DistributedDataParallel`. In such case,
46 | each process can pass a DistributedSamplerWrapper instance as a DataLoader
47 | sampler, and load a subset of subsampled data of the original dataset that is
48 | exclusive to it. .. note::
49 |
50 | Sampler is assumed to be of constant size.
51 | """
52 |
53 | def __init__(
54 | self,
55 | sampler,
56 | num_replicas: Optional[int] = None,
57 | rank: Optional[int] = None,
58 | shuffle: bool = True,
59 | ):
60 | super(DistributedSamplerWrapper, self).__init__(
61 | DatasetFromSampler(sampler),
62 | num_replicas=num_replicas,
63 | rank=rank,
64 | shuffle=shuffle,
65 | )
66 | self.sampler = sampler
67 |
68 | def __iter__(self):
69 | self.dataset = DatasetFromSampler(self.sampler)
70 | indexes_of_indexes = super().__iter__()
71 | subsampler_indexes = self.dataset
72 | return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
73 |
74 |
75 | def create_training_dataset(args):
76 | """Creating training dataset.
77 |
78 | Args:
79 | args: input argument
80 |
81 | Returns:
82 | train_dataset: training dataset
83 | train_sampler: training sampler
84 | """
85 | # parse args.train_dataset, "+" indicates that multiple datasets are used,
86 | # for example "ibrnet_collect+llff+spaces"
87 | # otherwise only one dataset is used
88 |
89 | print('training dataset: {}'.format(args.train_dataset))
90 |
91 | mode = 'train'
92 | if '+' not in args.train_dataset:
93 | train_dataset = dataset_dict[args.train_dataset](
94 | args, mode, scenes=args.train_scenes
95 | )
96 | train_sampler = (
97 | torch.utils.data.distributed.DistributedSampler(train_dataset)
98 | if args.distributed
99 | else None
100 | )
101 | else:
102 | train_dataset_names = args.train_dataset.split('+')
103 | weights = args.dataset_weights
104 | assert len(train_dataset_names) == len(weights)
105 | assert np.abs(np.sum(weights) - 1.0) < 1e-6
106 | print('weights:{}'.format(weights))
107 | train_datasets = []
108 | train_weights_samples = []
109 | for training_dataset_name, weight in zip(train_dataset_names, weights):
110 | train_dataset = dataset_dict[training_dataset_name](
111 | args,
112 | mode,
113 | scenes=args.train_scenes,
114 | )
115 | train_datasets.append(train_dataset)
116 | num_samples = len(train_dataset)
117 | weight_each_sample = weight / num_samples
118 | train_weights_samples.extend([weight_each_sample] * num_samples)
119 |
120 | train_dataset = torch.utils.data.ConcatDataset(train_datasets)
121 | train_weights = torch.from_numpy(np.array(train_weights_samples))
122 | sampler = WeightedRandomSampler(train_weights, len(train_weights))
123 | train_sampler = (
124 | DistributedSamplerWrapper(sampler) if args.distributed else sampler
125 | )
126 |
127 | return train_dataset, train_sampler
128 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions."""
2 |
3 | import cv2
4 | import matplotlib as mpl
5 | from matplotlib import cm
6 | from matplotlib.backends.backend_agg import FigureCanvasAgg
7 | from matplotlib.figure import Figure
8 | import numpy as np
9 | import torch
10 |
11 | HUGE_NUMBER = 1e10
12 | TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision
13 |
14 | img_HWC2CHW = lambda x: x.permute(2, 0, 1)
15 | gray2rgb = lambda x: x.unsqueeze(2).repeat(1, 1, 3)
16 |
17 |
18 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
19 | mse2psnr = lambda x: -10.0 * np.log(x + TINY_NUMBER) / np.log(10.0)
20 |
21 |
22 | def img2mse(x, y, mask=None):
23 | """MSE between two images."""
24 | if mask is None:
25 | return torch.mean((x - y) * (x - y))
26 | else:
27 | return torch.sum((x - y) * (x - y) * mask.unsqueeze(-1)) / (
28 | torch.sum(mask) * x.shape[-1] + TINY_NUMBER
29 | )
30 |
31 |
32 | def img2charbonier(x, y, mask=None, eps=0.001):
33 | """Charbonier loss between two images."""
34 | if mask is None:
35 | return torch.mean(torch.sqrt((x - y) ** 2 + eps**2))
36 | else:
37 | return torch.sum(
38 | torch.sqrt((x - y) ** 2 + eps**2) * mask.unsqueeze(-1)
39 | ) / (torch.sum(mask) * x.shape[-1] + TINY_NUMBER)
40 |
41 |
42 | def img2psnr(x, y, mask=None):
43 | return mse2psnr(img2mse(x, y, mask).item())
44 |
45 |
46 | def cycle(iterable):
47 | while True:
48 | for x in iterable:
49 | yield x
50 |
51 |
52 | def get_vertical_colorbar(
53 | h, vmin, vmax, cmap_name='jet', label=None, cbar_precision=2
54 | ):
55 | """Get colorbar."""
56 | fig = Figure(figsize=(2, 8), dpi=100)
57 | fig.subplots_adjust(right=1.5)
58 | canvas = FigureCanvasAgg(fig)
59 |
60 | # Do some plotting.
61 | ax = fig.add_subplot(111)
62 | cmap = cm.get_cmap(cmap_name)
63 | norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
64 |
65 | tick_cnt = 6
66 | tick_loc = np.linspace(vmin, vmax, tick_cnt)
67 | cb1 = mpl.colorbar.ColorbarBase(
68 | ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation='vertical'
69 | )
70 |
71 | tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc]
72 | if cbar_precision == 0:
73 | tick_label = [x[:-2] for x in tick_label]
74 |
75 | cb1.set_ticklabels(tick_label)
76 |
77 | cb1.ax.tick_params(labelsize=18, rotation=0)
78 |
79 | if label is not None:
80 | cb1.set_label(label)
81 |
82 | fig.tight_layout()
83 |
84 | canvas.draw()
85 | s, (width, height) = canvas.print_to_buffer()
86 |
87 | im = np.frombuffer(s, np.uint8).reshape((height, width, 4))
88 |
89 | im = im[:, :, :3].astype(np.float32) / 255.0
90 | if h != im.shape[0]:
91 | w = int(im.shape[1] / im.shape[0] * h)
92 | im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)
93 |
94 | return im
95 |
96 |
97 | def colorize_np(
98 | x,
99 | cmap_name='jet',
100 | mask=None,
101 | range=None,
102 | append_cbar=False,
103 | cbar_in_image=False,
104 | cbar_precision=2,
105 | ):
106 | """turn a grayscale image into a color image."""
107 | if range is not None:
108 | vmin, vmax = range
109 | elif mask is not None:
110 | # vmin, vmax = np.percentile(x[mask], (2, 100))
111 | vmin = np.min(x[mask][np.nonzero(x[mask])])
112 | vmax = np.max(x[mask])
113 | # vmin = vmin - np.abs(vmin) * 0.01
114 | x[np.logical_not(mask)] = vmin
115 | # print(vmin, vmax)
116 | else:
117 | vmin, vmax = np.percentile(x, (1, 99))
118 | vmax += TINY_NUMBER
119 |
120 | x = np.clip(x, vmin, vmax)
121 | x = (x - vmin) / (vmax - vmin)
122 | x = np.clip(x, 0.0, 1.0)
123 |
124 | cmap = cm.get_cmap(cmap_name)
125 | x_new = cmap(x)[:, :, :3]
126 |
127 | if mask is not None:
128 | mask = np.float32(mask[:, :, np.newaxis])
129 | x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask)
130 |
131 | cbar = get_vertical_colorbar(
132 | h=x.shape[0],
133 | vmin=vmin,
134 | vmax=vmax,
135 | cmap_name=cmap_name,
136 | cbar_precision=cbar_precision,
137 | )
138 |
139 | if append_cbar:
140 | if cbar_in_image:
141 | x_new[:, -cbar.shape[1] :, :] = cbar
142 | else:
143 | x_new = np.concatenate(
144 | (x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1
145 | )
146 | return x_new
147 | else:
148 | return x_new
149 |
150 |
151 | # tensor
152 | def colorize(
153 | x,
154 | cmap_name='jet',
155 | mask=None,
156 | range=None,
157 | append_cbar=False,
158 | cbar_in_image=False,
159 | ):
160 | """Convert gray scale image such as depth to RGB image."""
161 | device = x.device
162 | x = x.cpu().numpy()
163 | if mask is not None:
164 | mask = mask.cpu().numpy() > 0.99
165 | kernel = np.ones((3, 3), np.uint8)
166 | mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
167 |
168 | x = colorize_np(x, cmap_name, mask, range, append_cbar, cbar_in_image)
169 | x = torch.from_numpy(x).to(device)
170 | return x
171 |
--------------------------------------------------------------------------------
/save_monocular_cameras.py:
--------------------------------------------------------------------------------
1 | """Save images, depth, flow and mask data into dynibar input format."""
2 |
3 | '''
4 |
15 |
16 |
17 | '''
18 |
19 |
20 | import argparse
21 | import glob
22 | import os
23 | import cv2
24 | import imageio
25 | import numpy as np
26 |
27 |
28 | SAVE_IMG = True
29 | FINAL_H = 288
30 |
31 | if __name__ == '__main__':
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--cvd_dir', type=str, help='depth directory')
34 | parser.add_argument('--data_dir', type=str, help='dataset directory')
35 | # parser.add_argument("--scene_name", type=str,
36 | # help='Scene name') # 'kid-running'
37 | args = parser.parse_args()
38 |
39 | pt_out_list = sorted(glob.glob(os.path.join(args.cvd_dir, '*.npz')))
40 | data_dir = os.path.join(args.data_dir, 'dense')
41 |
42 | try:
43 | original_img_path = os.path.join(data_dir, 'images', '00000.png')
44 | o_img = imageio.imread(original_img_path)
45 | except:
46 | original_img_path = os.path.join(data_dir, 'images', '00000.jpg')
47 | o_img = imageio.imread(original_img_path)
48 |
49 | o_ar = float(o_img.shape[1]) / float(o_img.shape[0])
50 |
51 | final_w, final_h = int(round(FINAL_H * o_ar)), int(FINAL_H)
52 |
53 | img_dir = os.path.join(data_dir, 'images_%dx%d' % (final_w, final_h))
54 | os.makedirs(img_dir, exist_ok=True)
55 | print('img_dir ', img_dir)
56 | disp_dir = os.path.join(data_dir, 'disp')
57 | os.makedirs(disp_dir, exist_ok=True)
58 |
59 | Ks = []
60 | mono_depths = []
61 | c2w_mats = []
62 | imgs = []
63 | bounds_mats = []
64 |
65 | for i, pt_out_path in enumerate(pt_out_list):
66 | print(i)
67 | out_name = pt_out_path.split('/')[-1]
68 | pt_data = np.load(pt_out_path)
69 |
70 | img = pt_data['img_1'][0].transpose(1, 2, 0)
71 | pred_depth = pt_data['depth'][0, 0, ...]
72 | pred_disp = 1.0 / pred_depth
73 | K = pt_data['K'][0, 0, 0, ...].transpose()
74 | img = pt_data['img_1'][0].transpose(1, 2, 0)
75 | cam_c2w = pt_data['cam_c2w'][0]
76 |
77 | K[0, :] *= final_w / img.shape[1]
78 | K[1, :] *= final_h / img.shape[0]
79 |
80 | print('K ', K, abs(K[0, 0] - K[1, 1]) / (K[1, 1] + K[0, 0]))
81 | assert (
82 | abs(K[0, 0] - K[1, 1]) / (K[1, 1] + K[0, 0]) < 0.005
83 | ) # we assume fx ~= fy
84 |
85 | original_img_path = os.path.join(
86 | data_dir, 'images', '%05d.png' % int(out_name[5:9])
87 | )
88 | o_img = imageio.imread(original_img_path)
89 | print(o_img.shape, final_w, final_h)
90 | img_resized = cv2.resize(
91 | o_img, (final_w, final_h), interpolation=cv2.INTER_AREA
92 | )
93 | pred_disp_resized = cv2.resize(
94 | pred_disp, (final_w, final_h), interpolation=cv2.INTER_LINEAR
95 | )
96 |
97 | if SAVE_IMG:
98 | imageio.imwrite(os.path.join(img_dir, '%05d.png' % i), img_resized)
99 | np.save(
100 | os.path.join(disp_dir, '%05d.npy' % i),
101 | pred_disp_resized.astype(np.float32),
102 | )
103 |
104 | mono_depths.append(pred_depth)
105 | c2w_mats.append(cam_c2w)
106 | imgs.append(img_resized)
107 |
108 | close_depth, inf_depth = np.percentile(pred_depth, 5), np.percentile(
109 | pred_depth, 95
110 | )
111 | # print(close_depth, inf_depth)
112 | bounds = np.array([close_depth, inf_depth])
113 | bounds_mats.append(bounds)
114 |
115 | c2w_mats = np.stack(c2w_mats, 0)
116 | bounds_mats = np.stack(bounds_mats, 0)
117 |
118 | h, w, fx, fy = imgs[0].shape[0], imgs[0].shape[1], K[0, 0], K[1, 1]
119 |
120 | print('h, w ', h, w, fx, fy)
121 | print('bounds_mats ', np.min(bounds_mats), np.max(bounds_mats))
122 |
123 | ff = (fx + fy) / 2.0
124 | # hwf = np.array([h, w, fx, fy]).reshape([1, 4])
125 | hwf = np.array([h, w, ff]).reshape([3, 1])
126 |
127 | poses = c2w_mats[:, :3, :4].transpose([1, 2, 0])
128 |
129 | poses = np.concatenate(
130 | [poses, np.tile(hwf[..., np.newaxis], [1, 1, poses.shape[-1]])], 1
131 | )
132 |
133 | # must switch to [-y, x, z] from [x, -y, -z], NOT [r, u, -t]
134 | poses = np.concatenate(
135 | [
136 | poses[:, 1:2, :],
137 | poses[:, 0:1, :],
138 | -poses[:, 2:3, :],
139 | poses[:, 3:4, :],
140 | poses[:, 4:5, :],
141 | ],
142 | 1,
143 | )
144 |
145 | save_arr = []
146 | for i in range((poses.shape[2])):
147 | save_arr.append(np.concatenate([poses[..., i].ravel(), bounds_mats[i]], 0))
148 |
149 | np.save(os.path.join(data_dir, 'poses_bounds_cvd.npy'), save_arr)
150 |
--------------------------------------------------------------------------------
/ibrnet/data_loaders/data_utils.py:
--------------------------------------------------------------------------------
1 | """utility function definition for data loader."""
2 |
3 | import math
4 | import numpy as np
5 |
6 | rng = np.random.RandomState(234)
7 | _EPS = np.finfo(float).eps * 4.0
8 | TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision
9 |
10 |
11 | def vector_norm(data, axis=None, out=None):
12 | """Return length, i.e. eucledian norm, of ndarray along axis."""
13 | data = np.array(data, dtype=np.float64, copy=True)
14 | if out is None:
15 | if data.ndim == 1:
16 | return math.sqrt(np.dot(data, data))
17 | data *= data
18 | out = np.atleast_1d(np.sum(data, axis=axis))
19 | np.sqrt(out, out)
20 | return out
21 | else:
22 | data *= data
23 | np.sum(data, axis=axis, out=out)
24 | np.sqrt(out, out)
25 |
26 |
27 | def quaternion_about_axis(angle, axis):
28 | """Return quaternion for rotation about axis."""
29 | quaternion = np.zeros((4,), dtype=np.float64)
30 | quaternion[:3] = axis[:3]
31 | qlen = vector_norm(quaternion)
32 | if qlen > _EPS:
33 | quaternion *= math.sin(angle / 2.0) / qlen
34 | quaternion[3] = math.cos(angle / 2.0)
35 | return quaternion
36 |
37 |
38 | def quaternion_matrix(quaternion):
39 | """Return homogeneous rotation matrix from quaternion."""
40 | q = np.array(quaternion[:4], dtype=np.float64, copy=True)
41 | nq = np.dot(q, q)
42 | if nq < _EPS:
43 | return np.identity(4)
44 | q *= math.sqrt(2.0 / nq)
45 | q = np.outer(q, q)
46 | return np.array(
47 | (
48 | (1.0 - q[1, 1] - q[2, 2], q[0, 1] - q[2, 3], q[0, 2] + q[1, 3], 0.0),
49 | (q[0, 1] + q[2, 3], 1.0 - q[0, 0] - q[2, 2], q[1, 2] - q[0, 3], 0.0),
50 | (q[0, 2] - q[1, 3], q[1, 2] + q[0, 3], 1.0 - q[0, 0] - q[1, 1], 0.0),
51 | (0.0, 0.0, 0.0, 1.0),
52 | ),
53 | dtype=np.float64,
54 | )
55 |
56 |
57 | def angular_dist_between_2_vectors(vec1, vec2):
58 | vec1_unit = vec1 / (np.linalg.norm(vec1, axis=1, keepdims=True) + TINY_NUMBER)
59 | vec2_unit = vec2 / (np.linalg.norm(vec2, axis=1, keepdims=True) + TINY_NUMBER)
60 | angular_dists = np.arccos(
61 | np.clip(np.sum(vec1_unit * vec2_unit, axis=-1), -1.0, 1.0)
62 | )
63 | return angular_dists
64 |
65 |
66 | def batched_angular_dist_rot_matrix(r1, r2):
67 | """calculate the angular distance between two rotation matrices (batched)."""
68 |
69 | assert (
70 | r1.shape[-1] == 3
71 | and r2.shape[-1] == 3
72 | and r1.shape[-2] == 3
73 | and r2.shape[-2] == 3
74 | )
75 | return np.arccos(
76 | np.clip(
77 | (np.trace(np.matmul(r2.transpose(0, 2, 1), r1), axis1=1, axis2=2) - 1)
78 | / 2.0,
79 | a_min=-1 + TINY_NUMBER,
80 | a_max=1 - TINY_NUMBER,
81 | )
82 | )
83 |
84 |
85 | def get_nearest_pose_ids(
86 | tar_pose,
87 | ref_poses,
88 | tar_id=-1,
89 | angular_dist_method='vector',
90 | scene_center=(0, 0, 0),
91 | ):
92 | """Get poses id in nearest neighboorhood manner."""
93 | num_cams = len(ref_poses)
94 | batched_tar_pose = tar_pose[None, ...].repeat(num_cams, 0)
95 |
96 | if angular_dist_method == 'matrix':
97 | dists = batched_angular_dist_rot_matrix(
98 | batched_tar_pose[:, :3, :3], ref_poses[:, :3, :3]
99 | )
100 | elif angular_dist_method == 'vector':
101 | tar_cam_locs = batched_tar_pose[:, :3, 3]
102 | ref_cam_locs = ref_poses[:, :3, 3]
103 | scene_center = np.array(scene_center)[None, ...]
104 | tar_vectors = tar_cam_locs - scene_center
105 | ref_vectors = ref_cam_locs - scene_center
106 | dists = angular_dist_between_2_vectors(tar_vectors, ref_vectors)
107 | elif angular_dist_method == 'dist':
108 | tar_cam_locs = batched_tar_pose[:, :3, 3]
109 | ref_cam_locs = ref_poses[:, :3, 3]
110 | dists = np.linalg.norm(tar_cam_locs - ref_cam_locs, axis=1)
111 | else:
112 | raise NotImplementedError
113 |
114 | if tar_id >= 0:
115 | assert tar_id < num_cams
116 | dists[tar_id] = 1e3
117 |
118 | sorted_ids = np.argsort(dists)
119 |
120 | return sorted_ids
121 |
122 |
123 | def get_interval_pose_ids(
124 | tar_pose,
125 | ref_poses,
126 | tar_id=-1,
127 | angular_dist_method='dist',
128 | interval=2,
129 | scene_center=(0, 0, 0)):
130 | """Get poses id in nearest neighboorhood manner from every 'interval' frames."""
131 |
132 | original_indices = np.array(range(0, len(ref_poses)))
133 |
134 | ref_poses = ref_poses[::interval]
135 | subsample_indices = original_indices[::interval]
136 |
137 | num_cams = len(ref_poses)
138 | batched_tar_pose = tar_pose[None, ...].repeat(num_cams, 0)
139 |
140 | if angular_dist_method == 'matrix':
141 | dists = batched_angular_dist_rot_matrix(batched_tar_pose[:, :3, :3],
142 | ref_poses[:, :3, :3])
143 | elif angular_dist_method == 'vector':
144 | tar_cam_locs = batched_tar_pose[:, :3, 3]
145 | ref_cam_locs = ref_poses[:, :3, 3]
146 | scene_center = np.array(scene_center)[None, ...]
147 | tar_vectors = tar_cam_locs - scene_center
148 | ref_vectors = ref_cam_locs - scene_center
149 | dists = angular_dist_between_2_vectors(tar_vectors, ref_vectors)
150 | elif angular_dist_method == 'dist':
151 | tar_cam_locs = batched_tar_pose[:, :3, 3]
152 | ref_cam_locs = ref_poses[:, :3, 3]
153 | dists = np.linalg.norm(tar_cam_locs - ref_cam_locs, axis=1)
154 | else:
155 | raise NotImplementedError
156 |
157 | if tar_id >= 0:
158 | assert tar_id < num_cams
159 | dists[tar_id] = 1e3
160 |
161 | sorted_ids = np.argsort(dists)
162 |
163 | final_ids = subsample_indices[sorted_ids]
164 |
165 | return final_ids
166 |
--------------------------------------------------------------------------------
/ibrnet/projection.py:
--------------------------------------------------------------------------------
1 | """Class definition for perspective projection."""
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | class Projector:
8 | """Class for performing perspective projection."""
9 |
10 | def __init__(self, device):
11 | self.device = device
12 |
13 | def inbound(self, pixel_locations, h, w):
14 | """Check if the pixel locations are in valid range."""
15 | return (
16 | (pixel_locations[..., 0] <= w - 1.0)
17 | & (pixel_locations[..., 0] >= 0)
18 | & (pixel_locations[..., 1] <= h - 1.0)
19 | & (pixel_locations[..., 1] >= 0)
20 | )
21 |
22 | def normalize(self, pixel_locations, h, w):
23 | """Normalize pixel locations for grid_sampler function."""
24 | resize_factor = torch.tensor([w - 1.0, h - 1.0]).to(self.device)[
25 | None, None, :
26 | ]
27 | normalized_pixel_locations = (
28 | 2 * pixel_locations / resize_factor - 1.0
29 | ) # [n_views, n_points, 2]
30 | return normalized_pixel_locations
31 |
32 | def compute_projections(self, xyz, train_cameras):
33 | """Project 3D points into views using training camera parameteres."""
34 | original_shape = xyz.shape[:-1]
35 | xyz = xyz.reshape(original_shape[0], -1, 3)
36 |
37 | num_views = len(train_cameras)
38 | train_intrinsics = train_cameras[:, 2:18].reshape(
39 | -1, 4, 4
40 | ) # [n_views, 4, 4]
41 | train_poses = train_cameras[:, -16:].reshape(-1, 4, 4) # [n_views, 4, 4]
42 | xyz_h = torch.cat(
43 | [xyz, torch.ones_like(xyz[..., :1])], dim=-1
44 | ) # [n_points, 4]
45 |
46 | projections = train_intrinsics.bmm(torch.inverse(train_poses)).bmm(
47 | xyz_h.permute(0, 2, 1)
48 | ) # [n_views, 4, n_points]
49 |
50 | projections = projections.permute(0, 2, 1) # [n_views, n_points, 4]
51 | pixel_locations = projections[..., :2] / torch.clamp(
52 | projections[..., 2:3], min=1e-8
53 | ) # [n_views, n_points, 2]
54 | pixel_locations = torch.clamp(pixel_locations, min=-1e6, max=1e6)
55 |
56 | mask = projections[..., 2] > 0 # a point is invalid if behind the camera
57 | return pixel_locations.reshape(
58 | (num_views,) + original_shape[1:] + (2,)
59 | ), mask.reshape((num_views,) + original_shape[1:])
60 |
61 | def compute_angle(self, xyz_st, xyz, query_camera, train_cameras):
62 | """Compute difference of viewing angle between rays from source and ones from target view.
63 |
64 | Args:
65 |
66 | xyz_st: reference 3D point location without scene motion
67 | xyz: 3D positions displaced by scene motion at nearby times
68 | query_camera: target view camera parameters
69 | train_imgs: source view images
70 |
71 | Returns:
72 | Difference of viewing angle between rays from source and ones from target
73 | view.
74 | """
75 | original_shape = xyz.shape[:-1]
76 | xyz_st_ = xyz_st.reshape(xyz_st.shape[0], -1, 3)
77 | xyz_ = xyz.reshape(xyz.shape[0], -1, 3)
78 |
79 | train_poses = train_cameras[:, -16:].reshape(-1, 4, 4) # [n_views, 4, 4]
80 | num_views = len(train_poses)
81 | query_pose = (
82 | query_camera[-16:].reshape(-1, 4, 4).repeat(num_views, 1, 1)
83 | ) # [n_views, 4, 4]
84 |
85 | ray2tar_pose = F.normalize(
86 | query_pose[:, :3, 3].unsqueeze(1) - xyz_st_, dim=-1
87 | )
88 | ray2train_pose = F.normalize(
89 | train_poses[:, :3, 3].unsqueeze(1) - xyz_, dim=-1
90 | )
91 | ray_diff = ray2tar_pose - ray2train_pose
92 |
93 | ray_diff_dot = torch.sum(
94 | ray2tar_pose * ray2train_pose, dim=-1, keepdim=True
95 | )
96 | ray_diff_direction = F.normalize(
97 | ray_diff, dim=-1
98 | ) # ray_diff / torch.clamp(ray_diff_norm, min=1e-6)
99 |
100 | ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1)
101 | return ray_diff.reshape((num_views,) + original_shape[1:] + (4,))
102 |
103 | def compute_with_motions(
104 | self, xyz_st, xyz, query_camera, train_imgs, train_cameras, featmaps
105 | ):
106 | """Extract 2D feature by projecting 3D points displaced by scene motion.
107 |
108 | Args:
109 | xyz_st: reference point location without scene motion
110 | xyz: 3D point positions displaced by scene motion
111 | query_camera: target view camera parameters
112 | train_imgs: source view images
113 | train_cameras: source view camera parameters
114 | featmaps: source view 2D image feature maps.
115 |
116 | Returns:
117 | rgb_feat_sampled: extracted 2D feature
118 | ray_diff: viewing angle difference between target ray and source ray
119 | mask: valid masks
120 | """
121 |
122 | assert (
123 | (train_imgs.shape[0] == 1)
124 | and (train_cameras.shape[0] == 1)
125 | and (query_camera.shape[0] == 1)
126 | ), 'only support batch_size=1 for now'
127 |
128 | xyz_st = xyz_st[None, ...].expand(xyz.shape[0], -1, -1, -1)
129 |
130 | train_imgs = train_imgs.squeeze(0) # [n_views, h, w, 3]
131 | train_cameras = train_cameras.squeeze(0) # [n_views, 34]
132 | query_camera = query_camera.squeeze(0) # [34, ]
133 |
134 | train_imgs = train_imgs.permute(0, 3, 1, 2) # [n_views, 3, h, w]
135 |
136 | h, w = train_cameras[0][:2]
137 |
138 | # compute the projection of the query points to each reference image
139 | pixel_locations, mask_in_front = self.compute_projections(
140 | xyz, train_cameras
141 | )
142 |
143 | normalized_pixel_locations = self.normalize(
144 | pixel_locations, h, w
145 | ) # [n_views, n_rays, n_samples, 2]
146 |
147 | # rgb sampling
148 | rgbs_sampled = F.grid_sample(
149 | train_imgs, normalized_pixel_locations, align_corners=True
150 | )
151 | rgbs_sampled_ = rgbs_sampled.permute(
152 | 2, 3, 0, 1
153 | ) # [n_rays, n_samples, n_views, 3]
154 |
155 | # deep feature sampling
156 | feat_sampled = F.grid_sample(
157 | featmaps, normalized_pixel_locations, align_corners=True
158 | )
159 | feat_sampled = feat_sampled.permute(
160 | 2, 3, 0, 1
161 | ) # [n_rays, n_samples, n_views, d]
162 | rgb_feat_sampled = torch.cat(
163 | [rgbs_sampled_, feat_sampled], dim=-1
164 | ) # [n_rays, n_samples, n_views, d+3]
165 |
166 | inbound = self.inbound(pixel_locations, h, w)
167 | ray_diff = self.compute_angle(
168 | xyz_st, xyz, query_camera, train_cameras
169 | ).detach()
170 |
171 | ray_diff = ray_diff.permute(1, 2, 0, 3)
172 | mask = (
173 | (inbound * mask_in_front).float().permute(1, 2, 0)[..., None]
174 | ) # [n_rays, n_samples, n_views, 1]
175 |
176 | return rgb_feat_sampled, ray_diff, mask
177 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This is not an officially supported Google product.
2 |
3 | # DynIBaR: Neural Dynamic Image-Based Rendering
4 |
5 | ### [Project Page](https://dynibar.github.io/)
6 |
7 | Implementation for CVPR 2023 paper (best paper honorable mention)
8 |
9 | [DynIBaR: Neural Dynamic Image-Based Rendering, CVPR 2023](https://dynibar.github.io/)
10 |
11 | [Zhengqi Li](https://zhengqili.github.io/)1, [Qianqian Wang](https://www.cs.cornell.edu/~qqw/)1,2, [Forrester Cole](https://people.csail.mit.edu/fcole/)1, [Richard Tucker](https://research.google/people/RichardTucker/)1, [Noah Snavely](https://www.cs.cornell.edu/~snavely/)1
12 |
13 | 1Google Research, 2Cornell Tech, Cornell University \
14 |
15 |
16 | ## Instructions for installing dependencies
17 |
18 | ### Python Environment
19 |
20 | The following codebase was successfully run with Python 3.8 and CUDA 11.3. We
21 | suggest installing the library in a virtual environment such as Anaconda.
22 |
23 | To install required libraries, run: \
24 | `conda env create -f enviornment_dynibar.yml`
25 |
26 | To install softmax splatting for preprocessing, clone and install the library
27 | from [here](https://github.com/hperrot/splatting).
28 |
29 | To measure LPIPS, copy "models" folder from
30 | [NSFF](https://github.com/zhengqili/Neural-Scene-Flow-Fields/tree/main/nsff_exp/models),
31 | and put it in the code root directory.
32 |
33 | ## Evaluation on Nvidia Dynamic scene dataset.
34 |
35 | ### Downloading data and pretrained checkpoint
36 |
37 | We include pretrained checkpoints that can be accessed by running:
38 |
39 | ```
40 | wget https://storage.googleapis.com/gresearch/dynibar/nvidia_checkpoints.zip
41 | unzip nvidia_checkpoints.zip
42 | ```
43 |
44 | put the unzipped "checkpoints" folder in the code root directory.
45 |
46 | Each scene in the Nvidia dataset can be accessed
47 | [here](https://drive.google.com/drive/folders/1Gv6j_RvDG2WrpqEJWtx73u1tlCZKsPiM?usp=sharing)
48 |
49 | The input data directory should similar to the following format:
50 | xxx/nvidia_long_release/Balloon1
51 |
52 | Run the following command for each scene to obtain reported quantitative results:
53 |
54 | ```bash
55 | # Usage: In txt file, You need to change "rootdir" to your code root directory,
56 | # and "folder_path" to input data directory, and make sure "coarse_dir" points to
57 | # "checkpoints" folder you unzip.
58 | python eval_nvidia.py --config configs_nvidia/eval_balloon1_long.txt
59 | ```
60 |
61 | Note: It will take ~8 hours to evaluate each scene with 4x Nvidia A100 GPUs.
62 |
63 | ## Training/rendering on monocular videos.
64 |
65 | ### Required inputs and corresponding folders or files:
66 |
67 | We provide a template input data for the NSFF example video, which can
68 | be downloaded
69 | [here](https://drive.google.com/file/d/1t6VLtcdxITFcdm9fi9SSFOiHqgHu9wdP/view?usp=sharing)
70 |
71 | The input data directory should be in the following format:
72 | xxx/release/kid-running/dense/***
73 |
74 | For your own video, you need to include the following folders to run training.
75 |
76 | * disp: disparity maps from
77 | [dynamic-cvd](https://github.com/google/dynamic-video-depth). Note that you
78 | need to run test.py to save the disparity and camera parameters to the disk.
79 | * images_wxh: resized images at resolution w x h.
80 | * poses_bounds_cvd.npy: camera parameters of input video in LLFF format.
81 |
82 | You can generate the above three items with the following script:
83 |
84 | ```bash
85 | # Usage: data_dir is input video directory path,
86 | # cvd_dir is saved depth directory resulting from running
87 | # "test.py" at https://github.com/google/dynamic-video-depth
88 | python save_monocular_cameras.py \
89 | --data_dir xxx/release/kid-running \
90 | --cvd_dir xxx/kid-running_scene_flow_motion_field_epoch_20/epoch0020_test
91 | ```
92 |
93 | * source_virtual_views_wxh: virtual source views used to improve training
94 | stability and rendering quality (used in monocular video only). Running
95 | the following script to obtain them:
96 |
97 | ```bash
98 | # Usage: data_dir is input video directory path,
99 | # cvd_dir is saved depth direcotry resulting from running
100 | # "test.py" at https://github.com/google/dynamic-video-depth
101 | python render_source_vv.py \
102 | --data_dir xxx/release/kid-running \
103 | --cvd_dir xxx/kid-running_scene_flow_motion_field_epoch_20/epoch0020_test
104 | ```
105 |
106 | * flow_i1, flow_i2, flow_i3: estimated optical flows within temporal window of
107 | length 3. You can follow prior NSFF
108 | [script](https://github.com/zhengqili/Neural-Scene-Flow-Fields/blob/main/nsff_scripts/run_flows_video.py)
109 | to run optical flows between the frame i and its nearby frames i+1, i+2,
110 | i+3, and save them in folders "flow_i1", "flow_i2", "flow_i3" respectively.
111 | For example, 00000_fwd.npz in folder "flow_i1" stores forward flow and valid
112 | mask from frame 0 to frame 1, and 00000_bwd.npz stores backward flow and
113 | valid mask from frame 1 to frame 0.
114 |
115 | * static_masks, dynamic_masks: motion masks indicating which region is
116 | stationary or moving. You can perform morphological dilation and erosion operations respectively
117 | to ensure static_masks sufficeintly cover the regions of moving objects, and the regions from dynamic_masks
118 | are within the true regions of moving objects.
119 | (Note: due to dependency reason, we don't release code to generate the masks. Instead you could use [script](https://github.com/zhengqili/Neural-Scene-Flow-Fields/blob/main/nsff_scripts/run_flows_video.py#L87) from NSFF to generate coarse masks for your usage)
120 |
121 | ### To train the model:
122 |
123 | ```bash
124 | # Usage: config is config txt file for training video
125 | # make sure "rootdir" is your code root directory,
126 | # "folder_path" is your input data directory path,
127 | # "train_scenes" is your folder name.
128 | # For example, if data is in xxx/release/kid-running/dense/, then "train_scenes" is
129 | # "xxx/release/", "train_scenes" is "kid-running"
130 | python train.py \
131 | --config configs/train_kid-running.txt
132 | ```
133 |
134 | Hyperparameters in config txt file you might need to know for training a good model on in-the-wild videos
135 | * rootdir: code root directory, should be in format: YOUR_PATH/dynibar
136 | * folder_path: data root directory,
137 | * N_rand: number of random samples at each iterations. Try to set it as large as possible, typically > 3000 gives good results
138 | * init_decay_epoch: number of epochs to linaerly decay the data-driven depth and optical flow losses. Modify this such that num_video_frames * init_decay_epoch = 30~40K
139 | * max_range, num_source_views: max_range indicates maximum search frame ranges to select source views for static model. num_source_views*2 is number of source views used for static model.
140 |
141 | The tensorboard includes rendering visualization as shown below.
142 |
143 |
144 |
145 | ### To render the model:
146 |
147 | ```bash
148 | # Usage: config is config txt file for training video,
149 | # please make sure expname in txt is the saved folder name in 'out' directory
150 | python render_monocular_bt.py \
151 | --config configs/test_kid-running.txt
152 | ```
153 |
154 | ### Contact
155 |
156 | For any questions related to our paper and implementation,
157 | please send email to zhengqili@google.com.
158 |
159 | ## Citation
160 |
161 | ```
162 | @InProceedings{Li_2023_CVPR,
163 | author = {Li, Zhengqi and Wang, Qianqian and Cole, Forrester and Tucker, Richard and Snavely, Noah},
164 | title = {DynIBaR: Neural Dynamic Image-Based Rendering},
165 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
166 | month = {June},
167 | year = {2023},
168 | pages = {4273-4284}
169 | }
170 | ```
171 |
--------------------------------------------------------------------------------
/ibrnet/feature_network.py:
--------------------------------------------------------------------------------
1 | """Class definition for 2D feature extractor."""
2 |
3 | import importlib
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | def class_for_name(module_name, class_name):
10 | m = importlib.import_module(module_name)
11 | return getattr(m, class_name)
12 |
13 |
14 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
15 | """3x3 convolution with padding."""
16 | return nn.Conv2d(
17 | in_planes,
18 | out_planes,
19 | kernel_size=3,
20 | stride=stride,
21 | padding=dilation,
22 | groups=groups,
23 | bias=False,
24 | dilation=dilation,
25 | padding_mode='reflect',
26 | )
27 |
28 |
29 | def conv1x1(in_planes, out_planes, stride=1):
30 | """1x1 convolution layer."""
31 | return nn.Conv2d(
32 | in_planes,
33 | out_planes,
34 | kernel_size=1,
35 | stride=stride,
36 | bias=False,
37 | padding_mode='reflect',
38 | )
39 |
40 |
41 | class BasicBlock(nn.Module):
42 | """Basic CNN block."""
43 | expansion = 1
44 |
45 | def __init__(
46 | self,
47 | inplanes,
48 | planes,
49 | stride=1,
50 | downsample=None,
51 | groups=1,
52 | base_width=64,
53 | dilation=1,
54 | norm_layer=None,
55 | ):
56 | super(BasicBlock, self).__init__()
57 | if norm_layer is None:
58 | norm_layer = nn.InstanceNorm2d
59 |
60 | self.conv1 = conv3x3(inplanes, planes, stride)
61 | self.bn1 = norm_layer(planes, track_running_stats=False, affine=True)
62 | self.relu = nn.ReLU(inplace=True)
63 | self.conv2 = conv3x3(planes, planes)
64 | self.bn2 = norm_layer(planes, track_running_stats=False, affine=True)
65 | self.downsample = downsample
66 | self.stride = stride
67 |
68 | def forward(self, x):
69 | identity = x
70 |
71 | out = self.conv1(x)
72 | out = self.bn1(out)
73 | out = self.relu(out)
74 |
75 | out = self.conv2(out)
76 | out = self.bn2(out)
77 |
78 | if self.downsample is not None:
79 | identity = self.downsample(x)
80 |
81 | out += identity
82 | out = self.relu(out)
83 |
84 | return out
85 |
86 |
87 | class Bottleneck(nn.Module):
88 | """Bottleneck CNN block."""
89 |
90 | expansion = 4
91 |
92 | def __init__(
93 | self,
94 | inplanes,
95 | planes,
96 | stride=1,
97 | downsample=None,
98 | groups=1,
99 | base_width=64,
100 | dilation=1,
101 | norm_layer=None,
102 | ):
103 | super(Bottleneck, self).__init__()
104 | if norm_layer is None:
105 | norm_layer = nn.InstanceNorm2d
106 | width = int(planes * (base_width / 64.0)) * groups
107 | self.conv1 = conv1x1(inplanes, width)
108 | self.bn1 = norm_layer(width, track_running_stats=False, affine=True)
109 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
110 | self.bn2 = norm_layer(width, track_running_stats=False, affine=True)
111 | self.conv3 = conv1x1(width, planes * self.expansion)
112 | self.bn3 = norm_layer(
113 | planes * self.expansion, track_running_stats=False, affine=True
114 | )
115 | self.relu = nn.ReLU(inplace=True)
116 | self.downsample = downsample
117 | self.stride = stride
118 |
119 | def forward(self, x):
120 | identity = x
121 |
122 | out = self.conv1(x)
123 | out = self.bn1(out)
124 | out = self.relu(out)
125 |
126 | out = self.conv2(out)
127 | out = self.bn2(out)
128 | out = self.relu(out)
129 |
130 | out = self.conv3(out)
131 | out = self.bn3(out)
132 |
133 | if self.downsample is not None:
134 | identity = self.downsample(x)
135 |
136 | out += identity
137 | out = self.relu(out)
138 |
139 | return out
140 |
141 |
142 | class conv(nn.Module):
143 | """Convolutional layer."""
144 |
145 | def __init__(self, num_in_layers, num_out_layers, kernel_size, stride):
146 | super(conv, self).__init__()
147 | self.kernel_size = kernel_size
148 | self.conv = nn.Conv2d(
149 | num_in_layers,
150 | num_out_layers,
151 | kernel_size=kernel_size,
152 | stride=stride,
153 | padding=(self.kernel_size - 1) // 2,
154 | padding_mode='reflect',
155 | )
156 | self.bn = nn.InstanceNorm2d(
157 | num_out_layers, track_running_stats=False, affine=True
158 | )
159 |
160 | def forward(self, x):
161 | return F.elu(self.bn(self.conv(x)), inplace=True)
162 |
163 |
164 | class upconv(nn.Module):
165 | """Convolutional layers followed by upsampling."""
166 |
167 | def __init__(self, num_in_layers, num_out_layers, kernel_size, scale):
168 | super(upconv, self).__init__()
169 | self.scale = scale
170 | self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1)
171 |
172 | def forward(self, x):
173 | x = nn.functional.interpolate(
174 | x, scale_factor=self.scale, align_corners=True, mode='bilinear'
175 | )
176 | return self.conv(x)
177 |
178 |
179 | class ResNet(nn.Module):
180 | """Main ResNet based feature extractor."""
181 | def __init__(
182 | self,
183 | encoder='resnet34',
184 | coarse_out_ch=32,
185 | fine_out_ch=32,
186 | norm_layer=None,
187 | coarse_only=False,
188 | ):
189 | super(ResNet, self).__init__()
190 | assert encoder in [
191 | 'resnet18',
192 | 'resnet34',
193 | 'resnet50',
194 | 'resnet101',
195 | 'resnet152',
196 | ], 'Incorrect encoder type'
197 | if encoder in ['resnet18', 'resnet34']:
198 | filters = [64, 128, 256, 512]
199 | else:
200 | filters = [256, 512, 1024, 2048]
201 | self.coarse_only = coarse_only
202 | if self.coarse_only:
203 | fine_out_ch = 0
204 | self.coarse_out_ch = coarse_out_ch
205 | self.fine_out_ch = fine_out_ch
206 | out_ch = coarse_out_ch + fine_out_ch
207 |
208 | # original
209 | layers = [3, 4, 6, 3]
210 | if norm_layer is None:
211 | # norm_layer = nn.InstanceNorm2d
212 | norm_layer = nn.InstanceNorm2d
213 | self._norm_layer = norm_layer
214 | self.dilation = 1
215 | block = BasicBlock
216 | replace_stride_with_dilation = [False, False, False]
217 | self.inplanes = 64
218 | self.groups = 1
219 | self.base_width = 64
220 | self.conv1 = nn.Conv2d(
221 | 3,
222 | self.inplanes,
223 | kernel_size=7,
224 | stride=2,
225 | padding=3,
226 | bias=False,
227 | padding_mode='reflect',
228 | )
229 | self.bn1 = norm_layer(self.inplanes, track_running_stats=False, affine=True)
230 | self.relu = nn.ReLU(inplace=True)
231 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
232 | self.layer2 = self._make_layer(
233 | block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
234 | )
235 | self.layer3 = self._make_layer(
236 | block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
237 | )
238 |
239 | # decoder
240 | self.upconv3 = upconv(filters[2], 128, 3, 2)
241 | self.iconv3 = conv(filters[1] + 128, 128, 3, 1)
242 | self.upconv2 = upconv(128, 64, 3, 2)
243 | self.iconv2 = conv(filters[0] + 64, out_ch, 3, 1)
244 |
245 | # fine-level conv
246 | self.out_conv = nn.Conv2d(out_ch, out_ch, 1, 1)
247 |
248 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
249 | norm_layer = self._norm_layer
250 | downsample = None
251 | previous_dilation = self.dilation
252 | if dilate:
253 | self.dilation *= stride
254 | stride = 1
255 | if stride != 1 or self.inplanes != planes * block.expansion:
256 | downsample = nn.Sequential(
257 | conv1x1(self.inplanes, planes * block.expansion, stride),
258 | norm_layer(
259 | planes * block.expansion, track_running_stats=False, affine=True
260 | ),
261 | )
262 |
263 | layers = []
264 | layers.append(
265 | block(
266 | self.inplanes,
267 | planes,
268 | stride,
269 | downsample,
270 | self.groups,
271 | self.base_width,
272 | previous_dilation,
273 | norm_layer,
274 | )
275 | )
276 | self.inplanes = planes * block.expansion
277 | for _ in range(1, blocks):
278 | layers.append(
279 | block(
280 | self.inplanes,
281 | planes,
282 | groups=self.groups,
283 | base_width=self.base_width,
284 | dilation=self.dilation,
285 | norm_layer=norm_layer,
286 | )
287 | )
288 |
289 | return nn.Sequential(*layers)
290 |
291 | def skipconnect(self, x1, x2):
292 | diffY = x2.size()[2] - x1.size()[2]
293 | diffX = x2.size()[3] - x1.size()[3]
294 |
295 | x1 = F.pad(
296 | x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)
297 | )
298 |
299 | x = torch.cat([x2, x1], dim=1)
300 | return x
301 |
302 | def forward(self, x):
303 | x = self.relu(self.bn1(self.conv1(x)))
304 |
305 | x1 = self.layer1(x)
306 | x_out = self.out_conv(x1)
307 |
308 | x_coarse = x_out[:, : self.coarse_out_ch, :]
309 | x_fine = x_out[:, -self.fine_out_ch :, :]
310 |
311 | return x_coarse, x_fine
312 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | """function definition for config function."""
2 |
3 | import configargparse
4 |
5 |
6 | def config_parser():
7 | """Configuration function."""
8 | parser = configargparse.ArgumentParser()
9 | # general
10 | parser.add_argument('--config', is_config_file=True, help='Config file path')
11 | parser.add_argument(
12 | '--rootdir',
13 | type=str,
14 | help=(
15 | 'The path to the project root directory. Replace this path with'
16 | ' yours!'
17 | ),
18 | )
19 | parser.add_argument(
20 | '--folder_path',
21 | type=str,
22 | help=(
23 | 'The path to the input training data. Replace this path with yours.'
24 | ),
25 | )
26 |
27 | parser.add_argument(
28 | '--coarse_dir',
29 | type=str,
30 | help=(
31 | 'The directory of coarse model.'
32 | ),
33 | )
34 |
35 | parser.add_argument(
36 | '--mask_src_view',
37 | action='store_true',
38 | help=(
39 | 'Using motion segementation to mask src views for rendering static'
40 | ' model'
41 | ),
42 | )
43 | parser.add_argument(
44 | '--training_height', type=int, default=288, help='Training image height'
45 | )
46 | parser.add_argument('--expname', type=str, help='Experiment name')
47 | parser.add_argument(
48 | '--distributed', action='store_true', help='Use distributed training'
49 | )
50 | parser.add_argument(
51 | '--local_rank', type=int, default=0, help='Rank for distributed training'
52 | )
53 | parser.add_argument(
54 | '-j',
55 | '--workers',
56 | default=16,
57 | type=int,
58 | help='Number of data loading workers (default: 16)',
59 | )
60 |
61 | parser.add_argument(
62 | '--mask_static',
63 | action='store_true',
64 | help='Using motion mask to mask source views for static model',
65 | )
66 |
67 | ########## model options ##########
68 | parser.add_argument(
69 | '--N_rand',
70 | type=int,
71 | default=32 * 16,
72 | help='Batch size (number of random rays per gradient step)',
73 | )
74 | parser.add_argument(
75 | '--sample_mode',
76 | type=str,
77 | default='uniform',
78 | help='How to sample pixels from images for training:uniform|center',
79 | )
80 | parser.add_argument(
81 | '--lr_multipler',
82 | type=float,
83 | default=1.0,
84 | help='Learning rate ratio for training static component',
85 | )
86 | parser.add_argument(
87 | '--num_vv',
88 | type=int,
89 | default=3,
90 | help='Number of virtual source views',
91 | )
92 | parser.add_argument(
93 | '--cycle_factor',
94 | type=float,
95 | default=0.1,
96 | help='Cycle conssitency loss warmup factor',
97 | )
98 | parser.add_argument(
99 | '--anneal_cycle',
100 | action='store_true',
101 | help='Bootstrap cycle consistency loss',
102 | )
103 | parser.add_argument(
104 | '--erosion_radius',
105 | type=int,
106 | default=1,
107 | help='Mophorlogical erosion raidus for mask',
108 | )
109 | parser.add_argument(
110 | '--decay_rate',
111 | type=float,
112 | default=10.0,
113 | help='Decaying rate for data-driven loss',
114 | )
115 |
116 | ########## dataset options ##########
117 | parser.add_argument(
118 | '--eval_dataset',
119 | type=str,
120 | default='llff_test',
121 | help='The dataset to evaluate',
122 | )
123 | parser.add_argument(
124 | '--eval_scenes',
125 | nargs='+',
126 | default=[],
127 | help='Optional, specify a subset of scenes from eval_dataset to evaluate',
128 | )
129 | parser.add_argument(
130 | '--render_idx', type=int, default=-1, help='Frame index for rendering'
131 | )
132 | parser.add_argument(
133 | '--train_dataset',
134 | type=str,
135 | default='ibrnet_collected',
136 | help=(
137 | 'the training dataset, should either be a single dataset, or multiple'
138 | ' datasets connected with "+", for example,'
139 | ' ibrnet_collected+llff+spaces'
140 | ),
141 | )
142 | parser.add_argument(
143 | '--train_scenes',
144 | nargs='+',
145 | default=[],
146 | help=(
147 | 'optional, specify a subset of training scenes from training dataset'
148 | ),
149 | )
150 |
151 | ## others
152 | parser.add_argument(
153 | '--init_decay_epoch',
154 | type=int,
155 | default=150,
156 | help='How many epochs to decay data driven losses',
157 | )
158 | parser.add_argument(
159 | '--max_range',
160 | type=int,
161 | default=35,
162 | help='Max frame range to sample source views for static model',
163 | )
164 |
165 | ########## model options ##########
166 | ## ray sampling options
167 | parser.add_argument(
168 | '--chunk_size',
169 | type=int,
170 | default=1024 * 4,
171 | help=(
172 | 'Number of rays processed in parallel, decrease if running out of'
173 | ' memory'
174 | ),
175 | )
176 | ## model options
177 | parser.add_argument(
178 | '--coarse_feat_dim',
179 | type=int,
180 | default=32,
181 | help='2D feature dimension for coarse level',
182 | )
183 | parser.add_argument(
184 | '--fine_feat_dim',
185 | type=int,
186 | default=32,
187 | help='2D feature dimension for fine level',
188 | )
189 | parser.add_argument(
190 | '--num_source_views',
191 | type=int,
192 | default=7,
193 | help=(
194 | 'The number of input source views for each target view used in'
195 | 'static dynibar model'
196 | ),
197 | )
198 | parser.add_argument(
199 | '--num_basis',
200 | type=int,
201 | default=6,
202 | help='The number of basis for motion trajectory',
203 | )
204 | parser.add_argument(
205 | '--anti_alias_pooling',
206 | type=int,
207 | default=1,
208 | help='Use anti-alias pooling',
209 | )
210 | parser.add_argument(
211 | '--mask_rgb',
212 | type=int,
213 | default=1,
214 | help=(
215 | 'Mask RGB features coresponding to black pixel for rendering from'
216 | ' static model'
217 | ),
218 | )
219 |
220 | ########## checkpoints ##########
221 | parser.add_argument(
222 | '--no_reload',
223 | action='store_true',
224 | help='do not reload weights from saved ckpt',
225 | )
226 | parser.add_argument(
227 | '--ckpt_path',
228 | type=str,
229 | default='',
230 | help='specific weights npy file to reload for coarse network',
231 | )
232 | parser.add_argument(
233 | '--no_load_opt',
234 | action='store_true',
235 | help='do not load optimizer when reloading',
236 | )
237 | parser.add_argument(
238 | '--no_load_scheduler',
239 | action='store_true',
240 | help='do not load scheduler when reloading',
241 | )
242 | ########### iterations & learning rate options ##########
243 | parser.add_argument(
244 | '--n_iters', type=int, default=300000, help='Num of iterations'
245 | )
246 | parser.add_argument(
247 | '--lrate_feature',
248 | type=float,
249 | default=1e-3,
250 | help='Learning rate for feature extractor',
251 | )
252 | parser.add_argument(
253 | '--lrate_mlp', type=float, default=5e-4, help='Learning rate for mlp'
254 | )
255 | parser.add_argument(
256 | '--lrate_decay_factor',
257 | type=float,
258 | default=0.5,
259 | help='Decay learning rate by a factor every specified number of steps',
260 | )
261 | parser.add_argument(
262 | '--lrate_decay_steps',
263 | type=int,
264 | default=50000,
265 | help='Decay learning rate by a factor every number of steps',
266 | )
267 | parser.add_argument(
268 | '--w_cycle',
269 | type=float,
270 | default=0.1,
271 | help='Weight of cycle consistency loss',
272 | )
273 | parser.add_argument(
274 | '--w_distortion',
275 | type=float,
276 | default=1e-3,
277 | help='Weight of distortion loss',
278 | )
279 | parser.add_argument(
280 | '--w_entropy', type=float, default=0.0, help='Weight of entropy loss'
281 | )
282 | parser.add_argument(
283 | '--w_disp', type=float, default=5e-2, help='Weight of disparty loss'
284 | )
285 | parser.add_argument(
286 | '--w_flow', type=float, default=5e-3, help='Weight of flow loss'
287 | )
288 | parser.add_argument(
289 | '--w_skew_entropy',
290 | type=float,
291 | default=1e-3,
292 | help='Weight of entropy loss, assuming there is no skewness.',
293 | )
294 | parser.add_argument(
295 | '--w_reg', type=float, default=0.05, help='Weight of regularization loss'
296 | )
297 | parser.add_argument(
298 | '--pretrain_path', type=str, default='', help='Pretrained model path'
299 | )
300 | parser.add_argument(
301 | '--occ_weights_mode',
302 | type=int,
303 | default=0,
304 | help=(
305 | 'Occlusion weight mode during cross-time rendering. 0: mix two models'
306 | ' weights. 1: using weight from dynamic model only 2: using weight'
307 | ' composited from static and dynamic models. '
308 | ),
309 | )
310 |
311 | ########## rendering options ##########
312 | parser.add_argument(
313 | '--N_samples',
314 | type=int,
315 | default=64,
316 | help='Number of coarse samples per ray',
317 | )
318 | parser.add_argument(
319 | '--N_importance',
320 | type=int,
321 | default=64,
322 | help=(
323 | 'Number of fine samples per ray. total number of samples is the sum'
324 | ' of coarse plus fine models'
325 | ),
326 | )
327 | parser.add_argument(
328 | '--inv_uniform',
329 | action='store_true',
330 | help='If True, uniformly sample in inverse depth space',
331 | )
332 | parser.add_argument(
333 | '--input_dir',
334 | action='store_true',
335 | help='If True, input global directional with positional encoding',
336 | )
337 | parser.add_argument(
338 | '--input_xyz',
339 | action='store_true',
340 | help='If True, input global xyz with positional encoding',
341 | )
342 | parser.add_argument(
343 | '--det',
344 | action='store_true',
345 | help='Deterministic sampling for coarse and fine samples',
346 | )
347 | parser.add_argument(
348 | '--white_bkgd',
349 | action='store_true',
350 | help='Apply the trick to avoid fitting to white background',
351 | )
352 | parser.add_argument(
353 | '--render_stride',
354 | type=int,
355 | default=1,
356 | help='Render with large stride for validation to save time',
357 | )
358 | ########## logging/saving options ##########
359 | parser.add_argument(
360 | '--i_print', type=int, default=100, help='Frequency of terminal printout'
361 | )
362 | parser.add_argument(
363 | '--i_img',
364 | type=int,
365 | default=1000,
366 | help='Frequency of tensorboard image logging',
367 | )
368 | parser.add_argument(
369 | '--i_weights',
370 | type=int,
371 | default=10000,
372 | help='Frequency of weight ckpt saving',
373 | )
374 |
375 | return parser
376 |
--------------------------------------------------------------------------------
/render_source_vv.py:
--------------------------------------------------------------------------------
1 | """Rendering virutal source views from video depth, used for monocular video."""
2 |
3 | import argparse
4 | import glob
5 | import os
6 |
7 | import cv2
8 | import imageio.v2 as imageio
9 | import kornia
10 | import numpy as np
11 | import skimage.morphology
12 | from splatting import splatting_function
13 | import torch
14 |
15 | def render_forward_splat(src_imgs, src_depths, r_cam, t_cam, k_src, k_dst):
16 | '''Point cloud rendering from RGBD images.'''
17 | batch_size = src_imgs.shape[0]
18 |
19 | rot = r_cam
20 | t = t_cam
21 | k_src_inv = k_src.inverse()
22 |
23 | x = np.arange(src_imgs[0].shape[1])
24 | y = np.arange(src_imgs[0].shape[0])
25 | coord = np.stack(np.meshgrid(x, y), -1)
26 | coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1)
27 | coord = coord.astype(np.float32)
28 | coord = torch.as_tensor(coord, dtype=k_src.dtype, device=k_src.device)
29 | coord = coord[None, ..., None].repeat(batch_size, 1, 1, 1, 1)
30 |
31 | depth = src_depths[:, :, :, None, None]
32 |
33 | # from reference to target viewpoint
34 | pts_3d_ref = depth * k_src_inv[:, None, None, ...] @ coord
35 | pts_3d_tgt = rot[:, None, None, ...] @ pts_3d_ref + t[:, None, None, :, None]
36 | points = k_dst[:, None, None, ...] @ pts_3d_tgt
37 | points = points.squeeze(-1)
38 |
39 | new_z = points[:, :, :, [2]].clone().permute(0, 3, 1, 2) # b,1,h,w
40 | points = points / torch.clamp(points[:, :, :, [2]], 1e-8, None)
41 |
42 | src_ims_ = src_imgs.permute(0, 3, 1, 2)
43 | num_channels = src_ims_.shape[1]
44 |
45 | flow = points - coord.squeeze(-1)
46 | flow = flow.permute(0, 3, 1, 2)[:, :2, ...]
47 |
48 | importance = 1.0 / (new_z)
49 | importance_min = importance.amin((1, 2, 3), keepdim=True)
50 | importance_max = importance.amax((1, 2, 3), keepdim=True)
51 | weights = (importance - importance_min) / (
52 | importance_max - importance_min + 1e-6
53 | ) * 20 - 10
54 | src_mask_ = torch.ones_like(new_z)
55 |
56 | input_data = torch.cat([src_ims_, (1.0 / (new_z)), src_mask_], 1)
57 |
58 | output_data = splatting_function(
59 | 'softmax', input_data.cuda(), flow.cuda(), weights.detach().cuda()
60 | )
61 |
62 | warp_feature = output_data[:, 0:num_channels, ...]
63 | warp_disp = output_data[:, num_channels : num_channels + 1, ...]
64 | # warp_mask = output_data[:, num_channels + 1 : num_channels + 2, ...]
65 |
66 | return warp_feature, warp_disp#, warp_mask
67 |
68 | def render_wander_path(c2w, hwf, bd_scale, max_disp_=50, xyz=[1, 0, 1]):
69 | """Render nearby virtual source views with displacement in x and z direciton."""
70 | num_frames = 60
71 | max_disp = max_disp_ * bd_scale
72 | max_trans = (
73 | max_disp / hwf[2][0]
74 | )
75 | output_poses = []
76 |
77 | for i in range(num_frames):
78 |
79 | x_trans = max_trans * np.cos(
80 | 2.0 * np.pi * float(i) / float(num_frames)
81 | ) * xyz[0]
82 | y_trans = max_trans * np.sin(
83 | 2.0 * np.pi * float(i) / float(num_frames)
84 | ) * xyz[1]
85 | z_trans = max_trans * np.cos(
86 | 2.0 * np.pi * float(i) / float(num_frames)
87 | ) * xyz[2]
88 |
89 | i_pose = np.concatenate(
90 | [
91 | np.concatenate(
92 | [
93 | np.eye(3),
94 | np.array([x_trans, y_trans, z_trans])[:, np.newaxis],
95 | ],
96 | axis=1,
97 | ),
98 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :],
99 | ],
100 | axis=0,
101 | )
102 |
103 | i_pose = np.linalg.inv(
104 | i_pose
105 | ) # torch.tensor(np.linalg.inv(i_pose)).float()
106 |
107 | ref_pose = np.concatenate(
108 | [c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0
109 | )
110 |
111 | render_pose = np.dot(ref_pose, i_pose)
112 |
113 | output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1))
114 |
115 | return np.array(output_poses + output_poses), num_frames
116 |
117 |
118 | def sobel_fg_alpha(disp, mode='sobel', beta=10.0):
119 | """Create depth boundary mask."""
120 | sobel_grad = kornia.filters.spatial_gradient(
121 | disp, mode=mode, normalized=False
122 | )
123 | sobel_mag = torch.sqrt(
124 | sobel_grad[:, :, 0, ...] ** 2 + sobel_grad[:, :, 1, ...] ** 2
125 | )
126 | alpha = torch.exp(-1.0 * beta * sobel_mag).detach()
127 |
128 | return alpha
129 |
130 |
131 | FINAL_H = 288
132 | USE_DPT = True
133 |
134 | if __name__ == '__main__':
135 | parser = argparse.ArgumentParser()
136 | # parser.add_argument("--scene_name", type=str,
137 | # help='Scene name') # 'kid-running'
138 | parser.add_argument("--data_dir", type=str,
139 | help='data directory') # '/home/zhengqili/filestore/NSFF/nerf_data/release'
140 | parser.add_argument("--cvd_dir", type=str,
141 | help='video depth directory') # '/home/zhengqili/filestore/dynamic-video-DPT/monocular-results/kid-runningscene_flow_motion_field_shutterstock_epoch_15/epoch0015_test'
142 |
143 | args = parser.parse_args()
144 |
145 | data_path = os.path.join(
146 | args.data_dir, 'dense'
147 | )
148 |
149 | pt_out_list = sorted(
150 | glob.glob(
151 | os.path.join(
152 | args.cvd_dir,
153 | '*.npz',
154 | )
155 | )
156 | )
157 |
158 | try:
159 | original_img_path = os.path.join(data_path, 'images', '00000.png')
160 | o_img = imageio.imread(original_img_path)
161 | except:
162 | original_img_path = os.path.join(data_path, 'images', '00000.jpg')
163 | o_img = imageio.imread(original_img_path)
164 |
165 | o_ar = float(o_img.shape[1]) / float(o_img.shape[0])
166 |
167 | final_w, final_h = int(round(FINAL_H * o_ar)), int(FINAL_H)
168 |
169 | save_dir = os.path.join(
170 | data_path, 'source_virtual_views_%dx%d' % (final_w, final_h)
171 | )
172 | os.makedirs(save_dir, exist_ok=True)
173 |
174 | Ks = []
175 | mono_depths = []
176 | c2w_mats = []
177 | imgs = []
178 | bounds_mats = []
179 | points_cloud = []
180 |
181 | for i in range(0, len(pt_out_list)):
182 | pt_out_path = pt_out_list[i]
183 | out_name = pt_out_path.split('/')[-1]
184 | pt_data = np.load(pt_out_path)
185 | pred_depth = pt_data['depth'][0, 0, ...]
186 | cam_c2w = pt_data['cam_c2w'][0]
187 | img = pt_data['img_1'][0].transpose(1, 2, 0)
188 |
189 | c2w_mats.append(cam_c2w)
190 | bounds_mats.append(np.percentile(pred_depth, 5))
191 | K = pt_data['K'][0, 0, 0, ...].transpose()
192 | K[0, :] *= final_w / img.shape[1]
193 | K[1, :] *= final_h / img.shape[0]
194 |
195 | h, w, fx, fy = final_h, final_w, K[0, 0], K[1, 1]
196 | ff = (fx + fy) / 2.0
197 | # hwf = np.array([h, w, fx, fy]).reshape([1, 4])
198 | hwf = np.array([h, w, ff]).reshape([3, 1])
199 |
200 | c2w_mats = np.stack(c2w_mats, 0)
201 | bounds_mats = np.stack(bounds_mats, 0)
202 |
203 | bd_scale = bounds_mats.min() * 0.75
204 |
205 | poses = c2w_mats[:, :3, :4].transpose([1, 2, 0])
206 |
207 | # must switch to [-y, x, z] from [x, -y, -z], NOT [r, u, -t]
208 | poses = np.concatenate(
209 | [poses[:, 1:2, :], poses[:, 0:1, :], -poses[:, 2:3, :], poses[:, 3:4, :]],
210 | 1,
211 | )
212 | poses = np.moveaxis(poses, -1, 0).astype(np.float32)
213 |
214 | num_samples = 4
215 | vv_poses_final = np.zeros((poses.shape[0], num_samples * 2, 3, 4))
216 |
217 | for ii in range(poses.shape[0]):
218 | print(ii)
219 | virtural_poses_0, num_render_0 = render_wander_path(
220 | poses[ii], hwf, bd_scale, 56 * 1.5,
221 | xyz=[0., 1., 1.] # y, x, z
222 | )
223 | virtural_poses_1, num_render_1 = render_wander_path(
224 | poses[ii], hwf, bd_scale, 48 * 1.5,
225 | xyz=[0.5, 1., 0.]
226 | )
227 | # this is for fixed viewpoint!
228 | start_idx = np.random.randint(0, num_render_0 // num_samples)
229 |
230 | vv_poses_final[ii, :num_samples, ...] = virtural_poses_0[
231 | 5 : -1 : int(num_render_0 // num_samples)
232 | ][:num_samples, :3, :4]
233 | vv_poses_final[ii, num_samples:, ...] = virtural_poses_1[
234 | 15 : -1 : int(num_render_1 // num_samples)
235 | ][:num_samples, :3, :4]
236 |
237 | np.save(
238 | os.path.join(data_path, 'source_vv_poses.npy'),
239 | np.moveaxis(vv_poses_final, 0, -1).astype(np.float32),
240 | )
241 |
242 | # switch back
243 | c2w_mats_vsv = np.concatenate(
244 | [
245 | vv_poses_final[..., 1:2],
246 | vv_poses_final[..., 0:1],
247 | -vv_poses_final[..., 2:3],
248 | vv_poses_final[..., 3:4],
249 | ],
250 | -1,
251 | )
252 |
253 | for i in range(0, len(pt_out_list)):
254 | save_sub_dir = os.path.join(save_dir, '%05d' % i)
255 | print(save_sub_dir)
256 | os.makedirs(save_sub_dir, exist_ok=True)
257 | pt_out_path = pt_out_list[i]
258 |
259 | out_name = pt_out_path.split('/')[-1]
260 | pt_data = np.load(pt_out_path)
261 |
262 | K = pt_data['K'][0, 0, 0, ...].transpose()
263 | img = pt_data['img_1'][0].transpose(1, 2, 0)
264 | cam_ref2w = pt_data['cam_c2w'][0]
265 | pred_depth = pt_data['depth'][0, 0, ...]
266 | pred_disp = 1.0 / pred_depth
267 |
268 | K[0, :] *= final_w / img.shape[1]
269 | K[1, :] *= final_h / img.shape[0]
270 |
271 | print('K ', K)
272 | assert abs(K[0, 0] - K[1, 1]) / abs(K[0, 0] + K[1, 1]) < 0.005
273 |
274 | pred_depth_ = cv2.resize(
275 | pred_depth, (final_w, final_h), interpolation=cv2.INTER_NEAREST
276 | )
277 |
278 | img = cv2.resize(img, (final_w, final_h), interpolation=cv2.INTER_AREA)
279 | pred_disp = cv2.resize(
280 | pred_disp, (final_w, final_h), interpolation=cv2.INTER_LINEAR
281 | )
282 |
283 | mode = 'sobel'
284 | beta = 0.5
285 | pred_depth = 1.0 / torch.from_numpy(pred_disp[None, None, ...])
286 | pred_depth = pred_depth / 10.0
287 | cur_alpha = sobel_fg_alpha(pred_depth, mode, beta=beta)[
288 | 0, 0, ..., None
289 | ].numpy()
290 |
291 | for k in range(num_samples * 2):
292 | # render source view into target viewpoint
293 | rgba_pt = torch.from_numpy(
294 | np.concatenate(
295 | [np.array(img * 255.0), cur_alpha], axis=-1
296 | )
297 | )[None].float()
298 | disp_pt = torch.from_numpy(np.array(pred_disp))[
299 | None
300 | ].float()
301 | cam_tgt2w = np.eye(4)
302 | cam_tgt2w[:3, :4] = c2w_mats_vsv[i, k]
303 | T_ref2tgt = np.dot(np.linalg.inv(cam_tgt2w), cam_ref2w)
304 |
305 | fwd_rot = torch.from_numpy(T_ref2tgt[:3, :3])[None].float()
306 | fwd_t = torch.from_numpy(T_ref2tgt[:3, 3])[None].float() # * metric_scale
307 | k_ref = torch.from_numpy(np.array(K))[None].float()
308 |
309 | render_rgba, render_depth = render_forward_splat(
310 | rgba_pt, 1.0 / disp_pt, fwd_rot, fwd_t, k_src=k_ref, k_dst=k_ref
311 | )
312 |
313 | render_rgb = np.clip(
314 | render_rgba[0, :3, ...].cpu().numpy().transpose(1, 2, 0) / 255.0,
315 | 0.0,
316 | 1.0,
317 | )
318 | mask = np.clip(
319 | render_rgba[0, 3:4, ...].cpu().numpy().transpose(1, 2, 0), 0.0, 1.0
320 | )
321 | mask = skimage.morphology.erosion(
322 | mask[..., 0] > 0.5, skimage.morphology.disk(1)
323 | )
324 |
325 | render_rgb_masked = render_rgb * mask[..., None]
326 | h, w = render_rgb_masked.shape[:2]
327 | imageio.imsave(
328 | os.path.join(save_sub_dir, '%02d.png' % k),
329 | np.uint8(255 * np.clip(render_rgb_masked, 0.0, 1.0)),
330 | )
331 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/ibrnet/sample_ray.py:
--------------------------------------------------------------------------------
1 | """Utility class for sampling data corresponding to rays from images."""
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | from kornia import create_meshgrid
7 |
8 | rng = np.random.RandomState(234)
9 |
10 |
11 | def parse_camera(params):
12 | H = params[:, 0]
13 | W = params[:, 1]
14 | intrinsics = params[:, 2:18].reshape((-1, 4, 4))
15 | c2w = params[:, 18:34].reshape((-1, 4, 4))
16 | return W, H, intrinsics, c2w
17 |
18 |
19 | class RaySamplerSingleImage(object):
20 | """Sampling data corresponding to the rays from a target view.
21 |
22 | This class stores and returns following items at sampled pixel locations
23 | for training Dynibar:
24 | ray_o: ray origin at target view
25 | ray_d: ray direction at target view
26 | depth_range: scene depth bounds at target view
27 | camera: reference time camera parameters
28 | render_camera: rendered target view camera parameters
29 | anchor_camera: camera parameters for input view at nearby cross time
30 | rgb: image at reference time
31 | src_rgbs: source view images w.r.t reference time for dynamic model
32 | src_cameras: source view camera parameters w.r.t reference time for
33 | dynamic model
34 | anchor_src_rgbs: source view images w.r.t nearby cross time for dynamic
35 | model.
36 | anchor_src_cameras: source view camera parameters w.r.t
37 | nearby cross time for dynamic model.
38 | static_src_rgbs: source view images for static model
39 | static_src_cameras: source view camera parameters for static model
40 | static_src_masks: dynamic masks of source views for static model
41 | disp: disparity map
42 | motion_mask: dynamic mask
43 | static_mask: static masks
44 | uv_grid: 2D pixel coorindate in image space
45 | flows: observed 2D optical flows
46 | masks: optical flow vailid masks
47 | """
48 |
49 | def __init__(self, data, device, resize_factor=1, render_stride=1):
50 | super().__init__()
51 | self.render_stride = render_stride
52 | self.rgb = data['rgb'] if 'rgb' in data.keys() else None
53 | self.disp = data['disp'] if 'disp' in data.keys() else None
54 |
55 | self.motion_mask = (
56 | data['motion_mask'] if 'motion_mask' in data.keys() else None
57 | )
58 |
59 | self.static_mask = (
60 | data['static_mask'] if 'static_mask' in data.keys() else None
61 | )
62 |
63 | self.flows = data['flows'].squeeze(0) if 'flows' in data.keys() else None
64 | self.masks = data['masks'].squeeze(0) if 'masks' in data.keys() else None
65 |
66 | self.camera = data['camera']
67 | self.render_camera = (
68 | data['render_camera'] if 'render_camera' in data.keys() else None
69 | )
70 |
71 | self.anchor_camera = (
72 | data['anchor_camera'] if 'anchor_camera' in data.keys() else None
73 | )
74 | self.rgb_path = data['rgb_path']
75 | self.depth_range = data['depth_range']
76 | self.device = device
77 | W, H, self.intrinsics, self.c2w_mat = parse_camera(self.camera)
78 |
79 | self.batch_size = len(self.camera)
80 |
81 | self.H = int(H[0])
82 | self.W = int(W[0])
83 | self.uv_grid = create_meshgrid(
84 | self.H, self.W, normalized_coordinates=False
85 | )[0].to(
86 | self.device
87 | ) # (H, W, 2)
88 |
89 | self.rays_o, self.rays_d = self.get_rays_single_image(
90 | self.H, self.W, self.intrinsics, self.c2w_mat
91 | )
92 |
93 | if self.rgb is not None:
94 | self.rgb = self.rgb.reshape(-1, 3)
95 |
96 | if self.disp is not None:
97 | self.disp = self.disp.reshape(-1, 1)
98 |
99 | if self.motion_mask is not None:
100 | self.motion_mask = self.motion_mask.reshape(-1, 1)
101 |
102 | if self.static_mask is not None:
103 | self.static_mask = self.static_mask.reshape(-1, 1)
104 |
105 | if self.flows is not None:
106 | self.flows = self.flows.reshape(self.flows.shape[0], -1, 2)
107 | self.masks = self.masks.reshape(self.masks.shape[0], -1, 1)
108 |
109 | self.uv_grid = self.uv_grid.reshape(-1, 2)
110 |
111 | if 'src_rgbs' in data.keys():
112 | self.src_rgbs = data['src_rgbs']
113 | else:
114 | self.src_rgbs = None
115 |
116 | if 'src_cameras' in data.keys():
117 | self.src_cameras = data['src_cameras']
118 | else:
119 | self.src_cameras = None
120 |
121 | self.anchor_src_rgbs = (
122 | data['anchor_src_rgbs'] if 'anchor_src_rgbs' in data.keys() else None
123 | )
124 | self.anchor_src_cameras = (
125 | data['anchor_src_cameras']
126 | if 'anchor_src_cameras' in data.keys()
127 | else None
128 | )
129 |
130 | self.static_src_rgbs = (
131 | data['static_src_rgbs'] if 'static_src_rgbs' in data.keys() else None
132 | )
133 | self.static_src_cameras = (
134 | data['static_src_cameras']
135 | if 'static_src_cameras' in data.keys()
136 | else None
137 | )
138 | self.static_src_masks = (
139 | data['static_src_masks'] if 'static_src_masks' in data.keys() else None
140 | )
141 |
142 |
143 | def get_rays_single_image(self, H, W, intrinsics, c2w):
144 | """Return ray parameters (origin, direction) from a target view."""
145 | u, v = np.meshgrid(
146 | np.arange(W)[:: self.render_stride], np.arange(H)[:: self.render_stride]
147 | )
148 | u = u.reshape(-1).astype(dtype=np.float32)
149 | v = v.reshape(-1).astype(dtype=np.float32)
150 | pixels = np.stack((u, v, np.ones_like(u)), axis=0) # (3, H*W)
151 | pixels = torch.from_numpy(pixels)
152 | batched_pixels = pixels.unsqueeze(0).repeat(self.batch_size, 1, 1)
153 |
154 | rays_d = (
155 | c2w[:, :3, :3]
156 | .bmm(torch.inverse(intrinsics[:, :3, :3]))
157 | .bmm(batched_pixels)
158 | ).transpose(1, 2)
159 | rays_d = rays_d.reshape(-1, 3)
160 | rays_o = (
161 | c2w[:, :3, 3].unsqueeze(1).repeat(1, rays_d.shape[0], 1).reshape(-1, 3)
162 | ) # B x HW x 3
163 | return rays_o, rays_d
164 |
165 | def get_all(self):
166 | """Return all camera and ray information from a target view."""
167 | ret = {
168 | 'ray_o': self.rays_o.to(self.device),
169 | 'ray_d': self.rays_d.to(self.device),
170 | 'depth_range': self.depth_range.to(self.device),
171 | 'camera': self.camera.to(self.device),
172 | 'render_camera': (
173 | self.render_camera.to(self.device)
174 | if self.render_camera is not None
175 | else None
176 | ),
177 | 'anchor_camera': (
178 | self.anchor_camera.to(self.device)
179 | if self.anchor_camera is not None
180 | else None
181 | ),
182 | 'rgb': self.rgb.to(self.device) if self.rgb is not None else None,
183 | 'src_rgbs': (
184 | self.src_rgbs.to(self.device) if self.src_rgbs is not None else None
185 | ),
186 | 'src_cameras': (
187 | self.src_cameras.to(self.device)
188 | if self.src_cameras is not None
189 | else None
190 | ),
191 | 'anchor_src_rgbs': (
192 | self.anchor_src_rgbs.to(self.device)
193 | if self.anchor_src_rgbs is not None
194 | else None
195 | ),
196 | 'anchor_src_cameras': (
197 | self.anchor_src_cameras.to(self.device)
198 | if self.anchor_src_cameras is not None
199 | else None
200 | ),
201 | 'static_src_rgbs': (
202 | self.static_src_rgbs.to(self.device)
203 | if self.static_src_rgbs is not None
204 | else None
205 | ),
206 | 'static_src_cameras': (
207 | self.static_src_cameras.to(self.device)
208 | if self.static_src_cameras is not None
209 | else None
210 | ),
211 | 'static_src_masks': (
212 | self.static_src_masks.to(self.device)
213 | if self.static_src_masks is not None
214 | else None
215 | ),
216 | 'disp': (
217 | self.disp.to(self.device).squeeze()
218 | if self.disp is not None
219 | else None
220 | ),
221 | 'motion_mask': (
222 | self.motion_mask.to(self.device).squeeze()
223 | if self.motion_mask is not None
224 | else None
225 | ),
226 | 'static_mask': (
227 | self.static_mask.to(self.device).squeeze()
228 | if self.static_mask is not None
229 | else None
230 | ),
231 | 'uv_grid': self.uv_grid.to(self.device),
232 | 'flows': self.flows.to(self.device) if self.flows is not None else None,
233 | 'masks': self.masks.to(self.device) if self.masks is not None else None,
234 | }
235 | return ret
236 |
237 | def sample_random_pixel(self, N_rand, sample_mode, center_ratio=0.8):
238 | """Sample pixel randomly from the target view."""
239 | if sample_mode == 'center':
240 | border_H = int(self.H * (1 - center_ratio) / 2.0)
241 | border_W = int(self.W * (1 - center_ratio) / 2.0)
242 |
243 | # pixel coordinates
244 | u, v = np.meshgrid(
245 | np.arange(border_H, self.H - border_H),
246 | np.arange(border_W, self.W - border_W),
247 | )
248 | u = u.reshape(-1)
249 | v = v.reshape(-1)
250 |
251 | select_inds = rng.choice(u.shape[0], size=(N_rand,), replace=False)
252 | select_inds = v[select_inds] + self.W * u[select_inds]
253 |
254 | elif sample_mode == 'uniform':
255 | # Random from one image
256 | select_inds = rng.choice(self.H*self.W, size=(N_rand,), replace=False)
257 | else:
258 | raise NotImplementedError
259 |
260 | return select_inds
261 |
262 | def random_sample(self, N_rand, sample_mode, center_ratio=0.8):
263 | """Randomly sample pixel and pixel data from the target view."""
264 | select_inds = self.sample_random_pixel(N_rand, sample_mode, center_ratio)
265 |
266 | rays_o = self.rays_o[select_inds]
267 | rays_d = self.rays_d[select_inds]
268 |
269 | if self.rgb is not None:
270 | rgb = self.rgb[select_inds]
271 | disp = self.disp[select_inds].squeeze()
272 | motion_mask = self.motion_mask[select_inds].squeeze()
273 | static_mask = self.static_mask[select_inds].squeeze()
274 |
275 | flows = self.flows[:, select_inds, :]
276 | masks = self.masks[:, select_inds, :]
277 |
278 | uv_grid = self.uv_grid[select_inds]
279 |
280 | else:
281 | raise NotImplementedError
282 |
283 | ret = {
284 | 'ray_o': rays_o.to(self.device),
285 | 'ray_d': rays_d.to(self.device),
286 | 'camera': self.camera.to(self.device),
287 | 'anchor_camera': self.anchor_camera.to(self.device),
288 | 'depth_range': self.depth_range.to(self.device),
289 | 'rgb': rgb.to(self.device) if rgb is not None else None,
290 | 'disp': disp.to(self.device),
291 | 'motion_mask': motion_mask.to(self.device),
292 | 'static_mask': static_mask.to(self.device),
293 | 'uv_grid': uv_grid.to(self.device),
294 | 'flows': flows.to(self.device),
295 | 'masks': masks.to(self.device),
296 | 'src_rgbs': (
297 | self.src_rgbs.to(self.device) if self.src_rgbs is not None else None
298 | ),
299 | 'src_cameras': (
300 | self.src_cameras.to(self.device)
301 | if self.src_cameras is not None
302 | else None
303 | ),
304 | 'static_src_rgbs': (
305 | self.static_src_rgbs.to(self.device)
306 | if self.static_src_rgbs is not None
307 | else None
308 | ),
309 | 'static_src_cameras': (
310 | self.static_src_cameras.to(self.device)
311 | if self.static_src_cameras is not None
312 | else None
313 | ),
314 | 'static_src_masks': (
315 | self.static_src_masks.to(self.device)
316 | if self.static_src_masks is not None
317 | else None
318 | ),
319 | 'anchor_src_rgbs': (
320 | self.anchor_src_rgbs.to(self.device)
321 | if self.anchor_src_rgbs is not None
322 | else None
323 | ),
324 | 'anchor_src_cameras': (
325 | self.anchor_src_cameras.to(self.device)
326 | if self.anchor_src_cameras is not None
327 | else None
328 | ),
329 | 'selected_inds': select_inds,
330 | }
331 | return ret
332 |
--------------------------------------------------------------------------------
/render_monocular_bt.py:
--------------------------------------------------------------------------------
1 | """Script to render novel views from pretrained model."""
2 |
3 | import torch
4 | from torch.utils.data import Dataset
5 | from torch.utils.data import DataLoader
6 | import imageio.v2 as imageio
7 | from config import config_parser
8 | from ibrnet.sample_ray import RaySamplerSingleImage
9 | from ibrnet.render_image import render_single_image_mono
10 | from ibrnet.model import DynibarMono
11 | from ibrnet.projection import Projector
12 | from ibrnet.data_loaders.data_utils import get_nearest_pose_ids
13 | from ibrnet.data_loaders.data_utils import get_interval_pose_ids
14 | from ibrnet.data_loaders.llff_data_utils import load_mono_data
15 | from ibrnet.data_loaders.llff_data_utils import batch_parse_llff_poses
16 | from ibrnet.data_loaders.llff_data_utils import batch_parse_vv_poses
17 | import time
18 | import os
19 | import numpy as np
20 | import cv2
21 |
22 |
23 | class DynamicVideoDataset(Dataset):
24 | """Class for defining monocular video data.
25 |
26 | Attributes:
27 | folder_path: root path
28 | num_source_views: number of source views to sample
29 | mask_src_view: using mask to mask moving objects
30 | render_idx: rendering frame index
31 | max_range: max sampling frame range
32 | render_rgb_files: rendering RGB file path
33 | render_intrinsics: rendering camera intrinsics
34 | render_poses: rendering camera poses
35 | render_depth_range: rendering depth bounds
36 | h: image height
37 | w: image width
38 | train_intrinsics: training camera intrinisc
39 | train_poses: training camera poses
40 | train_rgb_files: training RGB path
41 | num_frames: number of video frames
42 | src_vv_c2w_mats: virtual views camera matrix
43 | """
44 |
45 | def __init__(self, args, scenes, **kwargs):
46 | self.folder_path = (
47 | args.folder_path
48 | )
49 | self.num_source_views = args.num_source_views
50 | self.mask_src_view = args.mask_src_view
51 | self.render_idx = args.render_idx
52 | self.max_range = args.max_range
53 | self.num_vv = args.num_vv
54 | print('num_source_views ', self.num_source_views)
55 | print('loading {} for rendering'.format(scenes))
56 | assert len(scenes) == 1
57 |
58 | scene = scenes[0]
59 | # for i, scene in enumerate(scenes):
60 | scene_path = os.path.join(self.folder_path, scene, 'dense')
61 | _, poses, src_vv_poses, bds, render_poses, _, rgb_files, _ = (
62 | load_mono_data(
63 | scene_path,
64 | height=args.training_height,
65 | render_idx=self.render_idx,
66 | load_imgs=False,
67 | )
68 | )
69 | near_depth = np.min(bds)
70 |
71 | if np.max(bds) < 10:
72 | far_depth = min(50, np.max(bds) + 15.0)
73 | else:
74 | far_depth = min(50, max(20, np.max(bds)))
75 |
76 | self.num_frames = len(rgb_files)
77 |
78 | intrinsics, c2w_mats = batch_parse_llff_poses(poses)
79 | h, w = poses[0][:2, -1]
80 | render_intrinsics, render_c2w_mats = batch_parse_llff_poses(render_poses)
81 | self.src_vv_c2w_mats = batch_parse_vv_poses(src_vv_poses)
82 |
83 | self.train_intrinsics = intrinsics
84 | self.train_poses = c2w_mats
85 | self.train_rgb_files = rgb_files
86 |
87 | self.render_intrinsics = render_intrinsics
88 | self.render_poses = render_c2w_mats
89 | self.render_depth_range = [[near_depth, far_depth]] * self.num_frames
90 | self.h = [int(h)] * self.num_frames
91 | self.w = [int(w)] * self.num_frames
92 |
93 | def __len__(self):
94 | return len(self.render_poses)
95 |
96 | def __getitem__(self, idx):
97 | render_pose = self.render_poses[idx]
98 | intrinsics = self.render_intrinsics[idx]
99 | depth_range = self.render_depth_range[idx]
100 |
101 | train_rgb_files = self.train_rgb_files
102 | train_poses = self.train_poses
103 | train_intrinsics = self.train_intrinsics
104 |
105 | rgb_file = train_rgb_files[idx]
106 | rgb = imageio.imread(rgb_file).astype(np.float32) / 255.0
107 |
108 | h, w = self.h[idx], self.w[idx]
109 | camera = np.concatenate(
110 | ([h, w], intrinsics.flatten(), render_pose.flatten())
111 | ).astype(np.float32)
112 |
113 | nearest_pose_ids = np.sort(
114 | [self.render_idx + offset for offset in [1, 2, 3, 0, -1, -2, -3]]
115 | )
116 | sp_pose_ids = get_nearest_pose_ids(
117 | render_pose, train_poses, tar_id=-1, angular_dist_method='dist'
118 | )
119 |
120 | static_pose_ids = []
121 | frame_interval = args.max_range // self.num_source_views
122 | interval_pose_ids = get_interval_pose_ids(
123 | render_pose,
124 | train_poses,
125 | tar_id=-1,
126 | angular_dist_method='dist',
127 | interval=frame_interval,
128 | )
129 |
130 | for sp_pose_id in interval_pose_ids:
131 | if len(static_pose_ids) >= (self.num_source_views * 2 + 1):
132 | break
133 |
134 | if np.abs(sp_pose_id - self.render_idx) > (
135 | self.max_range + self.num_source_views * 0.5
136 | ):
137 | continue
138 |
139 | static_pose_ids.append(sp_pose_id)
140 |
141 | static_pose_set = set(static_pose_ids)
142 |
143 | # if there is no sufficient src imgs, naively choose the closest images
144 | for sp_pose_id in sp_pose_ids[::5]:
145 | if len(static_pose_ids) >= (self.num_source_views * 2 + 1):
146 | break
147 |
148 | if sp_pose_id in static_pose_set:
149 | continue
150 |
151 | static_pose_ids.append(sp_pose_id)
152 |
153 | static_pose_ids = np.sort(static_pose_ids)
154 |
155 | assert len(static_pose_ids) == (self.num_source_views * 2 + 1)
156 |
157 | src_rgbs = []
158 | src_cameras = []
159 | for src_idx in nearest_pose_ids:
160 | src_rgb = (
161 | imageio.imread(train_rgb_files[src_idx]).astype(np.float32) / 255.0
162 | )
163 | train_pose = train_poses[src_idx]
164 | train_intrinsics_ = train_intrinsics[src_idx]
165 | src_rgbs.append(src_rgb)
166 | img_size = src_rgb.shape[:2]
167 | src_camera = np.concatenate(
168 | (list(img_size), train_intrinsics_.flatten(), train_pose.flatten())
169 | ).astype(np.float32)
170 |
171 | src_cameras.append(src_camera)
172 |
173 | # load src virtual views
174 | vv_pose_ids = get_nearest_pose_ids(
175 | render_pose,
176 | self.src_vv_c2w_mats[self.render_idx],
177 | tar_id=-1,
178 | angular_dist_method='dist',
179 | )
180 |
181 | # load virtual source views
182 | num_vv = self.num_vv
183 | for virtual_idx in vv_pose_ids[:num_vv]:
184 | src_vv_path = os.path.join(
185 | '/'.join(
186 | rgb_file.replace('images', 'source_virtual_views').split('/')[:-1]
187 | ),
188 | '%05d' % self.render_idx,
189 | '%02d.png' % virtual_idx,
190 | )
191 | src_rgb = imageio.imread(src_vv_path).astype(np.float32) / 255.0
192 | src_rgbs.append(src_rgb)
193 | img_size = src_rgb.shape[:2]
194 |
195 | src_camera = np.concatenate((
196 | list(img_size),
197 | intrinsics.flatten(),
198 | self.src_vv_c2w_mats[self.render_idx, virtual_idx].flatten(),
199 | )).astype(np.float32)
200 |
201 | src_cameras.append(src_camera)
202 |
203 | src_rgbs = np.stack(src_rgbs, axis=0)
204 | src_cameras = np.stack(src_cameras, axis=0)
205 |
206 | static_src_rgbs = []
207 | static_src_cameras = []
208 | # load src rgb for static view
209 | for st_near_id in static_pose_ids:
210 | src_rgb = (
211 | imageio.imread(train_rgb_files[st_near_id]).astype(np.float32) / 255.0
212 | )
213 | train_pose = train_poses[st_near_id]
214 | train_intrinsics_ = train_intrinsics[st_near_id]
215 |
216 | if self.mask_src_view:
217 | st_mask_path = os.path.join(
218 | '/'.join(rgb_file.split('/')[:-2]),
219 | 'dynamic_masks',
220 | '%d.png' % st_near_id,
221 | )
222 | st_mask = imageio.imread(st_mask_path).astype(np.float32) / 255.0
223 | st_mask = cv2.resize(
224 | st_mask,
225 | (src_rgb.shape[1], src_rgb.shape[0]),
226 | interpolation=cv2.INTER_NEAREST,
227 | )
228 |
229 | if len(st_mask.shape) == 2:
230 | st_mask = st_mask[..., None]
231 |
232 | src_rgb = src_rgb * st_mask
233 |
234 | static_src_rgbs.append(src_rgb)
235 | img_size = src_rgb.shape[:2]
236 | src_camera = np.concatenate(
237 | (list(img_size), train_intrinsics_.flatten(), train_pose.flatten())
238 | ).astype(np.float32)
239 |
240 | static_src_cameras.append(src_camera)
241 |
242 | static_src_rgbs = np.stack(static_src_rgbs, axis=0)
243 | static_src_cameras = np.stack(static_src_cameras, axis=0)
244 |
245 | depth_range = torch.tensor([depth_range[0] * 0.9, depth_range[1] * 1.5])
246 |
247 | return {
248 | 'camera': torch.from_numpy(camera),
249 | 'rgb_path': '',
250 | 'rgb': torch.from_numpy(rgb),
251 | 'src_rgbs': torch.from_numpy(src_rgbs[..., :3]).float(),
252 | 'src_cameras': torch.from_numpy(src_cameras).float(),
253 | 'static_src_rgbs': torch.from_numpy(static_src_rgbs[..., :3]).float(),
254 | 'static_src_cameras': torch.from_numpy(static_src_cameras).float(),
255 | 'depth_range': depth_range,
256 | 'ref_time': float(self.render_idx / float(self.num_frames)),
257 | 'id': self.render_idx,
258 | 'nearest_pose_ids': nearest_pose_ids
259 | }
260 |
261 | if __name__ == '__main__':
262 | parser = config_parser()
263 | args = parser.parse_args()
264 | args.distributed = False
265 |
266 | test_dataset = DynamicVideoDataset(args, scenes=args.eval_scenes)
267 | args.num_frames = test_dataset.num_frames
268 |
269 | # Create ibrnet model
270 | model = DynibarMono(args)
271 | eval_dataset_name = args.eval_dataset
272 | extra_out_dir = '{}/{}/{}'.format(
273 | eval_dataset_name, args.expname, str(args.render_idx)
274 | )
275 | print('saving results to {}...'.format(extra_out_dir))
276 | os.makedirs(extra_out_dir, exist_ok=True)
277 |
278 | projector = Projector(device='cuda:0')
279 |
280 | assert len(args.eval_scenes) == 1, 'only accept single scene'
281 | scene_name = args.eval_scenes[0]
282 | out_scene_dir = os.path.join(
283 | extra_out_dir, '{}_{:06d}'.format(scene_name, model.start_step), 'videos'
284 | )
285 | print('saving results to {}'.format(out_scene_dir))
286 |
287 | os.makedirs(out_scene_dir, exist_ok=True)
288 | os.makedirs(os.path.join(out_scene_dir, 'rgb_out'), exist_ok=True)
289 |
290 | save_prefix = scene_name
291 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
292 | total_num = len(test_loader)
293 | out_frames = []
294 | full_frames = []
295 | crop_ratio = 0.03
296 |
297 | for i, data in enumerate(test_loader):
298 | idx = int(data['id'].item())
299 | start = time.time()
300 | ref_time_embedding = data['ref_time'].cuda()
301 | ref_frame_idx = int(data['id'].item())
302 | ref_time_offset = [
303 | int(near_idx - ref_frame_idx)
304 | for near_idx in data['nearest_pose_ids'].squeeze().tolist()
305 | ]
306 |
307 | model.switch_to_eval()
308 | with torch.no_grad():
309 | ray_sampler = RaySamplerSingleImage(data, device='cuda:0')
310 | ray_batch = ray_sampler.get_all()
311 |
312 | cb_featmaps_1, cb_featmaps_2 = model.feature_net(
313 | ray_batch['src_rgbs'].squeeze(0).permute(0, 3, 1, 2)
314 | )
315 | ref_featmaps = cb_featmaps_1 # [0:NUM_DYNAMIC_SRC_VIEWS]
316 |
317 | static_src_rgbs = (
318 | ray_batch['static_src_rgbs'].squeeze(0).permute(0, 3, 1, 2)
319 | )
320 | static_featmaps, _ = model.feature_net_st(static_src_rgbs)
321 |
322 | ret = render_single_image_mono(
323 | frame_idx=(ref_frame_idx, None),
324 | time_embedding=(ref_time_embedding, None),
325 | time_offset=(ref_time_offset, None),
326 | ray_sampler=ray_sampler,
327 | ray_batch=ray_batch,
328 | model=model,
329 | projector=projector,
330 | chunk_size=args.chunk_size,
331 | det=True,
332 | N_samples=args.N_samples,
333 | args=args,
334 | inv_uniform=args.inv_uniform,
335 | N_importance=args.N_importance,
336 | white_bkgd=args.white_bkgd,
337 | featmaps=(ref_featmaps, None, static_featmaps),
338 | is_train=False,
339 | num_vv=args.num_vv
340 | )
341 |
342 | coarse_pred_rgb = ret['outputs_coarse_ref']['rgb'].detach().cpu()
343 | coarse_pred_rgb_st = ret['outputs_coarse_ref']['rgb_static'].detach().cpu()
344 | coarse_pred_rgb_rgb = ret['outputs_coarse_ref']['rgb_dy'].detach().cpu()
345 |
346 | coarse_pred_rgb = (
347 | 255 * np.clip(coarse_pred_rgb.numpy(), a_min=0, a_max=1.0)
348 | ).astype(np.uint8)
349 |
350 | h, w = coarse_pred_rgb.shape[:2]
351 | crop_h = int(h * crop_ratio)
352 | crop_w = int(w * crop_ratio)
353 |
354 | coarse_pred_rgb = coarse_pred_rgb[crop_h:h-crop_h, crop_w:w-crop_w, ...]
355 |
356 | gt_rgb = data['rgb'][0, crop_h:h-crop_h, crop_w:w-crop_w, ...]
357 | gt_rgb = (255 * np.clip(gt_rgb.numpy(), a_min=0, a_max=1.)).astype(np.uint8)
358 |
359 | full_rgb = np.concatenate([gt_rgb, coarse_pred_rgb], axis=1)
360 |
361 | full_frames.append(coarse_pred_rgb)
362 |
363 | imageio.imwrite(os.path.join(out_scene_dir, 'rgb_out', '{}.png'.format(i)),
364 | coarse_pred_rgb)
365 |
366 | print('frame {} completed, {}'.format(i, time.time() - start))
367 |
--------------------------------------------------------------------------------
/ibrnet/render_image.py:
--------------------------------------------------------------------------------
1 | """Functions for rendering a target view."""
2 |
3 | from collections import OrderedDict
4 | from ibrnet.render_ray import render_rays_mono
5 | from ibrnet.render_ray import render_rays_mv
6 | import torch
7 |
8 |
9 | def render_single_image_nvi(
10 | frame_idx,
11 | time_embedding,
12 | time_offset,
13 | ray_sampler,
14 | ray_batch,
15 | model,
16 | projector,
17 | chunk_size,
18 | N_samples,
19 | args,
20 | inv_uniform=False,
21 | N_importance=0,
22 | det=False,
23 | white_bkgd=False,
24 | render_stride=1,
25 | coarse_featmaps=None,
26 | fine_featmaps=None,
27 | is_train=True,
28 | ):
29 | """Render a target view for Nvidia dataset.
30 |
31 | Args:
32 | frame_idx: video frame index
33 | time_embedding: input time embedding
34 | time_offset: offset w.r.t reference time
35 | ray_sampler: target view ray sampler
36 | ray_batch: batch of ray information
37 | model: dynibar model
38 | projector: perspective projection module
39 | chunk_size: processing chunk size
40 | N_samples: number of coarse samples along the ray
41 | args: additional input arguments
42 | inv_uniform: use disparity-based sampling or not
43 | N_importance: number of fine samples along the ray
44 | det: deterministic sampling
45 | white_bkgd: whether background is present
46 | render_stride: pixel stride when rendering images
47 | coarse_featmaps: coarse-stage 2D feature map
48 | fine_featmaps: fine-stage 2D feature map
49 | is_train: is training or not
50 |
51 | Returns:
52 | outputs_fine_anchor: rendered fine images at target view from contents at
53 | nearby time
54 | outputs_fine_ref: rendered fine images at target view from contents at
55 | target time
56 | outputs_coarse_ref: rendered coarse images at target view from contents at
57 | target time
58 | """
59 |
60 | all_ret = OrderedDict([
61 | ('outputs_fine_anchor', OrderedDict()),
62 | ('outputs_fine_ref', OrderedDict()),
63 | ('outputs_coarse_ref', OrderedDict()),
64 | ])
65 |
66 | N_rays = ray_batch['ray_o'].shape[0]
67 |
68 | for i in range(0, N_rays, chunk_size):
69 | chunk = OrderedDict()
70 | for k in ray_batch:
71 | if ray_batch[k] is None:
72 | chunk[k] = None
73 | elif k in [
74 | 'camera',
75 | 'depth_range',
76 | 'src_rgbs',
77 | 'src_cameras',
78 | 'anchor_src_rgbs',
79 | 'anchor_src_cameras',
80 | 'static_src_rgbs',
81 | 'static_src_cameras',
82 | ]:
83 | chunk[k] = ray_batch[k]
84 | elif len(ray_batch[k].shape) == 3: # flow and mask
85 | chunk[k] = ray_batch[k][:, i : i + chunk_size, ...]
86 | elif ray_batch[k] is not None:
87 | chunk[k] = ray_batch[k][i : i + chunk_size]
88 | else:
89 | chunk[k] = None
90 |
91 | ret = render_rays_mv(
92 | frame_idx=frame_idx,
93 | time_embedding=time_embedding,
94 | time_offset=time_offset,
95 | ray_batch=chunk,
96 | model=model,
97 | coarse_featmaps=coarse_featmaps,
98 | fine_featmaps=fine_featmaps,
99 | projector=projector,
100 | N_samples=N_samples,
101 | args=args,
102 | inv_uniform=inv_uniform,
103 | N_importance=N_importance,
104 | raw_noise_std=0.0,
105 | det=det,
106 | white_bkgd=white_bkgd,
107 | is_train=is_train,
108 | )
109 |
110 | # handle both coarse and fine outputs
111 | # cache chunk results on cpu
112 | if i == 0:
113 | for k in ret['outputs_coarse_ref']:
114 | all_ret['outputs_coarse_ref'][k] = []
115 |
116 | for k in ret['outputs_fine_ref']:
117 | all_ret['outputs_fine_ref'][k] = []
118 |
119 | if is_train:
120 | for k in ret['outputs_fine_anchor']:
121 | all_ret['outputs_fine_anchor'][k] = []
122 |
123 | for k in ret['outputs_coarse_ref']:
124 | all_ret['outputs_coarse_ref'][k].append(
125 | ret['outputs_coarse_ref'][k].cpu()
126 | )
127 |
128 | for k in ret['outputs_fine_ref']:
129 | all_ret['outputs_fine_ref'][k].append(ret['outputs_fine_ref'][k].cpu())
130 |
131 | if is_train:
132 | for k in ret['outputs_fine_anchor']:
133 | all_ret['outputs_fine_anchor'][k].append(
134 | ret['outputs_fine_anchor'][k].cpu()
135 | )
136 |
137 | rgb_strided = torch.ones(ray_sampler.H, ray_sampler.W, 3)[
138 | ::render_stride, ::render_stride, :
139 | ]
140 | # merge chunk results and reshape
141 | for k in all_ret['outputs_coarse_ref']:
142 | if k == 'random_sigma':
143 | continue
144 |
145 | if len(all_ret['outputs_coarse_ref'][k][0].shape) == 4:
146 | continue
147 |
148 | if len(all_ret['outputs_coarse_ref'][k][0].shape) == 3:
149 | tmp = torch.cat(all_ret['outputs_coarse_ref'][k], dim=1).reshape((
150 | all_ret['outputs_coarse_ref'][k][0].shape[0],
151 | rgb_strided.shape[0],
152 | rgb_strided.shape[1],
153 | -1,
154 | ))
155 | else:
156 | tmp = torch.cat(all_ret['outputs_coarse_ref'][k], dim=0).reshape(
157 | (rgb_strided.shape[0], rgb_strided.shape[1], -1)
158 | )
159 | all_ret['outputs_coarse_ref'][k] = tmp.squeeze()
160 |
161 | all_ret['outputs_coarse_ref']['rgb'][
162 | all_ret['outputs_coarse_ref']['mask'] == 0
163 | ] = 0.0
164 |
165 | # merge chunk results and reshape
166 | for k in all_ret['outputs_fine_ref']:
167 | if k == 'random_sigma':
168 | continue
169 |
170 | if len(all_ret['outputs_fine_ref'][k][0].shape) == 4:
171 | continue
172 |
173 | if len(all_ret['outputs_fine_ref'][k][0].shape) == 3:
174 | tmp = torch.cat(all_ret['outputs_fine_ref'][k], dim=1).reshape((
175 | all_ret['outputs_fine_ref'][k][0].shape[0],
176 | rgb_strided.shape[0],
177 | rgb_strided.shape[1],
178 | -1,
179 | ))
180 | else:
181 | tmp = torch.cat(all_ret['outputs_fine_ref'][k], dim=0).reshape(
182 | (rgb_strided.shape[0], rgb_strided.shape[1], -1)
183 | )
184 | all_ret['outputs_fine_ref'][k] = tmp.squeeze()
185 |
186 | all_ret['outputs_fine_ref']['rgb'][
187 | all_ret['outputs_fine_ref']['mask'] == 0
188 | ] = 0.0
189 |
190 | # merge chunk results and reshape
191 | if is_train:
192 | for k in all_ret['outputs_fine_anchor']:
193 | if k == 'random_sigma':
194 | continue
195 |
196 | if len(all_ret['outputs_fine_anchor'][k][0].shape) == 4:
197 | continue
198 |
199 | if len(all_ret['outputs_fine_anchor'][k][0].shape) == 3:
200 | tmp = torch.cat(all_ret['outputs_fine_anchor'][k], dim=1).reshape((
201 | all_ret['outputs_fine_anchor'][k][0].shape[0],
202 | rgb_strided.shape[0],
203 | rgb_strided.shape[1],
204 | -1,
205 | ))
206 | else:
207 | tmp = torch.cat(all_ret['outputs_fine_anchor'][k], dim=0).reshape(
208 | (rgb_strided.shape[0], rgb_strided.shape[1], -1)
209 | )
210 | all_ret['outputs_fine_anchor'][k] = tmp.squeeze()
211 |
212 | all_ret['outputs_fine_anchor']['rgb'][
213 | all_ret['outputs_fine_anchor']['mask'] == 0
214 | ] = 0.0
215 |
216 | all_ret['outputs_fine'] = None
217 | return all_ret
218 |
219 |
220 | def render_single_image_mono(
221 | frame_idx,
222 | time_embedding,
223 | time_offset,
224 | ray_sampler,
225 | ray_batch,
226 | model,
227 | projector,
228 | chunk_size,
229 | N_samples,
230 | args,
231 | inv_uniform=False,
232 | N_importance=0,
233 | det=False,
234 | white_bkgd=False,
235 | render_stride=1,
236 | featmaps=None,
237 | is_train=True,
238 | num_vv=2,
239 | ):
240 | """Render a target view for Monocular video.
241 |
242 | Args:
243 | frame_idx: video frame index
244 | time_embedding: input time embedding
245 | time_offset: offset w.r.t reference time
246 | ray_sampler: target view ray sampler
247 | ray_batch: batch of ray information
248 | model: dynibar model
249 | projector: perspective projection module
250 | chunk_size: processing chunk size
251 | N_samples: number of coarse samples along the ray
252 | args: additional input arguments
253 | inv_uniform: use disparity-based sampling or not
254 | N_importance: number of fine samples along the ray
255 | det: deterministic sampling
256 | white_bkgd: whether background is present
257 | render_stride: pixel stride when rendering images
258 | featmaps: coarse-stage 2D feature map
259 | is_train: is training or not
260 | num_vv: number of virtual source views used
261 |
262 | Returns:
263 | outputs_coarse_ref: rendered images at target view from combined contents at
264 | target time, coarse model
265 | outputs_coarse_st: rendered images at target view from static
266 | contents at target time, coarse model
267 | outputs_coarse_anchor: cross-rendered images at target view
268 | from combined contents at nearby time, coarse model
269 |
270 | """
271 |
272 | all_ret = OrderedDict([
273 | ('outputs_coarse_ref', OrderedDict()),
274 | ('outputs_coarse_st', OrderedDict()),
275 | ('outputs_coarse_anchor', OrderedDict()),
276 | ])
277 |
278 | N_rays = ray_batch['ray_o'].shape[0]
279 |
280 | for i in range(0, N_rays, chunk_size):
281 | chunk = OrderedDict()
282 | for k in ray_batch:
283 | if ray_batch[k] is None:
284 | chunk[k] = None
285 | elif k in [
286 | 'camera',
287 | 'anchor_camera',
288 | 'depth_range',
289 | 'src_rgbs',
290 | 'src_cameras',
291 | 'anchor_src_rgbs',
292 | 'anchor_src_cameras',
293 | 'static_src_rgbs',
294 | 'static_src_cameras',
295 | ]:
296 | chunk[k] = ray_batch[k]
297 | elif len(ray_batch[k].shape) == 3: # flow and mask
298 | chunk[k] = ray_batch[k][:, i : i + chunk_size, ...]
299 | elif ray_batch[k] is not None:
300 | chunk[k] = ray_batch[k][i : i + chunk_size]
301 | else:
302 | chunk[k] = None
303 |
304 | ret = render_rays_mono(
305 | frame_idx=frame_idx,
306 | time_embedding=time_embedding,
307 | time_offset=time_offset,
308 | ray_batch=chunk,
309 | model=model,
310 | featmaps=featmaps,
311 | projector=projector,
312 | N_samples=N_samples,
313 | args=args,
314 | inv_uniform=inv_uniform,
315 | N_importance=N_importance,
316 | raw_noise_std=0.0,
317 | det=det,
318 | white_bkgd=white_bkgd,
319 | is_train=is_train,
320 | num_vv=num_vv,
321 | )
322 |
323 | # handle both coarse and fine outputs
324 | # cache chunk results on cpu
325 | if i == 0:
326 | for k in ret['outputs_coarse_ref']:
327 | all_ret['outputs_coarse_ref'][k] = []
328 |
329 | for k in ret['outputs_coarse_st']:
330 | all_ret['outputs_coarse_st'][k] = []
331 |
332 | if is_train:
333 | for k in ret['outputs_coarse_anchor']:
334 | all_ret['outputs_coarse_anchor'][k] = []
335 |
336 | if ret['outputs_fine'] is None:
337 | all_ret['outputs_fine'] = None
338 | else:
339 | for k in ret['outputs_fine']:
340 | all_ret['outputs_fine'][k] = []
341 |
342 | for k in ret['outputs_coarse_ref']:
343 | all_ret['outputs_coarse_ref'][k].append(
344 | ret['outputs_coarse_ref'][k].cpu()
345 | )
346 |
347 | for k in ret['outputs_coarse_st']:
348 | all_ret['outputs_coarse_st'][k].append(ret['outputs_coarse_st'][k].cpu())
349 |
350 | if is_train:
351 | for k in ret['outputs_coarse_anchor']:
352 | all_ret['outputs_coarse_anchor'][k].append(
353 | ret['outputs_coarse_anchor'][k].cpu()
354 | )
355 |
356 | if ret['outputs_fine'] is not None:
357 | for k in ret['outputs_fine']:
358 | all_ret['outputs_fine'][k].append(ret['outputs_fine'][k].cpu())
359 |
360 | rgb_strided = torch.ones(ray_sampler.H, ray_sampler.W, 3)[
361 | ::render_stride, ::render_stride, :
362 | ]
363 | # merge chunk results and reshape
364 | for k in all_ret['outputs_coarse_ref']:
365 | if k == 'random_sigma':
366 | continue
367 |
368 | if len(all_ret['outputs_coarse_ref'][k][0].shape) == 4:
369 | continue
370 |
371 | if len(all_ret['outputs_coarse_ref'][k][0].shape) == 3:
372 | tmp = torch.cat(all_ret['outputs_coarse_ref'][k], dim=1).reshape((
373 | all_ret['outputs_coarse_ref'][k][0].shape[0],
374 | rgb_strided.shape[0],
375 | rgb_strided.shape[1],
376 | -1,
377 | ))
378 | else:
379 | tmp = torch.cat(all_ret['outputs_coarse_ref'][k], dim=0).reshape(
380 | (rgb_strided.shape[0], rgb_strided.shape[1], -1)
381 | )
382 | all_ret['outputs_coarse_ref'][k] = tmp.squeeze()
383 |
384 | all_ret['outputs_coarse_ref']['rgb'][
385 | all_ret['outputs_coarse_ref']['mask'] == 0
386 | ] = 0.0
387 |
388 | # merge chunk results and reshape
389 | for k in all_ret['outputs_coarse_st']:
390 | if k == 'random_sigma':
391 | continue
392 |
393 | if len(all_ret['outputs_coarse_st'][k][0].shape) == 4:
394 | continue
395 |
396 | if len(all_ret['outputs_coarse_st'][k][0].shape) == 3:
397 | tmp = torch.cat(all_ret['outputs_coarse_st'][k], dim=1).reshape((
398 | all_ret['outputs_coarse_st'][k][0].shape[0],
399 | rgb_strided.shape[0],
400 | rgb_strided.shape[1],
401 | -1,
402 | ))
403 | else:
404 | tmp = torch.cat(all_ret['outputs_coarse_st'][k], dim=0).reshape(
405 | (rgb_strided.shape[0], rgb_strided.shape[1], -1)
406 | )
407 | all_ret['outputs_coarse_st'][k] = tmp.squeeze()
408 |
409 | all_ret['outputs_coarse_st']['rgb'][
410 | all_ret['outputs_coarse_st']['mask'] == 0
411 | ] = 0.0
412 |
413 | # merge chunk results and reshape
414 | if is_train:
415 | for k in all_ret['outputs_coarse_anchor']:
416 | if k == 'random_sigma':
417 | continue
418 |
419 | if len(all_ret['outputs_coarse_anchor'][k][0].shape) == 4:
420 | continue
421 |
422 | if len(all_ret['outputs_coarse_anchor'][k][0].shape) == 3:
423 | tmp = torch.cat(all_ret['outputs_coarse_anchor'][k], dim=1).reshape((
424 | all_ret['outputs_coarse_anchor'][k][0].shape[0],
425 | rgb_strided.shape[0],
426 | rgb_strided.shape[1],
427 | -1,
428 | ))
429 | else:
430 | tmp = torch.cat(all_ret['outputs_coarse_anchor'][k], dim=0).reshape(
431 | (rgb_strided.shape[0], rgb_strided.shape[1], -1)
432 | )
433 | all_ret['outputs_coarse_anchor'][k] = tmp.squeeze()
434 |
435 | all_ret['outputs_coarse_anchor']['rgb'][
436 | all_ret['outputs_coarse_anchor']['mask'] == 0
437 | ] = 0.0
438 |
439 | return all_ret
440 |
--------------------------------------------------------------------------------
/ibrnet/data_loaders/llff_data_utils.py:
--------------------------------------------------------------------------------
1 | """Forward-Facing data loading code.
2 |
3 | Modify from IBRNet
4 | github.com/googleinterns/IBRNet/blob/master/ibrnet/data_loaders/llff_data_utils.py
5 | """
6 |
7 | import os
8 |
9 | import cv2
10 | import imageio
11 | import numpy as np
12 |
13 |
14 | def parse_llff_pose(pose):
15 | """convert llff format pose to 4x4 matrix of intrinsics and extrinsics."""
16 |
17 | h, w, f = pose[:3, -1]
18 | c2w = pose[:3, :4]
19 | c2w_4x4 = np.eye(4)
20 | c2w_4x4[:3] = c2w
21 | c2w_4x4[:, 1:3] *= -1
22 | intrinsics = np.array(
23 | [[f, 0, w / 2.0, 0], [0, f, h / 2.0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
24 | )
25 | return intrinsics, c2w_4x4
26 |
27 |
28 | def batch_parse_llff_poses(poses):
29 | """Parse LLFF data format to opencv/colmap format."""
30 | all_intrinsics = []
31 | all_c2w_mats = []
32 | for pose in poses:
33 | intrinsics, c2w_mat = parse_llff_pose(pose)
34 | all_intrinsics.append(intrinsics)
35 | all_c2w_mats.append(c2w_mat)
36 | all_intrinsics = np.stack(all_intrinsics)
37 | all_c2w_mats = np.stack(all_c2w_mats)
38 | return all_intrinsics, all_c2w_mats
39 |
40 |
41 | def batch_parse_vv_poses(poses):
42 | """Parse virtural views pose used for monocular video training."""
43 | all_c2w_mats = []
44 | for pose in poses:
45 | t_c2w_mats = []
46 | for p in pose:
47 | intrinsics, c2w_mat = parse_llff_pose(p)
48 | t_c2w_mats.append(c2w_mat)
49 | t_c2w_mats = np.stack(t_c2w_mats)
50 | all_c2w_mats.append(t_c2w_mats)
51 |
52 | all_c2w_mats = np.stack(all_c2w_mats)
53 |
54 | return all_c2w_mats
55 |
56 |
57 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True):
58 | """Function for loading LLFF data."""
59 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds_cvd.npy'))
60 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
61 | bds = poses_arr[:, -2:].transpose([1, 0])
62 |
63 | img0 = [
64 | os.path.join(basedir, 'images', f)
65 | for f in sorted(os.listdir(os.path.join(basedir, 'images')))
66 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')
67 | ][0]
68 | sh = imageio.imread(img0).shape
69 |
70 | sfx = ''
71 |
72 | if factor is not None and factor != 1:
73 | sfx = '_{}'.format(factor)
74 | elif height is not None:
75 | factor = sh[0] / float(height)
76 | width = int(round(sh[1] / factor))
77 | sfx = '_{}x{}'.format(width, height)
78 | elif width is not None:
79 | factor = sh[1] / float(width)
80 | height = int(round(sh[0] / factor))
81 | sfx = '_{}x{}'.format(width, height)
82 | else:
83 | factor = 1
84 |
85 | imgdir = os.path.join(basedir, 'images' + sfx)
86 | print('imgdir ', imgdir, ' factor ', factor)
87 |
88 | if not os.path.exists(imgdir):
89 | print(imgdir, 'does not exist, returning')
90 | return
91 |
92 | imgfiles = [
93 | os.path.join(imgdir, f)
94 | for f in sorted(os.listdir(imgdir))
95 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')
96 | ]
97 |
98 | if poses.shape[-1] != len(imgfiles):
99 | print(
100 | '{}: Mismatch between imgs {} and poses {} !!!!'.format(
101 | basedir, len(imgfiles), poses.shape[-1]
102 | )
103 | )
104 | raise NotImplementedError
105 |
106 | sh = imageio.imread(imgfiles[0]).shape
107 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
108 | poses[2, 4, :] = poses[2, 4, :] # * 1. / factor
109 |
110 | def imread(f):
111 | if f.endswith('png'):
112 | return imageio.imread(f, ignoregamma=True)
113 | else:
114 | return imageio.imread(f)
115 |
116 | if not load_imgs:
117 | imgs = None
118 | else:
119 | imgs = [imread(f)[..., :3] / 255.0 for f in imgfiles]
120 | imgs = np.stack(imgs, -1)
121 | print('Loaded image data', imgs.shape, poses[:, -1, 0])
122 |
123 | return poses, bds, imgs, imgfiles
124 |
125 |
126 | def normalize(x):
127 | return x / np.linalg.norm(x)
128 |
129 |
130 | def viewmatrix(z, up, pos):
131 | vec2 = normalize(z)
132 | vec1_avg = up
133 | vec0 = normalize(np.cross(vec1_avg, vec2))
134 | vec1 = normalize(np.cross(vec2, vec0))
135 | m = np.stack([vec0, vec1, vec2, pos], 1)
136 | return m
137 |
138 |
139 | def ptstocam(pts, c2w):
140 | tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0]
141 | return tt
142 |
143 |
144 | def poses_avg(poses):
145 | hwf = poses[0, :3, -1:]
146 |
147 | center = poses[:, :3, 3].mean(0)
148 | vec2 = normalize(poses[:, :3, 2].sum(0))
149 | up = poses[:, :3, 1].sum(0)
150 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
151 |
152 | return c2w
153 |
154 |
155 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
156 | """Render a spiral path."""
157 |
158 | render_poses = []
159 | rads = np.array(list(rads) + [1.0])
160 | hwf = c2w[:, 4:5]
161 |
162 | for theta in np.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]:
163 | c = np.dot(
164 | c2w[:3, :4],
165 | np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0])
166 | * rads,
167 | )
168 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0])))
169 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
170 | return render_poses
171 |
172 |
173 | def recenter_poses(poses):
174 | """Recenter camera poses into centroid."""
175 | poses_ = poses + 0
176 | bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
177 | c2w = poses_avg(poses)
178 | c2w = np.concatenate([c2w[:3, :4], bottom], -2)
179 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
180 | poses = np.concatenate([poses[:, :3, :4], bottom], -2)
181 |
182 | poses = np.linalg.inv(c2w) @ poses
183 | poses_[:, :3, :4] = poses[:, :3, :4]
184 | poses = poses_
185 | return poses
186 |
187 |
188 | def recenter_poses_mono(poses, src_vv_poses):
189 | """Recenter virutal view camera poses into centroid."""
190 | hwf = poses[:, :, 4:5]
191 | poses_ = poses + 0
192 | bottom = np.reshape([0, 0, 0, 1.], [1, 4])
193 | c2w = poses_avg(poses)
194 | c2w = np.concatenate([c2w[:3, :4], bottom], -2)
195 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
196 | poses = np.concatenate([poses[:, :3, :4], bottom], -2)
197 |
198 | poses = np.linalg.inv(c2w) @ poses
199 | poses_[:, :3, :4] = poses[:, :3, :4]
200 | poses = poses_
201 |
202 | src_output_poses = np.zeros((
203 | src_vv_poses.shape[1],
204 | src_vv_poses.shape[0],
205 | src_vv_poses.shape[2],
206 | src_vv_poses.shape[3] + 1,
207 | ))
208 | for i in range(src_vv_poses.shape[1]):
209 | src_vv_poses_ = np.concatenate([src_vv_poses[:, i, :3, :4], bottom], -2)
210 | src_vv_poses_ = np.linalg.inv(c2w) @ src_vv_poses_
211 | src_output_poses[i, ...] = np.concatenate([src_vv_poses_[:, :3, :], hwf], 2)
212 |
213 | return poses, np.moveaxis(src_output_poses, 1, 0)
214 |
215 |
216 | def load_llff_data(
217 | basedir,
218 | height,
219 | num_avg_imgs,
220 | factor=8,
221 | render_idx=8,
222 | recenter=True,
223 | bd_factor=0.75,
224 | spherify=False,
225 | load_imgs=True,
226 | ):
227 | """Load LLFF forward-facing data.
228 |
229 | Args:
230 | basedir: base directory
231 | height: training image height
232 | factor: resize factor
233 | render_idx: rendering frame index from the video
234 | recenter: recentor camera poses
235 | bd_factor: scale factor for bounds
236 | spherify: spherify the camera poses
237 | load_imgs: load images from the disk
238 |
239 | Returns:
240 | images: video frames
241 | poses: corresponding camera parameters
242 | bds: bounds
243 | render_poses: rendering camera poses
244 | i_test: test index
245 | imgfiles: list of image path
246 | scale: scene scale
247 | """
248 | out = _load_data(
249 | basedir, factor=None, load_imgs=load_imgs, height=height
250 | )
251 |
252 | if out is None:
253 | return
254 | else:
255 | poses, bds, imgs, imgfiles = out
256 |
257 | # Correct rotation matrix ordering and move variable dim to axis 0
258 | poses = np.concatenate(
259 | [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1
260 | )
261 | poses = np.moveaxis(poses, -1, 0).astype(np.float32)
262 | if imgs is not None:
263 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
264 | images = imgs
265 | images = images.astype(np.float32)
266 | else:
267 | images = None
268 |
269 | bds = np.moveaxis(bds, -1, 0).astype(np.float32)
270 |
271 | # Rescale if bd_factor is provided
272 | scale = 1.0 if bd_factor is None else 1.0 / (bds.min() * bd_factor)
273 |
274 | poses[:, :3, 3] *= scale
275 | bds *= scale
276 |
277 | if recenter:
278 | poses = recenter_poses(poses)
279 |
280 | spiral = True
281 | if spiral:
282 | print('================= render_path_spiral ==========================')
283 | c2w = poses_avg(poses[0:num_avg_imgs])
284 | ## Get spiral
285 | # Get average pose
286 | up = normalize(poses[:, :3, 1].sum(0))
287 |
288 | # Find a reasonable "focus depth" for this dataset
289 | close_depth, inf_depth = bds.min() * 0.9, bds.max() * 2.0
290 | dt = 0.75
291 | mean_dz = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth))
292 | focal = mean_dz * 1.5
293 |
294 | # Get radii for spiral path
295 | # shrink_factor = 0.8
296 | zdelta = close_depth * 0.2
297 | tt = poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T
298 | rads = np.percentile(np.abs(tt), 80, 0)
299 | c2w_path = c2w
300 | n_views = 120
301 | n_rots = 2
302 |
303 | # Generate poses for spiral path
304 | render_poses = render_path_spiral(
305 | c2w_path, up, rads, focal, zdelta, zrate=0.5, rots=n_rots, N=n_views
306 | )
307 | else:
308 | raise NotImplementedError
309 |
310 | render_poses = np.array(render_poses).astype(np.float32)
311 |
312 | dists = np.sum(np.square(c2w[:3, 3] - poses[:, :3, 3]), -1)
313 | i_test = np.argmin(dists)
314 | poses = poses.astype(np.float32)
315 |
316 | print('bds ', bds.min(), bds.max())
317 |
318 | return images, poses, bds, render_poses, i_test, imgfiles, scale
319 |
320 |
321 | def load_mono_data(
322 | basedir,
323 | height=288,
324 | factor=8,
325 | render_idx=-1,
326 | recenter=True,
327 | bd_factor=0.75,
328 | spherify=False,
329 | load_imgs=True,
330 | ):
331 | """Load monocular video data.
332 |
333 | Args:
334 | basedir: base directory
335 | height: training image height
336 | factor: resize factor
337 | render_idx: rendering frame index from the video
338 | recenter: recentor camera poses
339 | bd_factor: scale factor for bounds
340 | spherify: spherify the camera poses
341 | load_imgs: load images from the disk
342 |
343 | Returns:
344 | images: video frames
345 | poses: corresponding camera parameters
346 | src_vv_poses: virtual view camera poses
347 | bds: bounds
348 | render_poses: rendering camera poses
349 | i_test: test index
350 | imgfiles: list of image path
351 | scale: scene scale
352 | """
353 | out = _load_data(basedir, factor=None, load_imgs=load_imgs, height=height)
354 |
355 | src_vv_poses = np.load(os.path.join(basedir, 'source_vv_poses.npy'))
356 |
357 | if out is None:
358 | return
359 | else:
360 | poses, bds, imgs, imgfiles = out
361 |
362 | # Correct rotation matrix ordering and move variable dim to axis 0
363 | poses = np.concatenate(
364 | [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1
365 | )
366 | src_vv_poses = np.concatenate(
367 | [
368 | src_vv_poses[:, :, 1:2, :],
369 | -src_vv_poses[:, :, 0:1, :],
370 | src_vv_poses[:, :, 2:, :],
371 | ],
372 | 2,
373 | )
374 |
375 | poses = np.moveaxis(poses, -1, 0).astype(np.float32)
376 | src_vv_poses = np.moveaxis(src_vv_poses, -1, 0).astype(np.float32)
377 |
378 | if imgs is not None:
379 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
380 | images = imgs
381 | images = images.astype(np.float32)
382 | else:
383 | images = None
384 |
385 | bds = np.moveaxis(bds, -1, 0).astype(np.float32)
386 |
387 | # Rescale if bd_factor is provided
388 | scale = 1. if bd_factor is None else 1. / (bds.min() * bd_factor)
389 |
390 | poses[:, :3, 3] *= scale
391 | src_vv_poses[..., :3, 3] *= scale
392 |
393 | bds *= scale
394 |
395 | if recenter:
396 | poses, src_vv_poses = recenter_poses_mono(poses, src_vv_poses)
397 |
398 | if render_idx >= 0:
399 | render_poses = render_wander_path(poses[render_idx])
400 | else:
401 | render_poses = render_stabilization_path(poses, k_size=45)
402 |
403 | render_poses = np.array(render_poses).astype(np.float32)
404 |
405 | i_test = []
406 | poses = poses.astype(np.float32)
407 |
408 | print('bds ', bds.min(), bds.max())
409 |
410 | return images, poses, src_vv_poses, bds, render_poses, i_test, imgfiles, scale
411 |
412 |
413 | def render_wander_path(c2w):
414 | """Rendering circular path."""
415 | hwf = c2w[:, 4:5]
416 | num_frames = 50
417 | max_disp = 48.0
418 |
419 | max_trans = max_disp / hwf[2][0]
420 | output_poses = []
421 |
422 | for i in range(num_frames):
423 | x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames))
424 | y_trans = 0.#max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 2.
425 | z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 2.
426 |
427 | i_pose = np.concatenate(
428 | [
429 | np.concatenate(
430 | [
431 | np.eye(3),
432 | np.array([x_trans, y_trans, z_trans])[:, np.newaxis],
433 | ],
434 | axis=1,
435 | ),
436 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :],
437 | ],
438 | axis=0,
439 | )
440 |
441 | i_pose = np.linalg.inv(i_pose)
442 |
443 | ref_pose = np.concatenate(
444 | [c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0
445 | )
446 |
447 | render_pose = np.dot(ref_pose, i_pose)
448 | output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1))
449 |
450 | return output_poses
451 |
452 |
453 | def render_stabilization_path(poses, k_size):
454 | """Rendering stablizaed camera path."""
455 |
456 | hwf = poses[0, :, 4:5]
457 | num_frames = poses.shape[0]
458 | output_poses = []
459 |
460 | input_poses = []
461 |
462 | for i in range(num_frames):
463 | input_poses.append(
464 | np.concatenate(
465 | [poses[i, :3, 0:1], poses[i, :3, 1:2], poses[i, :3, 3:4]], axis=-1
466 | )
467 | )
468 |
469 | input_poses = np.array(input_poses)
470 |
471 | gaussian_kernel = cv2.getGaussianKernel(
472 | ksize=k_size, sigma=-1
473 | )
474 | output_r1 = cv2.filter2D(input_poses[:, :, 0], -1, gaussian_kernel)
475 | output_r2 = cv2.filter2D(input_poses[:, :, 1], -1, gaussian_kernel)
476 |
477 | output_r1 = output_r1 / np.linalg.norm(output_r1, axis=-1, keepdims=True)
478 | output_r2 = output_r2 / np.linalg.norm(output_r2, axis=-1, keepdims=True)
479 |
480 | output_t = cv2.filter2D(input_poses[:, :, 2], -1, gaussian_kernel)
481 |
482 | for i in range(num_frames):
483 | output_r3 = np.cross(output_r1[i], output_r2[i])
484 |
485 | render_pose = np.concatenate(
486 | [
487 | output_r1[i, :, None],
488 | output_r2[i, :, None],
489 | output_r3[:, None],
490 | output_t[i, :, None],
491 | ],
492 | axis=-1,
493 | )
494 |
495 | output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1))
496 |
497 | return output_poses
498 |
--------------------------------------------------------------------------------
/ibrnet/data_loaders/monocular.py:
--------------------------------------------------------------------------------
1 | """Dataloader class for training monocular videos."""
2 |
3 |
4 | import os
5 | import cv2
6 | from ibrnet.data_loaders.data_utils import get_nearest_pose_ids
7 | from ibrnet.data_loaders.llff_data_utils import batch_parse_llff_poses
8 | from ibrnet.data_loaders.llff_data_utils import batch_parse_vv_poses
9 | from ibrnet.data_loaders.llff_data_utils import load_mono_data
10 | import imageio
11 | import numpy as np
12 | import skimage.morphology
13 | import torch
14 | from torch.utils.data import Dataset
15 |
16 |
17 | class MonocularDataset(Dataset):
18 | """This class loads data from monocular video.
19 |
20 | Each returned item in the dataset has
21 | id: reference frame index
22 | anchor_id: nearby frame index for cross time rendering
23 | num_frames: number of video frames
24 | ref_time: normalized reference time index
25 | anchor_time: normalized nearby cross-time index
26 | nearest_pose_ids: source view index w.r.t reference time
27 | anchor_nearest_pose_ids: source view index w.r.t nearby time
28 | rgb: [H, W, 3], image at reference time
29 | disp: [H, W], disparity at reference time
30 | motion_mask: [H, W], dynamic mask at reference time
31 | static_mask: [H, W], static mask at reference time
32 | flows: [6, H, W, 2] optical flows from reference time
33 | masks: [6, H, W] optical flow valid masks from reference time
34 | camera: [34] camera parameters at reference time
35 | anchor_camera: [34] camera parameters at nearby cross-time
36 | rgb_path: RGB file path name
37 | src_rgbs: [..., H, W, 3] source views RGB images for dynamic model
38 | src_cameras: [..., 34] source view camera parameters for dynamic model
39 | static_src_rgbs: [..., H, W, 3] srouce view images for static model
40 | static_src_cameras: [..., 34] source view camera parameters for static model
41 | anchor_src_rgbs: [..., H, W, 3] cross-time view images for dynamic model
42 | anchor_src_cameras: [..., 34] cross-time source view camera parameters for
43 | dynamic model
44 | depth_range: [2] scene near and far bounds
45 | """
46 |
47 | def __init__(self, args, mode, scenes=(), random_crop=True, **kwargs):
48 | assert len(scenes) == 1
49 | self.folder_path = args.folder_path
50 | self.num_vv = args.num_vv
51 | self.args = args
52 | self.mask_src_view = args.mask_src_view
53 | self.num_frames_sample = args.num_source_views
54 | self.erosion_radius = args.erosion_radius
55 | self.random_crop = random_crop
56 |
57 | self.max_range = args.max_range
58 |
59 | scene = scenes[0]
60 | self.scene_path = os.path.join(self.folder_path, scene, 'dense')
61 | _, poses, src_vv_poses, bds, _, _, rgb_files, scale = (
62 | load_mono_data(
63 | self.scene_path, height=args.training_height, load_imgs=False
64 | )
65 | )
66 | near_depth = np.min(bds)
67 |
68 | # make sure far scenes to be at least 15
69 | # so that static model is able to model view-dependent effect.
70 | if np.max(bds) < 10:
71 | far_depth = min(20, np.max(bds) + 15.0)
72 | else:
73 | far_depth = min(50, max(20, np.max(bds)))
74 |
75 | print('============= FINAL NEAR FAR', near_depth, far_depth)
76 |
77 | intrinsics, c2w_mats = batch_parse_llff_poses(poses)
78 | self.src_vv_c2w_mats = batch_parse_vv_poses(src_vv_poses)
79 | self.num_frames = len(rgb_files)
80 | assert self.num_frames == poses.shape[0]
81 | i_train = np.arange(self.num_frames)
82 | i_render = i_train
83 | self.scale = scale
84 |
85 | num_render = len(i_render)
86 | self.train_rgb_files = rgb_files
87 | self.train_intrinsics = intrinsics
88 | self.train_poses = c2w_mats
89 | self.train_depth_range = [[near_depth, far_depth]] * num_render
90 |
91 | def read_optical_flow(self, basedir, img_i, start_frame, fwd, interval):
92 | flow_dir = os.path.join(basedir, 'flow_i%d' % interval)
93 |
94 | if fwd:
95 | fwd_flow_path = os.path.join(
96 | flow_dir, '%05d_fwd.npz' % (start_frame + img_i)
97 | )
98 | fwd_data = np.load(fwd_flow_path) # , (w, h))
99 | fwd_flow, fwd_mask = fwd_data['flow'], fwd_data['mask']
100 | fwd_mask = np.float32(fwd_mask)
101 |
102 | return fwd_flow, fwd_mask
103 | else:
104 | bwd_flow_path = os.path.join(
105 | flow_dir, '%05d_bwd.npz' % (start_frame + img_i)
106 | )
107 |
108 | bwd_data = np.load(bwd_flow_path) # , (w, h))
109 | bwd_flow, bwd_mask = bwd_data['flow'], bwd_data['mask']
110 | bwd_mask = np.float32(bwd_mask)
111 |
112 | return bwd_flow, bwd_mask
113 |
114 | def __len__(self):
115 | return self.num_frames
116 |
117 | def set_epoch(self, epoch):
118 | self.current_epoch = epoch
119 |
120 | def load_src_view(
121 | self, rgb_file, pose, intrinsics, st_mask_path=None
122 | ):
123 | """Load RGB and camera data from each source views id."""
124 |
125 | src_rgb = imageio.imread(rgb_file).astype(np.float32) / 255.0
126 | img_size = src_rgb.shape[:2]
127 | src_camera = np.concatenate(
128 | (list(img_size), intrinsics.flatten(), pose.flatten())
129 | ).astype(np.float32)
130 |
131 | if st_mask_path:
132 | st_mask = imageio.imread(st_mask_path).astype(np.float32) / 255.0
133 | st_mask = cv2.resize(
134 | st_mask,
135 | (src_rgb.shape[1], src_rgb.shape[0]),
136 | interpolation=cv2.INTER_NEAREST,
137 | )
138 |
139 | if len(st_mask.shape) == 2:
140 | st_mask = st_mask[..., None]
141 |
142 | src_rgb = src_rgb * st_mask
143 |
144 | return src_rgb, src_camera
145 |
146 | def __getitem__(self, idx):
147 | # skip first and last 3 frames
148 | idx = np.random.randint(3, self.num_frames - 3)
149 | rgb_file = self.train_rgb_files[idx]
150 |
151 | render_pose = self.train_poses[idx]
152 | intrinsics = self.train_intrinsics[idx]
153 | depth_range = self.train_depth_range[idx]
154 |
155 | rgb, camera = self.load_src_view(rgb_file, render_pose, intrinsics)
156 | img_size = rgb.shape[:2]
157 |
158 | # load mono-depth
159 | disp_path = os.path.join(
160 | self.scene_path, 'disp', rgb_file.split('/')[-1][:-4] + '.npy'
161 | )
162 | disp = np.load(disp_path) / self.scale
163 |
164 | # load motion mask
165 | mask_path = os.path.join(
166 | '/'.join(rgb_file.split('/')[:-2]), 'dynamic_masks', '%d.png' % idx
167 | )
168 | motion_mask = 1.0 - imageio.imread(mask_path).astype(np.float32) / 255.0
169 |
170 | static_mask_path = os.path.join(
171 | '/'.join(rgb_file.split('/')[:-2]), 'static_masks', '%d.png' % idx
172 | )
173 | static_mask = (
174 | 1.0 - imageio.imread(static_mask_path).astype(np.float32) / 255.0
175 | )
176 |
177 | static_mask = cv2.resize(
178 | static_mask,
179 | (disp.shape[1], disp.shape[0]),
180 | interpolation=cv2.INTER_NEAREST,
181 | )
182 | # ensure input dynamic and static mask to have same height before
183 | # running morphological erosion
184 | motion_mask = cv2.resize(
185 | motion_mask,
186 | (int(round(288.0 * disp.shape[1] / disp.shape[0])), 288),
187 | interpolation=cv2.INTER_NEAREST,
188 | )
189 |
190 | if len(motion_mask.shape) == 2:
191 | motion_mask = motion_mask[..., None]
192 |
193 | motion_mask = skimage.morphology.erosion(
194 | motion_mask[..., 0] > 1e-3, skimage.morphology.disk(self.erosion_radius)
195 | )
196 |
197 | motion_mask = cv2.resize(
198 | np.float32(motion_mask),
199 | (disp.shape[1], disp.shape[0]),
200 | interpolation=cv2.INTER_NEAREST,
201 | )
202 |
203 | motion_mask = np.float32(motion_mask)
204 | static_mask = np.float32(static_mask > 1e-3)
205 |
206 | assert disp.shape[0:2] == img_size
207 | assert motion_mask.shape[0:2] == img_size
208 | assert static_mask.shape[0:2] == img_size
209 |
210 | # train_set_id = self.render_train_set_ids[idx]
211 | train_rgb_files = self.train_rgb_files
212 | train_poses = self.train_poses
213 | train_intrinsics = self.train_intrinsics
214 |
215 | # view selection based on time interval
216 | nearest_pose_ids = [idx + offset for offset in [1, 2, 3, -1, -2, -3]]
217 | max_step = min(3, self.current_epoch // (self.args.init_decay_epoch) + 1)
218 | # select a nearby time index for cross time rendering
219 | anchor_pool = [i for i in range(1, max_step + 1)] + [
220 | -i for i in range(1, max_step + 1)
221 | ]
222 | anchor_idx = idx + anchor_pool[np.random.choice(len(anchor_pool))]
223 | anchor_nearest_pose_ids = []
224 |
225 | anchor_camera = np.concatenate((
226 | list(img_size),
227 | self.train_intrinsics[anchor_idx].flatten(),
228 | self.train_poses[anchor_idx].flatten(),
229 | )).astype(np.float32)
230 |
231 | for offset in [3, 2, 1, 0, -1, -2, -3]:
232 | if (
233 | (anchor_idx + offset) < 0
234 | or (anchor_idx + offset) >= len(train_rgb_files)
235 | or (anchor_idx + offset) == idx
236 | ):
237 | continue
238 | anchor_nearest_pose_ids.append((anchor_idx + offset))
239 |
240 | # occasionally include render image for anchor time index
241 | if np.random.choice([0, 1], p=[1.0 - 0.005, 0.005]):
242 | anchor_nearest_pose_ids.append(idx)
243 |
244 | anchor_nearest_pose_ids = np.sort(anchor_nearest_pose_ids)
245 |
246 | flows, masks = [], []
247 |
248 | # load optical flow
249 | for ii in range(len(nearest_pose_ids)):
250 | offset = nearest_pose_ids[ii] - idx
251 | flow, mask = self.read_optical_flow(
252 | self.scene_path,
253 | idx,
254 | start_frame=0,
255 | fwd=True if offset > 0 else False,
256 | interval=np.abs(offset),
257 | )
258 |
259 | flows.append(flow)
260 | masks.append(mask)
261 |
262 | flows = np.stack(flows)
263 | masks = np.stack(masks)
264 |
265 | assert flows.shape[1:3] == img_size
266 | assert masks.shape[1:3] == img_size
267 |
268 | # load src rgb for ref view
269 | sp_pose_ids = get_nearest_pose_ids(
270 | render_pose,
271 | train_poses,
272 | tar_id=idx,
273 | angular_dist_method='dist',
274 | )
275 |
276 | static_pose_ids = []
277 |
278 | max_interval = self.max_range // self.num_frames_sample
279 | interval = np.random.randint(max(2, max_interval - 2), max_interval + 1)
280 |
281 | for ii in range(-self.num_frames_sample, self.num_frames_sample):
282 | rand_j = np.random.randint(1, interval + 1)
283 | static_pose_id = idx + interval * ii + rand_j
284 |
285 | if 0 <= static_pose_id < self.num_frames and static_pose_id != idx:
286 | static_pose_ids.append(static_pose_id)
287 |
288 | static_pose_set = set(static_pose_ids)
289 | # if there are no enough image, add nearest images w.r.t camera poses
290 | # choose stride of 5 so that views are not very close to each other.
291 | for sp_pose_id in sp_pose_ids[::5]:
292 | if len(static_pose_ids) >= (self.num_frames_sample * 2):
293 | break
294 |
295 | if sp_pose_id not in static_pose_set:
296 | static_pose_ids.append(sp_pose_id)
297 |
298 | static_pose_ids = np.sort(static_pose_ids)
299 |
300 | src_rgbs = []
301 | src_cameras = []
302 |
303 | for near_id in nearest_pose_ids:
304 | src_rgb, src_camera = self.load_src_view(
305 | train_rgb_files[near_id],
306 | train_poses[near_id],
307 | train_intrinsics[near_id],
308 | )
309 | src_rgbs.append(src_rgb)
310 | src_cameras.append(src_camera)
311 |
312 | # load src virtual views
313 | for virtual_idx in np.random.choice(
314 | list(range(0, 8)), size=self.num_vv, replace=False
315 | ):
316 | src_vv_path = os.path.join(
317 | '/'.join(
318 | rgb_file.replace('images', 'source_virtual_views').split('/')[:-1]
319 | ),
320 | '%05d' % idx,
321 | '%02d.png' % virtual_idx,
322 | )
323 | src_rgb, src_camera = self.load_src_view(
324 | src_vv_path,
325 | self.src_vv_c2w_mats[idx, virtual_idx],
326 | intrinsics,
327 | )
328 | src_rgbs.append(src_rgb)
329 | src_cameras.append(src_camera)
330 |
331 | src_rgbs = np.stack(src_rgbs, axis=0)
332 | src_cameras = np.stack(src_cameras, axis=0)
333 |
334 | static_src_rgbs = []
335 | static_src_cameras = []
336 |
337 | # load src rgb for static view
338 | for st_near_id in static_pose_ids:
339 | st_mask_path = None
340 |
341 | if self.mask_src_view:
342 | st_mask_path = os.path.join(
343 | '/'.join(rgb_file.split('/')[:-2]),
344 | 'dynamic_masks',
345 | '%d.png' % st_near_id,
346 | )
347 |
348 | src_rgb, src_camera = self.load_src_view(
349 | train_rgb_files[st_near_id],
350 | train_poses[st_near_id],
351 | train_intrinsics[st_near_id],
352 | st_mask_path=st_mask_path,
353 | )
354 |
355 | static_src_rgbs.append(src_rgb)
356 | static_src_cameras.append(src_camera)
357 |
358 | static_src_rgbs = np.stack(static_src_rgbs, axis=0)
359 | static_src_cameras = np.stack(static_src_cameras, axis=0)
360 |
361 | # load src rgb for anchor view
362 | anchor_src_rgbs = []
363 | anchor_src_cameras = []
364 |
365 | for near_id in anchor_nearest_pose_ids:
366 | src_rgb, src_camera = self.load_src_view(
367 | train_rgb_files[near_id],
368 | train_poses[near_id],
369 | train_intrinsics[near_id],
370 | )
371 | anchor_src_rgbs.append(src_rgb)
372 | anchor_src_cameras.append(src_camera)
373 |
374 | # load anchor src virtual views
375 | for virtual_idx in np.random.choice(
376 | list(range(0, 8)), size=self.num_vv, replace=False
377 | ):
378 | src_vv_path = os.path.join(
379 | '/'.join(
380 | rgb_file.replace('images', 'source_virtual_views').split('/')[:-1]
381 | ),
382 | '%05d' % anchor_idx,
383 | '%02d.png' % virtual_idx,
384 | )
385 | src_rgb, src_camera = self.load_src_view(
386 | src_vv_path,
387 | self.src_vv_c2w_mats[anchor_idx, virtual_idx],
388 | intrinsics,
389 | )
390 | anchor_src_rgbs.append(src_rgb)
391 | anchor_src_cameras.append(src_camera)
392 |
393 | anchor_src_rgbs = np.stack(anchor_src_rgbs, axis=0)
394 | anchor_src_cameras = np.stack(anchor_src_cameras, axis=0)
395 |
396 | depth_range = torch.tensor(
397 | [depth_range[0] * 0.9, depth_range[1] * 1.5]
398 | ).float()
399 |
400 | return {
401 | 'id': idx,
402 | 'anchor_id': anchor_idx,
403 | 'num_frames': self.num_frames,
404 | 'ref_time': float(idx / float(self.num_frames)),
405 | 'anchor_time': float(anchor_idx / float(self.num_frames)),
406 | 'nearest_pose_ids': torch.from_numpy(np.array(nearest_pose_ids)),
407 | 'anchor_nearest_pose_ids': torch.from_numpy(
408 | np.array(anchor_nearest_pose_ids)
409 | ),
410 | 'rgb': torch.from_numpy(rgb[..., 0:3]).float(),
411 | 'disp': torch.from_numpy(disp).float(),
412 | 'motion_mask': torch.from_numpy(motion_mask).float(),
413 | 'static_mask': torch.from_numpy(static_mask).float(),
414 | 'flows': torch.from_numpy(flows).float(),
415 | 'masks': torch.from_numpy(masks).float(),
416 | 'camera': torch.from_numpy(camera).float(),
417 | 'anchor_camera': torch.from_numpy(anchor_camera).float(),
418 | 'rgb_path': rgb_file,
419 | 'src_rgbs': torch.from_numpy(src_rgbs[..., :3]).float(),
420 | 'src_cameras': torch.from_numpy(src_cameras).float(),
421 | 'static_src_rgbs': torch.from_numpy(static_src_rgbs[..., :3]).float(),
422 | 'static_src_cameras': torch.from_numpy(static_src_cameras).float(),
423 | 'anchor_src_rgbs': torch.from_numpy(anchor_src_rgbs[..., :3]).float(),
424 | 'anchor_src_cameras': torch.from_numpy(anchor_src_cameras).float(),
425 | 'depth_range': depth_range,
426 | }
427 |
--------------------------------------------------------------------------------
/ibrnet/model.py:
--------------------------------------------------------------------------------
1 | """Main Dynibar model class definition."""
2 |
3 |
4 | import os
5 | from ibrnet.feature_network import ResNet
6 | from ibrnet.mlp_network import DynibarDynamic
7 | from ibrnet.mlp_network import DynibarStatic
8 | from ibrnet.mlp_network import MotionMLP
9 | import numpy as np
10 | import torch
11 |
12 |
13 | def de_parallel(model):
14 | """convert distributed parallel model to single model."""
15 | return model.module if hasattr(model, 'module') else model
16 |
17 |
18 | def init_dct_basis(num_basis, num_frames):
19 | """Initialize motion basis with DCT coefficient."""
20 | T = num_frames
21 | K = num_basis
22 | dct_basis = torch.zeros([T, K])
23 |
24 | for t in range(T):
25 | for k in range(1, K + 1):
26 | dct_basis[t, k - 1] = np.sqrt(2.0 / T) * np.cos(
27 | np.pi / (2.0 * T) * (2 * t + 1) * k
28 | )
29 |
30 | return dct_basis
31 |
32 |
33 | class DynibarFF(object):
34 | """Dynibar model for forward-facing benchmark."""
35 |
36 | def __init__(self, args, load_opt=True, load_scheduler=True):
37 | self.args = args
38 | self.device = torch.device('cuda:{}'.format(args.local_rank))
39 | # create coarse DynIBaR models
40 | self.net_coarse_st = DynibarStatic(
41 | args,
42 | in_feat_ch=self.args.coarse_feat_dim,
43 | n_samples=self.args.N_samples,
44 | ).to(self.device)
45 | self.net_coarse_dy = DynibarDynamic(
46 | args,
47 | in_feat_ch=self.args.coarse_feat_dim,
48 | n_samples=self.args.N_samples,
49 | ).to(self.device)
50 |
51 | # create fine DynIBaR models
52 | self.net_fine_st = DynibarStatic(
53 | args,
54 | in_feat_ch=self.args.fine_feat_dim,
55 | n_samples=self.args.N_samples + self.args.N_importance,
56 | ).to(self.device)
57 | self.net_fine_dy = DynibarDynamic(
58 | args,
59 | in_feat_ch=self.args.fine_feat_dim,
60 | n_samples=self.args.N_samples + self.args.N_importance,
61 | ).to(self.device)
62 |
63 | # create coarse feature extraction network
64 | self.feature_net = ResNet(
65 | coarse_out_ch=self.args.coarse_feat_dim,
66 | fine_out_ch=self.args.fine_feat_dim,
67 | coarse_only=False,
68 | ).to(self.device)
69 |
70 | # create fine feature extraction network
71 | self.feature_net_fine = ResNet(
72 | coarse_out_ch=self.args.coarse_feat_dim,
73 | fine_out_ch=self.args.fine_feat_dim,
74 | coarse_only=False,
75 | ).to(self.device)
76 |
77 | # Motion trajectory models with MLPs
78 | self.motion_mlp = (
79 | MotionMLP(num_basis=args.num_basis).float().to(self.device)
80 | )
81 | self.motion_mlp_fine = (
82 | MotionMLP(num_basis=args.num_basis).float().to(self.device)
83 | )
84 |
85 | # Motion basis
86 | dct_basis = init_dct_basis(args.num_basis, args.num_frames)
87 | self.trajectory_basis = (
88 | torch.nn.parameter.Parameter(dct_basis)
89 | .float()
90 | .to(self.device)
91 | .detach()
92 | .requires_grad_(True)
93 | )
94 | self.trajectory_basis_fine = (
95 | torch.nn.parameter.Parameter(dct_basis)
96 | .float()
97 | .to(self.device)
98 | .detach()
99 | .requires_grad_(True)
100 | )
101 |
102 | self.load_coarse_from_ckpt(args.coarse_dir)
103 |
104 | out_folder = os.path.join(args.rootdir, 'checkpoints/fine', args.expname)
105 |
106 | self.optimizer = torch.optim.Adam([
107 | {
108 | 'params': self.net_fine_st.parameters(),
109 | 'lr': args.lrate_mlp * args.lr_multipler,
110 | },
111 | {'params': self.net_fine_dy.parameters(), 'lr': args.lrate_mlp},
112 | {
113 | 'params': self.feature_net_fine.parameters(),
114 | 'lr': args.lrate_feature,
115 | },
116 | {'params': self.motion_mlp_fine.parameters(), 'lr': args.lrate_mlp},
117 | {'params': self.trajectory_basis_fine, 'lr': args.lrate_mlp * 0.25},
118 | ])
119 |
120 | self.scheduler = torch.optim.lr_scheduler.StepLR(
121 | self.optimizer,
122 | step_size=args.lrate_decay_steps,
123 | gamma=args.lrate_decay_factor,
124 | )
125 |
126 | self.start_step = self.load_fine_from_ckpt(
127 | out_folder, load_opt=True, load_scheduler=True
128 | )
129 |
130 | device_ids = list(range(torch.cuda.device_count()))
131 |
132 | # convert single model to
133 | # multi-GPU distributed mode for coarse networks
134 | self.net_coarse_st = torch.nn.DataParallel(
135 | self.net_coarse_st, device_ids=device_ids
136 | )
137 | self.net_coarse_dy = torch.nn.DataParallel(
138 | self.net_coarse_dy, device_ids=device_ids
139 | )
140 | self.feature_net = torch.nn.DataParallel(
141 | self.feature_net, device_ids=device_ids
142 | )
143 | self.motion_mlp = torch.nn.DataParallel(
144 | self.motion_mlp, device_ids=device_ids
145 | )
146 | # convert single model to
147 | # multi-GPU distributed mode for fine networks
148 | self.net_fine_st = torch.nn.DataParallel(
149 | self.net_fine_st, device_ids=device_ids
150 | )
151 | self.net_fine_dy = torch.nn.DataParallel(
152 | self.net_fine_dy, device_ids=device_ids
153 | )
154 | self.feature_net_fine = torch.nn.DataParallel(
155 | self.feature_net_fine, device_ids=device_ids
156 | )
157 | self.motion_mlp_fine = torch.nn.DataParallel(
158 | self.motion_mlp_fine, device_ids=device_ids
159 | )
160 |
161 | def switch_to_eval(self):
162 | """Switch to evaluation model."""
163 | self.net_fine_st.eval()
164 | self.net_fine_dy.eval()
165 |
166 | self.feature_net_fine.eval()
167 | self.motion_mlp_fine.eval()
168 |
169 | def switch_to_train(self):
170 | """Switch to training model."""
171 | self.net_fine_st.train()
172 | self.net_fine_dy.train()
173 |
174 | self.feature_net_fine.train()
175 | self.motion_mlp_fine.train()
176 |
177 | def save_model(self, filename, global_step):
178 | """De-parallel and save current model to local disk."""
179 | to_save = {
180 | 'optimizer': self.optimizer.state_dict(),
181 | 'scheduler': self.scheduler.state_dict(),
182 | 'net_fine_st': de_parallel(self.net_fine_st).state_dict(),
183 | 'net_fine_dy': de_parallel(self.net_fine_dy).state_dict(),
184 | 'feature_net_fine': de_parallel(self.feature_net_fine).state_dict(),
185 | 'motion_mlp_fine': de_parallel(self.motion_mlp_fine).state_dict(),
186 | 'traj_basis_fine': self.trajectory_basis_fine,
187 | 'global_step': int(global_step),
188 | }
189 |
190 | torch.save(to_save, filename)
191 |
192 | def load_coarse_model(self, filename):
193 | """Load coarse stage dynibar model."""
194 | if self.args.distributed:
195 | to_load = torch.load(
196 | filename, map_location='cuda:{}'.format(self.args.local_rank)
197 | )
198 | else:
199 | to_load = torch.load(filename)
200 |
201 | self.net_coarse_st.load_state_dict(to_load['net_coarse_st'])
202 | self.net_coarse_dy.load_state_dict(to_load['net_coarse_dy'])
203 |
204 | self.feature_net.load_state_dict(to_load['feature_net'])
205 |
206 | self.motion_mlp.load_state_dict(to_load['motion_mlp'])
207 | self.trajectory_basis = to_load['traj_basis']
208 |
209 | return to_load['global_step']
210 |
211 | def load_fine_model(self, filename, load_opt=True, load_scheduler=True):
212 | """Load fine stage dynibar model."""
213 | if self.args.distributed:
214 | to_load = torch.load(
215 | filename, map_location='cuda:{}'.format(self.args.local_rank)
216 | )
217 | else:
218 | to_load = torch.load(filename)
219 |
220 | if load_opt:
221 | self.optimizer.load_state_dict(to_load['optimizer'])
222 | if load_scheduler:
223 | self.scheduler.load_state_dict(to_load['scheduler'])
224 |
225 | self.net_fine_st.load_state_dict(to_load['net_fine_st'])
226 | self.net_fine_dy.load_state_dict(to_load['net_fine_dy'])
227 |
228 | self.feature_net_fine.load_state_dict(to_load['feature_net_fine'])
229 |
230 | self.motion_mlp_fine.load_state_dict(to_load['motion_mlp_fine'])
231 | self.trajectory_basis_fine = to_load['traj_basis_fine']
232 |
233 | return to_load['global_step']
234 |
235 | def load_coarse_from_ckpt(
236 | self,
237 | out_folder
238 | ):
239 | """Load coarse model from existing checkpoints and return the current step."""
240 |
241 | # all existing ckpts
242 | ckpts = []
243 | if os.path.exists(out_folder):
244 | ckpts = [
245 | os.path.join(out_folder, f)
246 | for f in sorted(os.listdir(out_folder))
247 | if f.endswith('.pth')
248 | ]
249 |
250 | fpath = ckpts[-1]
251 | num_steps = self.load_coarse_model(fpath)
252 |
253 | step = num_steps
254 | print('Reloading from {}, starting at step={}'.format(fpath, step))
255 |
256 | return step
257 |
258 | def load_fine_from_ckpt(
259 | self,
260 | out_folder,
261 | load_opt=True,
262 | load_scheduler=True
263 | ):
264 | """Load fine model from existing checkpoints and return the current step."""
265 |
266 | # all existing ckpts
267 | ckpts = []
268 | if os.path.exists(out_folder):
269 | ckpts = [
270 | os.path.join(out_folder, f)
271 | for f in sorted(os.listdir(out_folder))
272 | if f.endswith('.pth')
273 | ]
274 |
275 | if self.args.ckpt_path is not None:
276 | if os.path.isfile(self.args.ckpt_path): # load the specified ckpt
277 | ckpts = [self.args.ckpt_path]
278 |
279 | if len(ckpts) > 0 and not self.args.no_reload:
280 | fpath = ckpts[-1]
281 | num_steps = self.load_fine_model(fpath, load_opt, load_scheduler)
282 | step = num_steps
283 | print('Reloading from {}, starting at step={}'.format(fpath, step))
284 | else:
285 | print('No ckpts found, training from scratch...')
286 | step = 0
287 |
288 | return step
289 |
290 |
291 | class DynibarMono(object):
292 | """Main Dynibar model for monocular video."""
293 |
294 | def __init__(self, args):
295 | self.args = args
296 | self.device = torch.device('cuda:{}'.format(args.local_rank))
297 | # create Dynibar models for monocular videos
298 | self.net_coarse_st = DynibarStatic(
299 | args,
300 | in_feat_ch=self.args.coarse_feat_dim,
301 | n_samples=self.args.N_samples,
302 | ).to(self.device)
303 | self.net_coarse_dy = DynibarDynamic(
304 | args,
305 | in_feat_ch=self.args.coarse_feat_dim,
306 | n_samples=self.args.N_samples,
307 | shift=5.0,
308 | ).to(self.device)
309 |
310 | self.net_fine = None
311 |
312 | # create feature extraction network used for dynamic model.
313 | self.feature_net = ResNet(
314 | coarse_out_ch=self.args.coarse_feat_dim,
315 | fine_out_ch=self.args.fine_feat_dim,
316 | coarse_only=False,
317 | ).to(self.device)
318 |
319 | # create feature extraction network used for static model.
320 | self.feature_net_st = ResNet(
321 | coarse_out_ch=self.args.coarse_feat_dim,
322 | fine_out_ch=self.args.fine_feat_dim,
323 | coarse_only=False,
324 | ).to(self.device)
325 |
326 | # Motion trajectory model with MLP.
327 | self.motion_mlp = (
328 | MotionMLP(num_basis=args.num_basis).float().to(self.device)
329 | )
330 |
331 | # basis
332 | dct_basis = init_dct_basis(args.num_basis, args.num_frames)
333 | self.trajectory_basis = (
334 | torch.nn.parameter.Parameter(dct_basis)
335 | .float()
336 | .to(self.device)
337 | .detach()
338 | .requires_grad_(True)
339 | )
340 |
341 | self.optimizer = torch.optim.Adam([
342 | {'params': self.net_coarse_st.parameters(), 'lr': args.lrate_mlp * 0.5},
343 | {
344 | 'params': self.feature_net_st.parameters(),
345 | 'lr': args.lrate_feature * 0.5,
346 | },
347 | {'params': self.net_coarse_dy.parameters(), 'lr': args.lrate_mlp},
348 | {'params': self.feature_net.parameters(), 'lr': args.lrate_feature},
349 | {'params': self.motion_mlp.parameters(), 'lr': args.lrate_mlp},
350 | {'params': self.trajectory_basis, 'lr': args.lrate_mlp * 0.25},
351 | ])
352 |
353 | print(
354 | 'lrate_decay_steps ',
355 | args.lrate_decay_steps,
356 | ' lrate_decay_factor ',
357 | args.lrate_decay_factor,
358 | )
359 |
360 | self.scheduler = torch.optim.lr_scheduler.StepLR(
361 | self.optimizer,
362 | step_size=args.lrate_decay_steps,
363 | gamma=args.lrate_decay_factor,
364 | )
365 |
366 | out_folder = os.path.join(args.rootdir, 'out', args.expname)
367 |
368 | self.start_step = 0
369 |
370 | if args.pretrain_path == '':
371 | self.start_step = self.load_from_ckpt(
372 | out_folder, load_opt=True, load_scheduler=True
373 | )
374 |
375 | else:
376 | self.start_step = self.load_from_ckpt(
377 | args.pretrain_path, load_opt=True, load_scheduler=True
378 | )
379 |
380 | device_ids = list(range(torch.cuda.device_count()))
381 |
382 | self.net_coarse_st = torch.nn.DataParallel(
383 | self.net_coarse_st, device_ids=device_ids
384 | )
385 | self.net_coarse_dy = torch.nn.DataParallel(
386 | self.net_coarse_dy, device_ids=device_ids
387 | )
388 | self.feature_net = torch.nn.DataParallel(
389 | self.feature_net, device_ids=device_ids
390 | )
391 | self.feature_net_st = torch.nn.DataParallel(
392 | self.feature_net_st, device_ids=device_ids
393 | )
394 |
395 | self.motion_mlp = torch.nn.DataParallel(
396 | self.motion_mlp, device_ids=device_ids
397 | )
398 |
399 | def switch_to_eval(self):
400 | """Switch models to evaluation mode."""
401 | self.net_coarse_st.eval()
402 | self.net_coarse_dy.eval()
403 |
404 | self.feature_net.eval()
405 | self.feature_net_st.eval()
406 | self.motion_mlp.eval()
407 |
408 | if self.net_fine is not None:
409 | self.net_fine.eval()
410 |
411 | def switch_to_train(self):
412 | """Switch models to training mode."""
413 |
414 | self.net_coarse_st.train()
415 | self.net_coarse_dy.train()
416 |
417 | self.feature_net.train()
418 | self.motion_mlp.train()
419 | self.feature_net_st.train()
420 |
421 | if self.net_fine is not None:
422 | self.net_fine.train()
423 |
424 | def save_model(self, filename, global_step):
425 | """Save Dynibar monocular model."""
426 | to_save = {
427 | 'optimizer': self.optimizer.state_dict(),
428 | 'scheduler': self.scheduler.state_dict(),
429 | 'net_coarse_st': de_parallel(self.net_coarse_st).state_dict(),
430 | 'net_coarse_dy': de_parallel(self.net_coarse_dy).state_dict(),
431 | 'feature_net': de_parallel(self.feature_net).state_dict(),
432 | 'feature_net_st': de_parallel(self.feature_net_st).state_dict(),
433 | 'motion_mlp': de_parallel(self.motion_mlp).state_dict(),
434 | 'traj_basis': self.trajectory_basis,
435 | 'global_step': int(global_step),
436 | }
437 |
438 | if self.net_fine is not None:
439 | to_save['net_fine'] = de_parallel(self.net_fine).state_dict()
440 |
441 | torch.save(to_save, filename)
442 |
443 | def load_model(self, filename, load_opt=True, load_scheduler=True):
444 | """Load Dynibar monocular model."""
445 | if self.args.distributed:
446 | to_load = torch.load(
447 | filename, map_location='cuda:{}'.format(self.args.local_rank)
448 | )
449 | else:
450 | to_load = torch.load(filename)
451 |
452 | if load_opt:
453 | self.optimizer.load_state_dict(to_load['optimizer'])
454 | if load_scheduler:
455 | self.scheduler.load_state_dict(to_load['scheduler'])
456 |
457 | self.net_coarse_st.load_state_dict(to_load['net_coarse_st'])
458 | self.net_coarse_dy.load_state_dict(to_load['net_coarse_dy'])
459 |
460 | self.feature_net.load_state_dict(to_load['feature_net'])
461 | self.feature_net_st.load_state_dict(to_load['feature_net_st'])
462 |
463 | self.motion_mlp.load_state_dict(to_load['motion_mlp'])
464 | self.trajectory_basis = to_load['traj_basis']
465 |
466 | return to_load['global_step']
467 |
468 | def load_from_ckpt(
469 | self,
470 | out_folder,
471 | load_opt=True,
472 | load_scheduler=True,
473 | ):
474 | """Load coarse model from existing checkpoints and return the current step."""
475 |
476 | # all existing ckpts
477 | ckpts = []
478 | if os.path.exists(out_folder):
479 | ckpts = [
480 | os.path.join(out_folder, f)
481 | for f in sorted(os.listdir(out_folder))
482 | if f.endswith('latest.pth')
483 | ]
484 |
485 | if self.args.ckpt_path is not None:
486 | if os.path.isfile(self.args.ckpt_path): # load the specified ckpt
487 | ckpts = [self.args.ckpt_path]
488 |
489 | if len(ckpts) > 0 and not self.args.no_reload:
490 | fpath = ckpts[-1]
491 | num_steps = self.load_model(fpath, True, True)
492 | print('=========== num_steps ', num_steps)
493 |
494 | step = num_steps
495 | print('Reloading from {}, starting at step={}'.format(fpath, step))
496 | else:
497 | print('No ckpts found, training from scratch...')
498 | step = 0
499 |
500 | return step
501 |
502 |
--------------------------------------------------------------------------------
/eval_nvidia.py:
--------------------------------------------------------------------------------
1 | """Evaluation script for the Nvidia Benchmark."""
2 |
3 | import collections
4 | import math
5 | import os
6 | import time
7 | from config import config_parser
8 | import cv2
9 | from ibrnet.data_loaders.llff_data_utils import batch_parse_llff_poses
10 | from ibrnet.data_loaders.llff_data_utils import load_llff_data
11 | from ibrnet.model import DynibarFF
12 | from ibrnet.projection import Projector
13 | from ibrnet.render_image import render_single_image_nvi
14 | from ibrnet.sample_ray import RaySamplerSingleImage
15 | import imageio
16 | import models
17 | import numpy as np
18 | import skimage.metrics
19 | import torch
20 | from torch.utils.data import DataLoader
21 | from torch.utils.data import Dataset
22 |
23 |
24 | class DynamicVideoDataset(Dataset):
25 | """This class loads data from Nvidia benchmarks, including camera scene and image information from source views."""
26 |
27 | def __init__(self, render_idx, args, scenes, **kwargs):
28 | self.folder_path = args.folder_path
29 | self.render_idx = render_idx
30 | self.mask_static = args.mask_static
31 |
32 | print('loading {} for rendering'.format(scenes))
33 | assert len(scenes) == 1
34 |
35 | scene = scenes[0]
36 | self.scene_path = os.path.join(
37 | self.folder_path, scene, 'dense'
38 | )
39 | _, poses, bds, _, i_test, rgb_files, _ = load_llff_data(
40 | self.scene_path,
41 | height=288,
42 | num_avg_imgs=12,
43 | render_idx=self.render_idx,
44 | load_imgs=False,
45 | )
46 | near_depth = np.min(bds)
47 | # Adding 15 to ensure we cover far scene contents
48 | far_depth = np.max(bds) + 15.0
49 | self.num_frames = len(rgb_files)
50 |
51 | intrinsics, c2w_mats = batch_parse_llff_poses(poses)
52 | h, w = poses[0][:2, -1]
53 | render_intrinsics, render_c2w_mats = (
54 | intrinsics,
55 | c2w_mats,
56 | )
57 |
58 | self.train_intrinsics = intrinsics
59 | self.train_poses = c2w_mats
60 | self.train_rgb_files = rgb_files
61 | self.render_intrinsics = render_intrinsics
62 |
63 | self.render_poses = render_c2w_mats
64 | self.render_depth_range = [[near_depth, far_depth]] * self.num_frames
65 | self.h = [int(h)] * self.num_frames
66 | self.w = [int(w)] * self.num_frames
67 |
68 | def __len__(self):
69 | return 12 # number of viewpoints
70 |
71 | def __getitem__(self, idx):
72 | render_pose = self.render_poses[idx]
73 | intrinsics = self.render_intrinsics[idx]
74 | depth_range = self.render_depth_range[idx]
75 |
76 | train_rgb_files = self.train_rgb_files
77 | train_poses = self.train_poses
78 | train_intrinsics = self.train_intrinsics
79 |
80 | h, w = self.h[idx], self.w[idx]
81 | camera = np.concatenate(
82 | ([h, w], intrinsics.flatten(), render_pose.flatten())
83 | ).astype(np.float32)
84 |
85 | gt_img_path = os.path.join(
86 | self.scene_path,
87 | 'mv_images',
88 | '%05d' % self.render_idx,
89 | 'cam%02d.jpg' % (idx + 1),
90 | )
91 |
92 | nearest_pose_ids = np.sort(
93 | [self.render_idx + offset for offset in [1, 2, 3, 0, -1, -2, -3]]
94 | )
95 | # 12 is number of viewpoints we sample from input cameras
96 | num_imgs_per_cycle = 12
97 |
98 | # Get camera viewpoint that is closet to target view using index for benchmark
99 | # Since benchamrk has fixed viewpoint in a round-robin manner
100 | static_pose_ids = np.array(list(range(0, train_poses.shape[0])))
101 | static_id_dict = collections.defaultdict(list)
102 | for static_pose_id in static_pose_ids:
103 | # do not include image with the same viewpoint
104 | if (
105 | static_pose_id % num_imgs_per_cycle
106 | == self.render_idx % num_imgs_per_cycle
107 | ):
108 | continue
109 |
110 | static_id_dict[static_pose_id % num_imgs_per_cycle].append(static_pose_id)
111 |
112 | static_pose_ids = []
113 | for key in static_id_dict:
114 | min_idx = np.argmin(
115 | np.abs(np.array(static_id_dict[key]) - self.render_idx)
116 | )
117 | static_pose_ids.append(static_id_dict[key][min_idx])
118 |
119 | static_pose_ids = np.sort(static_pose_ids)
120 |
121 | src_rgbs = []
122 | src_cameras = []
123 | for src_idx in nearest_pose_ids:
124 | src_rgb = (
125 | imageio.v2.imread(train_rgb_files[src_idx]).astype(np.float32) / 255.0
126 | )
127 | train_pose = train_poses[src_idx]
128 | train_intrinsics_ = train_intrinsics[src_idx]
129 | src_rgbs.append(src_rgb)
130 | img_size = src_rgb.shape[:2]
131 | src_camera = np.concatenate(
132 | (list(img_size), train_intrinsics_.flatten(), train_pose.flatten())
133 | ).astype(np.float32)
134 |
135 | src_cameras.append(src_camera)
136 |
137 | src_rgbs = np.stack(src_rgbs, axis=0)
138 | src_cameras = np.stack(src_cameras, axis=0)
139 |
140 | static_src_rgbs = []
141 | static_src_cameras = []
142 | static_src_masks = []
143 |
144 | # load src rgb for static view
145 | for st_near_id in static_pose_ids:
146 | src_rgb = (
147 | imageio.v2.imread(train_rgb_files[st_near_id]).astype(np.float32)
148 | / 255.0
149 | )
150 | train_pose = train_poses[st_near_id]
151 | train_intrinsics_ = train_intrinsics[st_near_id]
152 |
153 | static_src_rgbs.append(src_rgb)
154 |
155 | # load coarse mask
156 | if self.mask_static and 3 <= st_near_id < self.num_frames - 3:
157 | st_mask_path = os.path.join(
158 | '/'.join(train_rgb_files[st_near_id].split('/')[:-2]),
159 | 'coarse_masks',
160 | '%05d.png' % st_near_id,
161 | )
162 | st_mask = imageio.v2.imread(st_mask_path).astype(np.float32) / 255.0
163 | st_mask = cv2.resize(
164 | st_mask,
165 | (src_rgb.shape[1], src_rgb.shape[0]),
166 | interpolation=cv2.INTER_NEAREST,
167 | )
168 | else:
169 | st_mask = np.ones_like(src_rgb[..., 0])
170 |
171 | static_src_masks.append(st_mask)
172 |
173 | img_size = src_rgb.shape[:2]
174 | src_camera = np.concatenate(
175 | (list(img_size), train_intrinsics_.flatten(), train_pose.flatten())
176 | ).astype(np.float32)
177 |
178 | static_src_cameras.append(src_camera)
179 |
180 | static_src_rgbs = np.stack(static_src_rgbs, axis=0)
181 | static_src_cameras = np.stack(static_src_cameras, axis=0)
182 | static_src_masks = np.stack(static_src_masks, axis=0)
183 |
184 | depth_range = torch.tensor([depth_range[0] * 0.9, depth_range[1] * 1.5])
185 |
186 | return {
187 | 'camera': torch.from_numpy(camera),
188 | 'rgb_path': gt_img_path,
189 | 'src_rgbs': torch.from_numpy(src_rgbs[..., :3]).float(),
190 | 'src_cameras': torch.from_numpy(src_cameras).float(),
191 | 'static_src_rgbs': torch.from_numpy(static_src_rgbs[..., :3]).float(),
192 | 'static_src_cameras': torch.from_numpy(static_src_cameras).float(),
193 | 'static_src_masks': torch.from_numpy(static_src_masks).float(),
194 | 'depth_range': depth_range,
195 | 'ref_time': float(self.render_idx / float(self.num_frames)),
196 | 'id': self.render_idx,
197 | 'nearest_pose_ids': nearest_pose_ids,
198 | }
199 |
200 |
201 | def calculate_psnr(img1, img2, mask):
202 | """Compute PSNR between two images.
203 |
204 | Args:
205 | img1: image 1
206 | img2: image 2
207 | mask: mask indicating which region is valid.
208 |
209 | Returns:
210 | PSNR: PSNR error
211 | """
212 |
213 | # img1 and img2 have range [0, 1]
214 | img1 = img1.astype(np.float64)
215 | img2 = img2.astype(np.float64)
216 | mask = mask.astype(np.float64)
217 |
218 | num_valid = np.sum(mask) + 1e-8
219 |
220 | mse = np.sum((img1 - img2) ** 2 * mask) / num_valid
221 |
222 | if mse == 0:
223 | return 0 # float('inf')
224 |
225 | return 10 * math.log10(1.0 / mse)
226 |
227 |
228 | def calculate_ssim(img1, img2, mask):
229 | """Compute SSIM between two images.
230 |
231 | Args:
232 | img1: image 1
233 | img2: image 2
234 | mask: mask indicating which region is valid.
235 |
236 | Returns:
237 | PSNR: PSNR error
238 | """
239 | if img1.shape != img2.shape:
240 | raise ValueError('Input images must have the same dimensions.')
241 |
242 | _, ssim_map = skimage.metrics.structural_similarity(
243 | img1, img2, multichannel=True, full=True
244 | )
245 | num_valid = np.sum(mask) + 1e-8
246 |
247 | return np.sum(ssim_map * mask) / num_valid
248 |
249 |
250 | def im2tensor(image, cent=1.0, factor=1.0 / 2.0):
251 | """Convert image to Pytorch tensor.
252 |
253 | Args:
254 | image: input image
255 | cent: shift
256 | factor: scale
257 |
258 | Returns:
259 | Pytorch tensor
260 | """
261 | return torch.Tensor(
262 | (image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1))
263 | )
264 |
265 |
266 | if __name__ == '__main__':
267 | parser = config_parser()
268 | args = parser.parse_args()
269 | args.distributed = False
270 | # Construct a dataset to get number of frames for evaluation
271 | test_dataset = DynamicVideoDataset(0, args, scenes=args.eval_scenes)
272 | args.num_frames = test_dataset.num_frames
273 | print('args.num_frames ', args.num_frames)
274 | # Create ibrnet model
275 | model = DynibarFF(args, load_scheduler=False, load_opt=False)
276 | eval_dataset_name = args.eval_dataset
277 | # extra_out_dir = '{}/{}'.format(eval_dataset_name, args.expname)
278 | # print('saving results to {}...'.format(extra_out_dir))
279 | # os.makedirs(extra_out_dir, exist_ok=True)
280 |
281 | projector = Projector(device='cuda:0')
282 |
283 | assert len(args.eval_scenes) == 1, 'only accept single scene'
284 | scene_name = args.eval_scenes[0]
285 | # out_scene_dir = os.path.join(extra_out_dir, 'renderings')
286 | # print('saving results to {}'.format(out_scene_dir))
287 | # os.makedirs(out_scene_dir, exist_ok=True)
288 |
289 | lpips_model = models.PerceptualLoss(
290 | model='net-lin', net='alex', use_gpu=True, version=0.1
291 | )
292 |
293 | psnr_list = []
294 | ssim_list = []
295 | lpips_list = []
296 |
297 | dy_psnr_list = []
298 | dy_ssim_list = []
299 | dy_lpips_list = []
300 |
301 | st_psnr_list = []
302 | st_ssim_list = []
303 | st_lpips_list = []
304 |
305 | for img_i in range(3, args.num_frames - 3):
306 | test_dataset = DynamicVideoDataset(img_i, args, scenes=args.eval_scenes)
307 | save_prefix = scene_name
308 | test_loader = DataLoader(
309 | test_dataset, batch_size=1, num_workers=12, shuffle=False
310 | )
311 | total_num = len(test_loader)
312 | out_frames = []
313 |
314 | for i, data in enumerate(test_loader):
315 | print('img_i ', img_i, i)
316 |
317 | if img_i % 12 == i:
318 | continue
319 |
320 | # idx = int(data['id'].item())
321 | start = time.time()
322 |
323 | ref_time_embedding = data['ref_time'].cuda()
324 | ref_frame_idx = int(data['id'].item())
325 | ref_time_offset = [
326 | int(near_idx - ref_frame_idx)
327 | for near_idx in data['nearest_pose_ids'].squeeze().tolist()
328 | ]
329 |
330 | model.switch_to_eval()
331 | with torch.no_grad():
332 | ray_sampler = RaySamplerSingleImage(data, device='cuda:0')
333 | ray_batch = ray_sampler.get_all()
334 |
335 | cb_featmaps_1, cb_featmaps_2 = model.feature_net(
336 | ray_batch['src_rgbs'].squeeze(0).permute(0, 3, 1, 2)
337 | )
338 | ref_featmaps = cb_featmaps_1
339 |
340 | static_src_rgbs = (
341 | ray_batch['static_src_rgbs'].squeeze(0).permute(0, 3, 1, 2)
342 | )
343 | _, static_featmaps = model.feature_net(static_src_rgbs)
344 |
345 | cb_featmaps_1_fine, _ = model.feature_net_fine(
346 | ray_batch['src_rgbs'].squeeze(0).permute(0, 3, 1, 2)
347 | )
348 | ref_featmaps_fine = cb_featmaps_1_fine
349 |
350 | if args.mask_static:
351 | static_src_rgbs_ = (
352 | static_src_rgbs
353 | * ray_batch['static_src_masks'].squeeze(0)[:, None, ...]
354 | )
355 | else:
356 | static_src_rgbs_ = static_src_rgbs
357 |
358 | _, static_featmaps_fine = model.feature_net_fine(static_src_rgbs_)
359 |
360 | ret = render_single_image_nvi(
361 | frame_idx=(ref_frame_idx, None),
362 | time_embedding=(ref_time_embedding, None),
363 | time_offset=(ref_time_offset, None),
364 | ray_sampler=ray_sampler,
365 | ray_batch=ray_batch,
366 | model=model,
367 | projector=projector,
368 | chunk_size=args.chunk_size,
369 | det=True,
370 | N_samples=args.N_samples,
371 | args=args,
372 | inv_uniform=args.inv_uniform,
373 | N_importance=args.N_importance,
374 | white_bkgd=args.white_bkgd,
375 | coarse_featmaps=(ref_featmaps, None, static_featmaps),
376 | fine_featmaps=(ref_featmaps_fine, None, static_featmaps_fine),
377 | is_train=False,
378 | )
379 |
380 | fine_pred_rgb = ret['outputs_fine_ref']['rgb'].detach().cpu().numpy()
381 | fine_pred_depth = ret['outputs_fine_ref']['depth'].detach().cpu().numpy()
382 |
383 | valid_mask = np.float32(
384 | np.sum(fine_pred_rgb, axis=-1, keepdims=True) > 1e-3
385 | )
386 | valid_mask = np.tile(valid_mask, (1, 1, 3))
387 | gt_img = cv2.imread(data['rgb_path'][0])[:, :, ::-1]
388 | gt_img = cv2.resize(
389 | gt_img,
390 | dsize=(fine_pred_rgb.shape[1], fine_pred_rgb.shape[0]),
391 | interpolation=cv2.INTER_AREA,
392 | )
393 | gt_img = np.float32(gt_img) / 255
394 |
395 | gt_img = gt_img * valid_mask
396 | fine_pred_rgb = fine_pred_rgb * valid_mask
397 |
398 | dynamic_mask = valid_mask
399 | ssim = calculate_ssim(gt_img, fine_pred_rgb, dynamic_mask)
400 | psnr = calculate_psnr(gt_img, fine_pred_rgb, dynamic_mask)
401 |
402 | gt_img_0 = im2tensor(gt_img).cuda()
403 | fine_pred_rgb_0 = im2tensor(fine_pred_rgb).cuda()
404 | dynamic_mask_0 = torch.Tensor(
405 | dynamic_mask[:, :, :, np.newaxis].transpose((3, 2, 0, 1))
406 | )
407 |
408 | lpips = lpips_model.forward(
409 | gt_img_0, fine_pred_rgb_0, dynamic_mask_0
410 | ).item()
411 | print(psnr, ssim, lpips)
412 | psnr_list.append(psnr)
413 | ssim_list.append(ssim)
414 | lpips_list.append(lpips)
415 |
416 | dynamic_mask_path = os.path.join(
417 | test_dataset.scene_path,
418 | 'mv_masks',
419 | '%05d' % img_i,
420 | 'cam%02d.png' % (i + 1),
421 | )
422 |
423 | dynamic_mask = np.float32(cv2.imread(dynamic_mask_path) > 1e-3) # /255.
424 | dynamic_mask = cv2.resize(
425 | dynamic_mask,
426 | dsize=(gt_img.shape[1], gt_img.shape[0]),
427 | interpolation=cv2.INTER_NEAREST,
428 | )
429 |
430 | dynamic_mask_0 = torch.Tensor(
431 | dynamic_mask[:, :, :, np.newaxis].transpose((3, 2, 0, 1))
432 | )
433 | dynamic_ssim = calculate_ssim(gt_img, fine_pred_rgb, dynamic_mask)
434 | dynamic_psnr = calculate_psnr(gt_img, fine_pred_rgb, dynamic_mask)
435 | dynamic_lpips = lpips_model.forward(
436 | gt_img_0, fine_pred_rgb_0, dynamic_mask_0
437 | ).item()
438 | print(dynamic_psnr, dynamic_ssim, dynamic_lpips)
439 |
440 | dy_psnr_list.append(dynamic_psnr)
441 | dy_ssim_list.append(dynamic_ssim)
442 | dy_lpips_list.append(dynamic_lpips)
443 |
444 | static_mask = 1 - dynamic_mask
445 | static_mask_0 = torch.Tensor(
446 | static_mask[:, :, :, np.newaxis].transpose((3, 2, 0, 1))
447 | )
448 | static_ssim = calculate_ssim(gt_img, fine_pred_rgb, static_mask)
449 | static_psnr = calculate_psnr(gt_img, fine_pred_rgb, static_mask)
450 | static_lpips = lpips_model.forward(
451 | gt_img_0, fine_pred_rgb_0, static_mask_0
452 | ).item()
453 | print(static_psnr, static_ssim, static_lpips)
454 |
455 | st_psnr_list.append(static_psnr)
456 | st_ssim_list.append(static_ssim)
457 | st_lpips_list.append(static_lpips)
458 |
459 | print('MOVING PSNR ', np.mean(np.array(psnr_list)))
460 | print('MOVING SSIM ', np.mean(np.array(ssim_list)))
461 | print('MOVING LPIPS ', np.mean(np.array(lpips_list)))
462 |
463 | print('MOVING DYNAMIC PSNR ', np.mean(np.array(dy_psnr_list)))
464 | print('MOVING DYNAMIC SSIM ', np.mean(np.array(dy_ssim_list)))
465 | print('MOVING DYNAMIC LPIPS ', np.mean(np.array(dy_lpips_list)))
466 |
467 | print('MOVING Static PSNR ', np.mean(np.array(st_psnr_list)))
468 | print('MOVING Static SSIM ', np.mean(np.array(st_ssim_list)))
469 | print('MOVING Static LPIPS ', np.mean(np.array(st_lpips_list)))
470 |
471 | print('AVG PSNR ', np.mean(np.array(psnr_list)))
472 | print('AVG SSIM ', np.mean(np.array(ssim_list)))
473 | print('AVG LPIPS ', np.mean(np.array(lpips_list)))
474 |
475 | print('AVG DYNAMIC PSNR ', np.mean(np.array(dy_psnr_list)))
476 | print('AVG DYNAMIC SSIM ', np.mean(np.array(dy_ssim_list)))
477 | print('AVG DYNAMIC LPIPS ', np.mean(np.array(dy_lpips_list)))
478 |
479 | print('AVG Static PSNR ', np.mean(np.array(st_psnr_list)))
480 | print('AVG Static SSIM ', np.mean(np.array(st_ssim_list)))
481 | print('AVG Static LPIPS ', np.mean(np.array(st_lpips_list)))
482 |
--------------------------------------------------------------------------------
/ibrnet/mlp_network.py:
--------------------------------------------------------------------------------
1 | """Class definition for MLP Network."""
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | torch._C._jit_set_profiling_executor(False)
10 | torch._C._jit_set_profiling_mode(False)
11 |
12 |
13 | class ScaledDotProductAttention(nn.Module):
14 | """Dot-Product Attention Layer."""
15 |
16 | def __init__(self, temperature, attn_dropout=0.1):
17 | super().__init__()
18 | self.temperature = temperature
19 |
20 | def forward(self, q, k, v, mask=None):
21 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
22 |
23 | if mask is not None:
24 | attn = attn.masked_fill(mask == 0, -1e9)
25 | # attn = attn * mask
26 |
27 | attn = F.softmax(attn, dim=-1)
28 | # attn = self.dropout(F.softmax(attn, dim=-1))
29 | output = torch.matmul(attn, v)
30 |
31 | return output, attn
32 |
33 |
34 | class PositionwiseFeedForward(nn.Module):
35 | """A two-feed-forward-layer module."""
36 |
37 | def __init__(self, d_in, d_hid, dropout=0.1):
38 | super().__init__()
39 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise
40 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise
41 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
42 | # self.dropout = nn.Dropout(dropout)
43 |
44 | def forward(self, x):
45 | residual = x
46 |
47 | x = self.w_2(F.relu(self.w_1(x)))
48 | # x = self.dropout(x)
49 | x += residual
50 |
51 | x = self.layer_norm(x)
52 |
53 | return x
54 |
55 |
56 | class MultiHeadAttention(nn.Module):
57 | """Multi-Head Attention module."""
58 |
59 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
60 | super().__init__()
61 |
62 | self.n_head = n_head
63 | self.d_k = d_k
64 | self.d_v = d_v
65 |
66 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
67 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
68 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
69 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
70 |
71 | self.attention = ScaledDotProductAttention(temperature=d_k**0.5)
72 |
73 | # self.dropout = nn.Dropout(dropout)
74 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
75 |
76 | def forward(self, q, k, v, mask=None):
77 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
78 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
79 |
80 | residual = q
81 |
82 | # Pass through the pre-attention projection: b x lq x (n*dv)
83 | # Separate different heads: b x lq x n x dv
84 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
85 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
86 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
87 |
88 | # Transpose for attention dot product: b x n x lq x dv
89 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
90 |
91 | if mask is not None:
92 | mask = mask.unsqueeze(1) # For head axis broadcasting.
93 |
94 | q, attn = self.attention(q, k, v, mask=mask)
95 |
96 | # Transpose to move the head dimension back: b x lq x n x dv
97 | # Combine the last two dimensions to concatenate all the heads together
98 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
99 | q = self.fc(q)
100 | q += residual
101 |
102 | q = self.layer_norm(q)
103 |
104 | return q, attn
105 |
106 |
107 | def weights_init(m):
108 | """Default initialization of linear layers."""
109 | if isinstance(m, nn.Linear):
110 | nn.init.kaiming_normal_(m.weight.data)
111 | if m.bias is not None:
112 | nn.init.zeros_(m.bias.data)
113 |
114 |
115 | @torch.jit.script
116 | def fused_mean_variance(x, weight):
117 | mean = torch.sum(x * weight, dim=2, keepdim=True)
118 | var = torch.sum(weight * (x - mean) ** 2, dim=2, keepdim=True)
119 | return mean, var
120 |
121 |
122 | @torch.jit.script
123 | def epipolar_fused_mean_variance(x, weight):
124 | mean = torch.sum(x * weight, dim=1, keepdim=True)
125 | var = torch.sum(weight * (x - mean) ** 2, dim=1, keepdim=True)
126 | return mean, var
127 |
128 |
129 | class DynibarDynamic(nn.Module):
130 | """Dynibar time-varying dynamic model."""
131 |
132 | def __init__(self, args, in_feat_ch=32, n_samples=64, shift=0.0, **kwargs):
133 | super(DynibarDynamic, self).__init__()
134 | self.args = args
135 | self.anti_alias_pooling = False # args.anti_alias_pooling
136 | self.input_dir = args.input_dir
137 | self.input_xyz = args.input_xyz
138 |
139 | if self.anti_alias_pooling:
140 | self.s = nn.Parameter(torch.tensor(0.2), requires_grad=True)
141 |
142 | activation_func = nn.ELU(inplace=True)
143 | self.shift = shift
144 | t_num_freqs = 10
145 | self.t_embed = PeriodicEmbed(
146 | max_freq=t_num_freqs, N_freq=t_num_freqs, linspace=False
147 | ).float()
148 | dir_num_freqs = 4
149 | self.dir_embed = PeriodicEmbed(
150 | max_freq=dir_num_freqs, N_freq=dir_num_freqs, linspace=False
151 | ).float()
152 |
153 | pts_num_freqs = 5
154 | self.pts_embed = PeriodicEmbed(
155 | max_freq=pts_num_freqs, N_freq=pts_num_freqs, linspace=False
156 | ).float()
157 |
158 | self.n_samples = n_samples
159 | self.ray_dir_fc = nn.Sequential(
160 | nn.Linear(t_num_freqs * 2 + 1, 256),
161 | activation_func,
162 | nn.Linear(256, in_feat_ch + 3),
163 | activation_func,
164 | )
165 |
166 | self.base_fc = nn.Sequential(
167 | nn.Linear((in_feat_ch + 3) * 3, 256),
168 | activation_func,
169 | nn.Linear(256, 128),
170 | activation_func,
171 | )
172 |
173 | self.vis_fc = nn.Sequential(
174 | nn.Linear(128, 128),
175 | activation_func,
176 | nn.Linear(128, 128 + 1),
177 | activation_func,
178 | )
179 |
180 | self.vis_fc2 = nn.Sequential(
181 | nn.Linear(128, 128), activation_func, nn.Linear(128, 1), nn.Sigmoid()
182 | )
183 |
184 | self.geometry_fc = nn.Sequential(
185 | nn.Linear(128 * 2 + 1, 256),
186 | activation_func,
187 | nn.Linear(256, 128),
188 | activation_func,
189 | )
190 |
191 | self.ray_attention = MultiHeadAttention(4, 128, 32, 32)
192 |
193 | num_c_xyz = (pts_num_freqs * 2 + 1) * 3
194 |
195 | self.ref_pts_fc = nn.Sequential(
196 | nn.Linear(num_c_xyz + 128, 256),
197 | activation_func,
198 | nn.Linear(256, 128),
199 | activation_func,
200 | )
201 |
202 | self.out_geometry_fc = nn.Sequential(
203 | nn.Linear(128, 128), activation_func, nn.Linear(128, 1)
204 | )
205 |
206 | if self.input_dir:
207 | self.rgb_fc = nn.Sequential(
208 | nn.Linear(128 + (dir_num_freqs * 2 + 1) * 3, 128),
209 | activation_func,
210 | nn.Linear(128, 64),
211 | activation_func,
212 | nn.Linear(64, 3),
213 | nn.Sigmoid(),
214 | )
215 | else:
216 | raise NotImplementedError
217 |
218 | self.pos_encoding = self.posenc(d_hid=128, n_samples=self.n_samples)
219 |
220 | def posenc(self, d_hid, n_samples):
221 | def get_position_angle_vec(position):
222 | return [
223 | position / np.power(10000, 2 * (hid_j // 2) / d_hid)
224 | for hid_j in range(d_hid)
225 | ]
226 |
227 | sinusoid_table = np.array(
228 | [get_position_angle_vec(pos_i) for pos_i in range(n_samples)]
229 | )
230 |
231 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
232 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
233 | sinusoid_table = torch.from_numpy(sinusoid_table).float().unsqueeze(0)
234 | return sinusoid_table
235 |
236 | def forward(
237 | self, pts_xyz, rgb_feat, glb_ray_dir, ray_diff, time_diff, mask, time
238 | ):
239 | num_views = rgb_feat.shape[2]
240 | time_pe = (
241 | self.t_embed(time)[..., None, :].repeat(1, 1, num_views, 1).float()
242 | )
243 |
244 | direction_feat = self.ray_dir_fc(time_pe)
245 |
246 | # rgb_in = rgb_feat[..., :3]
247 | rgb_feat = rgb_feat + direction_feat
248 |
249 | if self.anti_alias_pooling:
250 | _, dot_prod = torch.split(ray_diff, [3, 1], dim=-1)
251 | exp_dot_prod = torch.exp(torch.abs(self.s) * (dot_prod - 1))
252 | weight = (
253 | exp_dot_prod - torch.min(exp_dot_prod, dim=2, keepdim=True)[0]
254 | ) * mask
255 | weight = weight / (torch.sum(weight, dim=2, keepdim=True) + 1e-8)
256 | else:
257 | weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8)
258 |
259 | # compute mean and variance across different views for each point
260 | mean, var = fused_mean_variance(
261 | rgb_feat, weight
262 | ) # [n_rays, n_samples, 1, n_feat]
263 | globalfeat = torch.cat(
264 | [mean, var], dim=-1
265 | ) # [n_rays, n_samples, 1, 2*n_feat]
266 |
267 | x = torch.cat(
268 | [globalfeat.expand(-1, -1, num_views, -1), rgb_feat], dim=-1
269 | ) # [n_rays, n_samples, n_views, 3*n_feat]
270 | x = self.base_fc(x)
271 |
272 | x_vis = self.vis_fc(x * weight)
273 | x_res, vis = torch.split(x_vis, [x_vis.shape[-1] - 1, 1], dim=-1)
274 | vis = F.sigmoid(vis) * mask
275 | x = x + x_res
276 | vis = self.vis_fc2(x * vis) * mask
277 | weight = vis / (torch.sum(vis, dim=2, keepdim=True) + 1e-8)
278 |
279 | mean, var = fused_mean_variance(x, weight)
280 | globalfeat = torch.cat(
281 | [mean.squeeze(2), var.squeeze(2), weight.mean(dim=2)], dim=-1
282 | ) # [n_rays, n_samples, 32*2+1]
283 | globalfeat = self.geometry_fc(globalfeat) # [n_rays, n_samples, 16]
284 | num_valid_obs = torch.sum(mask, dim=2)
285 |
286 | globalfeat = globalfeat + self.pos_encoding.to(globalfeat.device)
287 | globalfeat, _ = self.ray_attention(
288 | globalfeat, globalfeat, globalfeat, mask=(num_valid_obs > 1).float()
289 | ) # [n_rays, n_samples, 16]
290 |
291 | pts_xyz_pe = self.pts_embed(pts_xyz)
292 | globalfeat = self.ref_pts_fc(torch.cat([globalfeat, pts_xyz_pe], dim=-1))
293 |
294 | sigma = (
295 | self.out_geometry_fc(globalfeat) - self.shift
296 | ) # [n_rays, n_samples, 1]
297 | sigma_out = sigma.masked_fill(
298 | num_valid_obs < 1, -1e9
299 | ) # set the sigma of invalid point to zero
300 |
301 | if self.input_dir:
302 | glb_ray_dir_pe = self.dir_embed(glb_ray_dir).float()
303 | h = torch.cat(
304 | [
305 | globalfeat,
306 | glb_ray_dir_pe[:, None, :].repeat(1, globalfeat.shape[1], 1),
307 | ],
308 | dim=-1,
309 | )
310 | else:
311 | h = globalfeat
312 |
313 | rgb_out = self.rgb_fc(h)
314 | rgb_out = rgb_out.masked_fill(torch.sum(mask.repeat(1, 1, 1, 3), 2) == 0, 0)
315 | out = torch.cat([rgb_out, sigma_out], dim=-1)
316 | return out
317 |
318 |
319 | class DynibarStatic(nn.Module):
320 | """Dynibar time-invariant static model."""
321 |
322 | def __init__(self, args, in_feat_ch=32, n_samples=64, **kwargs):
323 | super(DynibarStatic, self).__init__()
324 | self.args = args
325 | self.anti_alias_pooling = args.anti_alias_pooling # CHECK DISCREPENCY
326 | self.mask_rgb = args.mask_rgb
327 | self.input_dir = args.input_dir
328 | self.input_xyz = args.input_xyz
329 |
330 | if self.anti_alias_pooling:
331 | self.s = nn.Parameter(torch.tensor(0.2), requires_grad=True)
332 |
333 | activation_func = nn.ELU(inplace=True)
334 |
335 | ray_num_freqs = 5
336 | self.ray_embed = PeriodicEmbed(
337 | max_freq=ray_num_freqs, N_freq=ray_num_freqs, linspace=False
338 | )
339 | pts_num_freqs = 5
340 | self.pts_embed = PeriodicEmbed(
341 | max_freq=pts_num_freqs, N_freq=pts_num_freqs, linspace=False
342 | )
343 |
344 | num_c_xyz = (pts_num_freqs * 2 + 1) * 3
345 | num_c_ray = (ray_num_freqs * 2 + 1) * 6
346 |
347 | self.n_samples = n_samples
348 |
349 | self.ray_dir_fc = nn.Sequential(
350 | nn.Linear(4 + num_c_xyz + num_c_ray, 256),
351 | activation_func,
352 | nn.Linear(256, in_feat_ch + 3),
353 | )
354 |
355 | self.ref_feature_fc = nn.Sequential(nn.Linear(num_c_ray, in_feat_ch + 3))
356 |
357 | self.base_fc = nn.Sequential(
358 | nn.Linear((in_feat_ch + 3) * 6, 256),
359 | activation_func,
360 | nn.Linear(256, 128),
361 | activation_func,
362 | )
363 |
364 | self.vis_fc = nn.Sequential(
365 | nn.Linear(128, 128),
366 | activation_func,
367 | nn.Linear(128, 128 + 1),
368 | activation_func,
369 | )
370 |
371 | self.vis_fc2 = nn.Sequential(
372 | nn.Linear(128, 128), activation_func, nn.Linear(128, 1), nn.Sigmoid()
373 | )
374 |
375 | self.geometry_fc = nn.Sequential(
376 | nn.Linear(128 * 2 + 1, 256),
377 | activation_func,
378 | nn.Linear(256, 128),
379 | activation_func,
380 | )
381 |
382 | self.ray_attention = MultiHeadAttention(4, 128, 32, 32)
383 | self.out_geometry_fc = nn.Sequential(
384 | nn.Linear(128, 128), activation_func, nn.Linear(128, 1)
385 | )
386 |
387 | if self.input_dir:
388 | self.rgb_fc = nn.Sequential(
389 | nn.Linear(128 * 2 + 1 + 4, 128),
390 | activation_func,
391 | nn.Linear(128, 64),
392 | activation_func,
393 | nn.Linear(64, 1),
394 | )
395 |
396 | else:
397 | self.rgb_fc = nn.Sequential(
398 | nn.Linear(32 + 1, 32),
399 | activation_func,
400 | nn.Linear(32, 16),
401 | activation_func,
402 | nn.Linear(16, 1),
403 | )
404 |
405 | self.pos_encoding = self.posenc(d_hid=128, n_samples=self.n_samples)
406 |
407 | def posenc(self, d_hid, n_samples):
408 | def get_position_angle_vec(position):
409 | return [
410 | position / np.power(10000, 2 * (hid_j // 2) / d_hid)
411 | for hid_j in range(d_hid)
412 | ]
413 |
414 | sinusoid_table = np.array(
415 | [get_position_angle_vec(pos_i) for pos_i in range(n_samples)]
416 | )
417 |
418 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
419 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
420 | sinusoid_table = torch.from_numpy(sinusoid_table).float().unsqueeze(0)
421 | return sinusoid_table
422 |
423 | def forward(
424 | self,
425 | pts,
426 | ref_rays_coords,
427 | src_rays_coords,
428 | rgb_feat,
429 | glb_ray_dir,
430 | ray_diff,
431 | mask,
432 | ):
433 | num_views = rgb_feat.shape[2]
434 | ref_rays_pe = self.ray_embed(ref_rays_coords)
435 | src_rays_pe = self.ray_embed(src_rays_coords)
436 | pts_pe = self.pts_embed(pts)
437 |
438 | ref_features = ref_rays_pe[:, None, None, :].expand(
439 | -1, src_rays_pe.shape[1], src_rays_pe.shape[2], -1
440 | )
441 | src_features = torch.cat(
442 | [
443 | pts_pe.unsqueeze(2).expand(-1, -1, src_rays_pe.shape[2], -1),
444 | src_rays_pe,
445 | ],
446 | dim=-1,
447 | )
448 |
449 | src_feat = self.ray_dir_fc(torch.cat([src_features, ray_diff], dim=-1))
450 | ref_feat = self.ref_feature_fc(ref_features)
451 |
452 | rgb_in = rgb_feat[..., :3]
453 |
454 | if self.mask_rgb:
455 | rgb_in_sum = torch.sum(rgb_in, dim=-1, keepdim=True)
456 | rgb_mask = (rgb_in_sum > 1e-3).float().detach()
457 | mask = mask * rgb_mask
458 |
459 | rgb_feat = torch.cat([rgb_feat, src_feat * ref_feat], dim=-1)
460 |
461 | if self.anti_alias_pooling:
462 | _, dot_prod = torch.split(ray_diff, [3, 1], dim=-1)
463 | exp_dot_prod = torch.exp(torch.abs(self.s) * (dot_prod - 1))
464 | weight = (
465 | exp_dot_prod - torch.min(exp_dot_prod, dim=2, keepdim=True)[0]
466 | ) * mask
467 | weight = weight / (torch.sum(weight, dim=2, keepdim=True) + 1e-8)
468 | else:
469 | weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8)
470 |
471 | # compute mean and variance across different views for each point
472 | mean, var = fused_mean_variance(
473 | rgb_feat, weight
474 | ) # [n_rays, n_samples, 1, n_feat]
475 | globalfeat = torch.cat(
476 | [mean, var], dim=-1
477 | ) # [n_rays, n_samples, 1, 2*n_feat]
478 |
479 | x = torch.cat(
480 | [globalfeat.expand(-1, -1, num_views, -1), rgb_feat], dim=-1
481 | ) # [n_rays, n_samples, n_views, 3*n_feat]
482 |
483 | x = self.base_fc(x)
484 |
485 | x_vis = self.vis_fc(x * weight)
486 | x_res, vis = torch.split(x_vis, [x_vis.shape[-1] - 1, 1], dim=-1)
487 | vis = F.sigmoid(vis) * mask
488 | x = x + x_res
489 | vis = self.vis_fc2(x * vis) * mask
490 | weight = vis / (torch.sum(vis, dim=2, keepdim=True) + 1e-8)
491 |
492 | mean, var = fused_mean_variance(x, weight)
493 | globalfeat = torch.cat(
494 | [mean.squeeze(2), var.squeeze(2), weight.mean(dim=2)], dim=-1
495 | ) # [n_rays, n_samples, 32*2+1]
496 | globalfeat = self.geometry_fc(globalfeat) # [n_rays, n_samples, 16]
497 | num_valid_obs = torch.sum(mask, dim=2)
498 |
499 | # globalfeat = globalfeat #+ self.pos_encoding.to(globalfeat.device)
500 | globalfeat, _ = self.ray_attention(
501 | globalfeat, globalfeat, globalfeat, mask=(num_valid_obs > 1).float()
502 | ) # [n_rays, n_samples, 16]
503 | sigma = self.out_geometry_fc(globalfeat) # [n_rays, n_samples, 1]
504 | sigma_out = sigma.masked_fill(
505 | num_valid_obs < 1, -1e9
506 | ) # set the sigma of invalid point to zero
507 |
508 | if self.input_dir:
509 | x = torch.cat(
510 | [
511 | globalfeat[:, :, None, :].expand(-1, -1, x.shape[2], -1),
512 | x,
513 | vis,
514 | ray_diff,
515 | ],
516 | dim=-1,
517 | )
518 | else:
519 | x = torch.cat([globalfeat, vis], dim=-1)
520 |
521 | x = self.rgb_fc(x)
522 |
523 | x = x.masked_fill(mask == 0, -1e9)
524 | blending_weights_valid = F.softmax(x, dim=2) # color blending
525 | rgb_out = torch.sum(rgb_in * blending_weights_valid, dim=2)
526 | out = torch.cat([rgb_out, sigma_out], dim=-1)
527 | return out
528 |
529 |
530 | class PeriodicEmbed(nn.Module):
531 | """Fourier Position encoding module."""
532 |
533 | def __init__(self, max_freq, N_freq, linspace=True):
534 | """Init function for position encoding.
535 |
536 | Args:
537 | max_freq: max frequency band
538 | N_freq: number of frequency
539 | linspace: linearly spacing or not
540 | """
541 | super().__init__()
542 | self.embed_functions = [torch.cos, torch.sin]
543 | if linspace:
544 | self.freqs = torch.linspace(1, max_freq + 1, steps=N_freq)
545 | else:
546 | exps = torch.linspace(0, N_freq - 1, steps=N_freq)
547 | self.freqs = 2**exps
548 |
549 | def forward(self, x):
550 | output = [x]
551 | for f in self.embed_functions:
552 | for freq in self.freqs:
553 | output.append(f(freq * x))
554 |
555 | return torch.cat(output, -1)
556 |
557 |
558 | class MotionMLP(nn.Module):
559 | """Motion trajectory MLP module."""
560 |
561 | def __init__(
562 | self,
563 | num_basis=4,
564 | D=8,
565 | W=256,
566 | input_ch=4,
567 | num_freqs=16,
568 | skips=[4],
569 | sf_mag_div=1.0,
570 | ):
571 | """Init function for motion MLP.
572 |
573 | Args:
574 | num_basis: number motion basis
575 | D: MLP layers
576 | W: feature dimention of MLP layers
577 | input_ch: input number of channels
578 | num_freqs: number of rquency for position encoding
579 | skips: where to inject skip connection
580 | sf_mag_div: motion scaling factor
581 | """
582 | super(MotionMLP, self).__init__()
583 | self.D = D
584 | self.W = W
585 | self.input_ch = int(input_ch + input_ch * num_freqs * 2)
586 | self.skips = skips
587 | self.sf_mag_div = sf_mag_div
588 |
589 | self.xyzt_embed = PeriodicEmbed(max_freq=num_freqs, N_freq=num_freqs)
590 |
591 | self.pts_linears = nn.ModuleList(
592 | [nn.Linear(self.input_ch, W)]
593 | + [
594 | nn.Linear(W, W)
595 | if i not in self.skips
596 | else nn.Linear(W + self.input_ch, W)
597 | for i in range(D - 1)
598 | ]
599 | )
600 |
601 | self.coeff_linear = nn.Linear(W, num_basis * 3)
602 | self.coeff_linear.weight.data.fill_(0.0)
603 | self.coeff_linear.bias.data.fill_(0.0)
604 |
605 | def forward(self, x):
606 | input_pts = self.xyzt_embed(x)
607 |
608 | h = input_pts
609 | for i, l in enumerate(self.pts_linears):
610 | h = self.pts_linears[i](h)
611 | h = F.relu(h)
612 | if i in self.skips:
613 | h = torch.cat([input_pts, h], -1)
614 |
615 | # sf = nn.functional.tanh(self.sf_linear(h))
616 | pred_coeff = self.coeff_linear(h)
617 |
618 | return pred_coeff / self.sf_mag_div
619 |
--------------------------------------------------------------------------------