├── LICENSE
├── README.md
├── ckpts
└── Exp-syndney
│ ├── MPI_rendered_views.gif
│ ├── MPI_rendered_views.mp4
│ ├── canvas.png
│ └── canvas_depth.png
├── dataloaders
└── single_img_data.py
├── model
├── Inpainter.py
├── MPF.py
├── TrainableFilter.py
└── VitExtractor.py
├── mpi
├── homography_sampler.py
├── mpi_rendering.py
└── rendering_utils.py
├── outpaint_rgbd.py
├── requirements.txt
├── scripts
└── train_all.sh
├── test_images
└── Syndney.jpg
├── train_inpainting.py
├── train_mpi.py
├── utils_for_train.py
└── warpback
├── networks.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Tricky
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### Welcome to SinMPI!
2 |
3 | ["SinMPI: Novel View Synthesis from a Single Image with Expanded Multiplane Images" (SIGGRAPH Asia 2023)](https://arxiv.org/abs/2312.11037).
4 |
5 | ## Quick demo
6 |
7 | ### 1. Prepare
8 |
9 | (1) Create a new conda environment specified in requirements.txt.
10 |
11 | (2) Download pretrained weights of depth-aware inpainter [ecweights](https://drive.google.com/drive/folders/1FZZ6laPuqEMSfrGvEWYaDZWEPaHvGm6r) and put them into 'warpback/ecweights/xxx.pth'.
12 |
13 | ### 2. Run demo
14 | ```
15 | sh scripts/train_all.sh
16 | ```
17 | This demo converts 'test_images/Syndney.jpg' to an expanded MPI and renders novel views as in 'ckpts/Exp-Syndney-new/MPI_rendered_views.mp4'.
18 |
19 | ## What happens when running the demo?
20 |
21 | ### 1. Outpaint the input image
22 |
23 | In the above demo, we specify 'test_images/Syndney.jpg'
24 |
25 |
26 |
27 | as the input image, then we continuously outpaint the input image:
28 | ```
29 | CUDA_VISIBLE_DEVICES=$cuda python outpaint_rgbd.py \
30 | --width $width \
31 | --height $height \
32 | --ckpt_path $ckpt_path \
33 | --img_path $img_path \
34 | --extrapolate_times $extrapolate_times
35 | ```
36 | Then we get the outpainted image and its depth estimated by a monocular depth estimator (DPT):
37 |
38 |
39 |
40 | ### 2. Finetune Depth-aware Inpainter and create Pseudo-multi-view images
41 |
42 | ```
43 | CUDA_VISIBLE_DEVICES=$cuda python train_inpainting.py \
44 | --width $width \
45 | --height $height \
46 | --ckpt_path $ckpt_path \
47 | --img_path $img_path \
48 | --num_epochs 10 \
49 | --extrapolate_times $extrapolate_times \
50 | --batch_size 1 #--load_warp_pairs --debugging
51 | ```
52 |
53 | ### 3. Optimizing the expanded MPI
54 |
55 | ```
56 | CUDA_VISIBLE_DEVICES=$cuda python train_mpi.py \
57 | --width $width \
58 | --height $height \
59 | --ckpt_path $ckpt_path \
60 | --img_path $img_path \
61 | --num_epochs 10 \
62 | --extrapolate_times $extrapolate_times \
63 | --batch_size 1 #--debugging #--resume
64 | ```
65 |
66 | After optimization, we render novel views:
67 |
68 |
69 |
70 | ## Toward Better quality and robustness
71 |
72 | Notice the above demo is designed for fast illustration (FPS is low). For better quality:
73 |
74 | #### Pesudo-multi-views should be more and cover more areas.
75 | Increasing the sample rate and sample areas helps to optimize MPI with better quality. To add training and rendering view trajectories, modify 'dataloaders/single_img_data.py'.
76 | #### More training epochs are needed.
77 |
78 | ## Cite our paper
79 |
80 | If you find our work helpful, please cite our paper. Thank you!
81 |
82 | ACM Reference Format:
83 | ```
84 | Guo Pu, Peng-Shuai Wang, and Zhouhui Lian. 2023. SinMPI: Novel View
85 | Synthesis from a Single Image with Expanded Multiplane Images. In SIGGRAPH Asia 2023 Conference Papers (SA Conference Papers '23), December
86 | 12–15, 2023, Sydney, NSW, Australia. ACM, New York, NY, USA, 10 pages.
87 | https://doi.org/10.1145/3610548.3618155
88 | ```
89 |
--------------------------------------------------------------------------------
/ckpts/Exp-syndney/MPI_rendered_views.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrickyGo/SinMPI/e14a8232ea1c0e6b6680a49a89cc0ce6f6e2409f/ckpts/Exp-syndney/MPI_rendered_views.gif
--------------------------------------------------------------------------------
/ckpts/Exp-syndney/MPI_rendered_views.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrickyGo/SinMPI/e14a8232ea1c0e6b6680a49a89cc0ce6f6e2409f/ckpts/Exp-syndney/MPI_rendered_views.mp4
--------------------------------------------------------------------------------
/ckpts/Exp-syndney/canvas.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrickyGo/SinMPI/e14a8232ea1c0e6b6680a49a89cc0ce6f6e2409f/ckpts/Exp-syndney/canvas.png
--------------------------------------------------------------------------------
/ckpts/Exp-syndney/canvas_depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrickyGo/SinMPI/e14a8232ea1c0e6b6680a49a89cc0ce6f6e2409f/ckpts/Exp-syndney/canvas_depth.png
--------------------------------------------------------------------------------
/dataloaders/single_img_data.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(".")
3 | sys.path.append("..")
4 | import math
5 | import numpy as np
6 | import torch
7 | from torch.utils.data.dataset import Dataset
8 | from torchvision import transforms
9 | from PIL import Image
10 |
11 | def gen_swing_path(init_pose=torch.eye(4), num_frames=10, r_x=0.2, r_y=0, r_z=-0.2):
12 | "Return a list of matrix [4, 4]"
13 | t = torch.arange(num_frames) / (num_frames - 1)
14 | poses = init_pose.repeat(num_frames, 1, 1)
15 |
16 | swing = torch.eye(4).repeat(num_frames, 1, 1)
17 | swing[:, 0, 3] = r_x * torch.sin(2. * math.pi * t)
18 | swing[:, 1, 3] = r_y * torch.cos(2. * math.pi * t)
19 | swing[:, 2, 3] = r_z * (torch.cos(2. * math.pi * t) - 1.)
20 |
21 | for i in range(num_frames):
22 | poses[i, :, :] = poses[i, :, :] @ swing[i, :, :]
23 | return list(poses.unbind())
24 |
25 | def create_spheric_poses_along_y(n_poses=10):
26 |
27 | def spheric_pose_y(phi, radius=10):
28 | trans_t = lambda t : np.array([
29 | [1,0,0, math.sin(2. * math.pi * t) * radius],
30 | [0,1,0,0],
31 | [0,0,1,0],
32 | [0,0,0,1],
33 | ])
34 |
35 | # rotation along y
36 | rot_phi = lambda phi : np.array([
37 | [np.cos(phi),0, np.sin(phi),0],
38 | [0,1,0,0],
39 | [-np.sin(phi),0,np.cos(phi),0],
40 | [0,0,0,1],
41 | ])
42 |
43 | c2w = rot_phi(phi) @ trans_t(phi)
44 | c2w = np.array([[1,0,0,0],
45 | [0,1,0,0],
46 | [0,0,1,0],
47 | [0,0,0,1]]) @ c2w
48 | c2w = torch.tensor(c2w).float()
49 | return c2w
50 |
51 | def spheric_pose_x(phi, radius=10):
52 | trans_t = lambda t : np.array([
53 | [1,0,0,0],
54 | [0,1,0,math.sin(2. * math.pi * t) * radius * -1],
55 | [0,0,1,0],
56 | [0,0,0,1],
57 | ])
58 |
59 | # rotation along x
60 | rot_theta = lambda th : np.array([
61 | [1,0,0,0],
62 | [0,np.cos(th),-np.sin(th),0],
63 | [0,np.sin(th), np.cos(th),0],
64 | [0,0,0,1],
65 | ])
66 |
67 | c2w = rot_theta(phi) @ trans_t(phi)
68 | c2w = np.array([[1,0,0,0],
69 | [0,1,0,0],
70 | [0,0,1,0],
71 | [0,0,0,1]]) @ c2w
72 | c2w = torch.tensor(c2w).float()
73 | return c2w
74 |
75 | spheric_poses = []
76 | poses = gen_swing_path()
77 | spheric_poses += poses
78 |
79 | factor = 1
80 | y_angle = (1/16) * np.pi * factor
81 | x_angle = (1/16) * np.pi * factor
82 | x_radius = 0.1 * factor
83 | y_radius = 0.1 * factor
84 |
85 | # rotate left and right
86 | for th in np.linspace(0, -1 * y_angle, n_poses//2):
87 | spheric_poses += [spheric_pose_y(th, y_radius)]
88 |
89 | poses = gen_swing_path(spheric_poses[-1])
90 | spheric_poses += poses
91 |
92 | for th in np.linspace(-1 * y_angle, y_angle, n_poses)[:-1]:
93 | spheric_poses += [spheric_pose_y(th, y_radius)]
94 |
95 | poses = gen_swing_path(spheric_poses[-1])
96 | spheric_poses += poses
97 |
98 | for th in np.linspace(y_angle, 0, n_poses//2)[:-1]:
99 | spheric_poses += [spheric_pose_y(th, y_radius)]
100 |
101 | # rotate up and down
102 | for th in np.linspace(0, -1 * x_angle, n_poses//2):
103 | spheric_poses += [spheric_pose_x(th, x_radius)]
104 |
105 | poses = gen_swing_path(spheric_poses[-1])
106 | spheric_poses += poses
107 |
108 | for th in np.linspace(-1 * x_angle, x_angle, n_poses)[:-1]:
109 | spheric_poses += [spheric_pose_x(th, x_radius)]
110 |
111 | poses = gen_swing_path(spheric_poses[-1])
112 | spheric_poses += poses
113 |
114 | for th in np.linspace(x_angle, 0, n_poses//2)[:-1]:
115 | spheric_poses += [spheric_pose_x(th, x_radius)]
116 |
117 | return spheric_poses
118 |
119 |
120 | def convert(c2w, phi=0):
121 | # rot_along_y
122 | c2w = np.concatenate((c2w, np.array([[0, 0, 0, 1]])), axis=0)
123 | rot = np.array([
124 | [np.cos(phi),0, np.sin(phi),0],
125 | [0,1,0,0],
126 | [-np.sin(phi),0,np.cos(phi),0],
127 | [0,0,0,1],
128 | ])
129 | return rot @ c2w
130 |
131 | class SinImgDataset(Dataset):
132 | def __init__(
133 | self,
134 | img_path,
135 | width=512,
136 | height=512,
137 | repeat_times = 1
138 | ):
139 | self.repeat_times = repeat_times
140 | self.img_wh = [width, height]
141 | self.transform = transforms.ToTensor()
142 | self.img_path = img_path
143 | self.read_meta()
144 |
145 | def read_meta(self):
146 |
147 | self.all_rgbs = []
148 | self.all_poses = []
149 | self.all_poses += create_spheric_poses_along_y()
150 |
151 | img_path = "test_images/" + self.img_path
152 | self.ref_img = self.load_img(img_path)
153 |
154 |
155 | def load_img(self, image_path):
156 | img = Image.open(image_path).convert('RGB')
157 | img = img.resize(self.img_wh, Image.LANCZOS)
158 | img = self.transform(img) # (3, h, w)
159 | img = img.unsqueeze(0).cuda()
160 | return img # [img:1*3*h*w]
161 |
162 |
163 | def __len__(self):
164 | return len(self.all_poses) * self.repeat_times
165 |
166 | def __getitem__(self, idx):
167 | sample = {
168 | 'cur_pose':self.all_poses[idx % len(self.all_poses)]
169 | }
170 | return sample
171 |
--------------------------------------------------------------------------------
/model/Inpainter.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(".")
3 | sys.path.append("..")
4 | import os
5 | import glob
6 | import numpy as np
7 | from skimage.feature import canny
8 | import torch
9 | import torch.nn.functional as F
10 | from torchvision import transforms
11 |
12 | from warpback.networks import get_edge_connect
13 |
14 | class InpaintingModule(torch.nn.Module):
15 | def __init__(
16 | self,
17 | data_root='',
18 | width=512,
19 | height=512,
20 | depth_dir_name="dpt_depth",
21 | device="cuda:0",
22 | trans_range={"x":0.2, "y":-1, "z":-1, "a":-1, "b":-1, "c":-1}, # xyz for translation, abc for euler angle
23 | ec_weight_dir="warpback/ecweight",
24 | ):
25 | super(InpaintingModule, self).__init__()
26 | self.data_root = data_root
27 | self.depth_dir_name = depth_dir_name
28 | self.width = width
29 | self.height = height
30 | self.device = device
31 | self.trans_range = trans_range
32 | self.image_path_list = glob.glob(os.path.join(self.data_root, "*.jpg"))
33 | self.image_path_list += glob.glob(os.path.join(self.data_root, "*.png"))
34 |
35 | self.edge_model, self.inpaint_model, self.disp_model = get_edge_connect(ec_weight_dir)
36 | self.edge_model = self.edge_model.to(self.device)
37 | self.inpaint_model = self.inpaint_model.to(self.device)
38 | self.disp_model = self.disp_model.to(self.device)
39 |
40 | def preprocess_rgbd(self, image, disp):
41 | image = F.interpolate(image.unsqueeze(0), (self.height, self.width), mode="bilinear").squeeze(0)
42 | disp = F.interpolate(disp.unsqueeze(0), (self.height, self.width), mode="bilinear").squeeze(0)
43 | return image, disp
44 |
45 | def forward(self, image, disp, mask):
46 |
47 | image_gray = transforms.Grayscale()(image)
48 | edge = self.get_edge(image_gray, mask)
49 |
50 | mask_hole = 1 - mask
51 |
52 | # inpaint edge
53 | edge_model_input = torch.cat([image_gray, edge, mask_hole], dim=1) # [b,4,h,w]
54 | edge_inpaint = self.edge_model(edge_model_input) # [b,1,h,w]
55 |
56 | # inpaint RGB
57 | inpaint_model_input = torch.cat([image + mask_hole, edge_inpaint], dim=1)
58 | image_inpaint = self.inpaint_model(inpaint_model_input)
59 | image_merged = image * (1 - mask_hole) + image_inpaint * mask_hole
60 |
61 | # inpaint Disparity
62 | disp_model_input = torch.cat([disp + mask_hole, edge_inpaint], dim=1)
63 | disp_inpaint = self.disp_model(disp_model_input)
64 | disp_merged = disp * (1 - mask_hole) + disp_inpaint * mask_hole
65 |
66 | return image_merged, disp_merged
67 |
68 | def get_edge(self, image_gray, mask):
69 | image_gray_np = image_gray.squeeze(1).cpu().numpy() # [b,h,w]
70 | mask_bool_np = np.array(mask.squeeze(1).cpu(), dtype=np.bool_) # [b,h,w]
71 | edges = []
72 | for i in range(mask.shape[0]):
73 | cur_edge = canny(image_gray_np[i], sigma=2, mask=mask_bool_np[i])
74 | edges.append(torch.from_numpy(cur_edge).unsqueeze(0)) # [1,h,w]
75 | edge = torch.cat(edges, dim=0).unsqueeze(1).float() # [b,1,h,w]
76 | return edge.to(self.device)
77 |
78 |
--------------------------------------------------------------------------------
/model/MPF.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class MultiPlaneField(torch.nn.Module):
4 | def __init__(
5 | self,
6 | image_size=(256, 384),
7 | num_planes=12,
8 | assign_origin_planes=None,
9 | depth_range=None
10 | ):
11 | super(MultiPlaneField, self).__init__()
12 | # x:horizontal axis, y:vertical axis, z:inward axis
13 |
14 | self.num_planes = num_planes
15 | self.near, self.far = depth_range
16 |
17 | (self.plane_h, self.plane_w) = image_size
18 |
19 | self.planes_disp = torch.linspace(self.near, self.far, num_planes, requires_grad=False).unsqueeze(0).cuda() # [b,s]
20 | #[S:num_planes, H:plane_h , W:plane_w, 4:rgb+transparency]
21 |
22 | self.extrapolate_RGBDs = assign_origin_planes
23 |
24 | init_planes = self.assign_image_to_planes(self.extrapolate_RGBDs[0], self.extrapolate_RGBDs[1])
25 | self.planes_mid2 = init_planes
26 |
27 | init_val = torch.ones(1, num_planes, 4, self.plane_h, self.plane_w) * 0.1
28 | init_val[:,:,:4,:,:] *= 0.01
29 |
30 | self.planes_residual = torch.nn.Parameter(init_val)
31 |
32 |
33 | def assign_image_to_planes(self, ref_img, ref_disp):
34 | planes = torch.zeros(1, self.num_planes, 4, self.plane_h, self.plane_w, requires_grad=False).cuda()
35 | # set ref_img alpha channels all ones
36 | ref_img = torch.cat([ref_img, torch.ones_like(ref_disp) * self.far],dim=1) #[1,3+1,h,w]
37 | depth_levels = torch.linspace(self.near, self.far, self.num_planes).cuda()
38 |
39 | planes_masks = []
40 | for i in range(len(depth_levels)):
41 | cur_depth_mask = torch.where(ref_disp < depth_levels[i],
42 | torch.ones_like(ref_disp).cuda(),
43 | torch.zeros_like(ref_disp).cuda())
44 | planes_masks.append(cur_depth_mask)
45 | cur_depth_pixels = ref_img * cur_depth_mask.repeat(1,4,1,1)
46 | cur_depth_pixels = cur_depth_pixels.unsqueeze(0)
47 | planes[:,i:i+1,:,:,:] = cur_depth_pixels#[1,1,4,h,w]
48 |
49 | ref_disp = ref_disp + cur_depth_mask * (self.far + 1)# the cur_masked area are discarded
50 |
51 | return planes
52 |
53 | def forward(self):
54 | planes = self.planes_mid2
55 | pred_planes = self.planes_residual
56 |
57 | residual_mask = torch.where(planes > 0,
58 | torch.zeros_like(planes).cuda(),
59 | torch.ones_like(planes).cuda())
60 |
61 | x = pred_planes * residual_mask + planes
62 | return x
63 |
--------------------------------------------------------------------------------
/model/TrainableFilter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | class GaussianFilter(nn.Module):
6 | def __init__(self, ksize=5, sigma=None):
7 | super(GaussianFilter, self).__init__()
8 | # initialize guassian kernel
9 | if sigma is None:
10 | sigma = 0.3 * ((ksize-1) / 2.0 - 1) + 0.8
11 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
12 | x_coord = torch.arange(ksize)
13 | x_grid = x_coord.repeat(ksize).view(ksize, ksize)
14 | y_grid = x_grid.t()
15 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
16 |
17 | # Calculate the 2-dimensional gaussian kernel
18 | center = ksize // 2
19 | weight = torch.exp(-torch.sum((xy_grid - center)**2., dim=-1) / (2*sigma**2))
20 | # Make sure sum of values in gaussian kernel equals 1.
21 | weight /= torch.sum(weight)
22 | self.gaussian_weight = weight
23 |
24 | def forward(self, x):
25 | return self.filter(x)
26 |
27 | class TrainableFilter(nn.Module):
28 | def __init__(self, ksize=5, sigma_space=None, sigma_density=1):
29 | super(TrainableFilter, self).__init__()
30 | # initialization
31 | if sigma_space is None:
32 | self.sigma_space = 0.3 * ((ksize-1) * 0.5 - 1) + 0.8
33 | else:
34 | self.sigma_space = sigma_space
35 | if sigma_density is None:
36 | self.sigma_density = self.sigma_space
37 | else:
38 | self.sigma_density = sigma_density
39 |
40 | self.pad = (ksize-1) // 2
41 | self.ksize = ksize
42 | # get the spatial gaussian weight
43 | self.weight_space = GaussianFilter(ksize=self.ksize, sigma=self.sigma_space).gaussian_weight.cuda()
44 | # # create gaussian filter as convolutional layer
45 | self.weight_space = torch.nn.Parameter(self.weight_space)
46 |
47 | def forward(self, x):
48 | # Extracts sliding local patches from a batched input tensor.
49 | x_pad = F.pad(x, pad=[self.pad, self.pad, self.pad, self.pad], mode='reflect')
50 | x_patches = x_pad.unfold(2, self.ksize, 1).unfold(3, self.ksize, 1)
51 | patch_dim = x_patches.dim()
52 |
53 | # Calculate the 2-dimensional gaussian kernel
54 | diff_density = x_patches - x.unsqueeze(-1).unsqueeze(-1)
55 | weight_density = torch.exp(-(diff_density ** 2) / (2 * self.sigma_density ** 2))
56 | # Normalization
57 | # weight_density /= weight_density.sum(dim=(-1, -2), keepdim=True)
58 | weight_density = weight_density / weight_density.sum(dim=(-1, -2), keepdim=True)
59 | # print(weight_density.shape)
60 |
61 | # Keep same shape with weight_density
62 | weight_space_dim = (patch_dim - 2) * (1, ) + (self.ksize, self.ksize)
63 | weight_space = self.weight_space.view(*weight_space_dim).expand_as(weight_density)
64 |
65 | # get the final kernel weight
66 | weight = weight_density * weight_space
67 | weight_sum = weight.sum(dim=(-1, -2))
68 | x = (weight * x_patches).sum(dim=(-1, -2)) / weight_sum
69 | return x
70 |
--------------------------------------------------------------------------------
/model/VitExtractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def attn_cosine_sim(x, eps=1e-08):
6 | x = x[0] # TEMP: getting rid of redundant dimension, TBF
7 | norm1 = x.norm(dim=2, keepdim=True)
8 | factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps)
9 | sim_matrix = (x @ x.permute(0, 2, 1)) / factor
10 | return sim_matrix
11 |
12 |
13 | class VitExtractor(nn.Module):
14 | BLOCK_KEY = 'block'
15 | ATTN_KEY = 'attn'
16 | PATCH_IMD_KEY = 'patch_imd'
17 | QKV_KEY = 'qkv'
18 | KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY]
19 |
20 | def __init__(self, model_name, device):
21 | super().__init__()
22 | self.model = torch.hub.load(
23 | 'facebookresearch/dino:main', model_name).to(device)
24 | self.model.eval()
25 | self.model_name = model_name
26 | self.hook_handlers = []
27 | self.layers_dict = {}
28 | self.outputs_dict = {}
29 | for key in VitExtractor.KEY_LIST:
30 | self.layers_dict[key] = []
31 | self.outputs_dict[key] = []
32 | self._init_hooks_data()
33 |
34 | def _init_hooks_data(self):
35 | self.layers_dict[VitExtractor.BLOCK_KEY] = [
36 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
37 | self.layers_dict[VitExtractor.ATTN_KEY] = [
38 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
39 | self.layers_dict[VitExtractor.QKV_KEY] = [
40 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
41 | self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [
42 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
43 | for key in VitExtractor.KEY_LIST:
44 | # self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else []
45 | self.outputs_dict[key] = []
46 |
47 | def _register_hooks(self, **kwargs):
48 | for block_idx, block in enumerate(self.model.blocks):
49 | if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]:
50 | self.hook_handlers.append(
51 | block.register_forward_hook(self._get_block_hook()))
52 | if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]:
53 | self.hook_handlers.append(
54 | block.attn.attn_drop.register_forward_hook(self._get_attn_hook()))
55 | if block_idx in self.layers_dict[VitExtractor.QKV_KEY]:
56 | self.hook_handlers.append(
57 | block.attn.qkv.register_forward_hook(self._get_qkv_hook()))
58 | if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]:
59 | self.hook_handlers.append(
60 | block.attn.register_forward_hook(self._get_patch_imd_hook()))
61 |
62 | def _clear_hooks(self):
63 | for handler in self.hook_handlers:
64 | handler.remove()
65 | self.hook_handlers = []
66 |
67 | def _get_block_hook(self):
68 | def _get_block_output(model, input, output):
69 | self.outputs_dict[VitExtractor.BLOCK_KEY].append(output)
70 |
71 | return _get_block_output
72 |
73 | def _get_attn_hook(self):
74 | def _get_attn_output(model, inp, output):
75 | self.outputs_dict[VitExtractor.ATTN_KEY].append(output)
76 |
77 | return _get_attn_output
78 |
79 | def _get_qkv_hook(self):
80 | def _get_qkv_output(model, inp, output):
81 | self.outputs_dict[VitExtractor.QKV_KEY].append(output)
82 |
83 | return _get_qkv_output
84 |
85 | # TODO: CHECK ATTN OUTPUT TUPLE
86 | def _get_patch_imd_hook(self):
87 | def _get_attn_output(model, inp, output):
88 | self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0])
89 |
90 | return _get_attn_output
91 |
92 | def get_feature_from_input(self, input_img): # List([B, N, D])
93 | self._register_hooks()
94 | self.model(input_img)
95 | feature = self.outputs_dict[VitExtractor.BLOCK_KEY]
96 | self._clear_hooks()
97 | self._init_hooks_data()
98 | return feature
99 |
100 | def get_qkv_feature_from_input(self, input_img):
101 | self._register_hooks()
102 | self.model(input_img)
103 | feature = self.outputs_dict[VitExtractor.QKV_KEY]
104 | self._clear_hooks()
105 | self._init_hooks_data()
106 | return feature
107 |
108 | def get_attn_feature_from_input(self, input_img):
109 | self._register_hooks()
110 | self.model(input_img)
111 | feature = self.outputs_dict[VitExtractor.ATTN_KEY]
112 | self._clear_hooks()
113 | self._init_hooks_data()
114 | return feature
115 |
116 | def get_patch_size(self):
117 | return 8 if "8" in self.model_name else 16
118 |
119 | def get_width_patch_num(self, input_img_shape):
120 | b, c, h, w = input_img_shape
121 | patch_size = self.get_patch_size()
122 | return w // patch_size
123 |
124 | def get_height_patch_num(self, input_img_shape):
125 | b, c, h, w = input_img_shape
126 | patch_size = self.get_patch_size()
127 | return h // patch_size
128 |
129 | def get_patch_num(self, input_img_shape):
130 | patch_num = 1 + (self.get_height_patch_num(input_img_shape)
131 | * self.get_width_patch_num(input_img_shape))
132 | return patch_num
133 |
134 | def get_head_num(self):
135 | if "dino" in self.model_name:
136 | return 6 if "s" in self.model_name else 12
137 | return 6 if "small" in self.model_name else 12
138 |
139 | def get_embedding_dim(self):
140 | if "dino" in self.model_name:
141 | return 384 if "s" in self.model_name else 768
142 | return 384 if "small" in self.model_name else 768
143 |
144 | def get_queries_from_qkv(self, qkv, input_img_shape):
145 | patch_num = self.get_patch_num(input_img_shape)
146 | head_num = self.get_head_num()
147 | embedding_dim = self.get_embedding_dim()
148 | q = qkv.reshape(patch_num, 3, head_num, embedding_dim //
149 | head_num).permute(1, 2, 0, 3)[0]
150 | return q
151 |
152 | def get_keys_from_qkv(self, qkv, input_img_shape):
153 | patch_num = self.get_patch_num(input_img_shape)
154 | head_num = self.get_head_num()
155 | embedding_dim = self.get_embedding_dim()
156 | k = qkv.reshape(patch_num, 3, head_num, embedding_dim //
157 | head_num).permute(1, 2, 0, 3)[1]
158 | return k
159 |
160 | def get_values_from_qkv(self, qkv, input_img_shape):
161 | patch_num = self.get_patch_num(input_img_shape)
162 | head_num = self.get_head_num()
163 | embedding_dim = self.get_embedding_dim()
164 | v = qkv.reshape(patch_num, 3, head_num, embedding_dim //
165 | head_num).permute(1, 2, 0, 3)[2]
166 | return v
167 |
168 | def get_keys_from_input(self, input_img, layer_num):
169 | qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num]
170 | keys = self.get_keys_from_qkv(qkv_features, input_img.shape)
171 | return keys
172 |
173 | def get_keys_self_sim_from_input(self, input_img, layer_num):
174 | keys = self.get_keys_from_input(input_img, layer_num=layer_num)
175 | h, t, d = keys.shape
176 | concatenated_keys = keys.transpose(0, 1).reshape(t, h * d)
177 | ssim_map = attn_cosine_sim(concatenated_keys[None, None, ...])
178 | return ssim_map
179 |
--------------------------------------------------------------------------------
/mpi/homography_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.spatial.transform import Rotation
4 |
5 |
6 | def inverse(matrices):
7 | """
8 | torch.inverse() sometimes produces outputs with nan the when batch size is 2.
9 | Ref https://github.com/pytorch/pytorch/issues/47272
10 | this function keeps inversing the matrix until successful or maximum tries is reached
11 | :param matrices Bx3x3
12 | """
13 | inverse = None
14 | max_tries = 5
15 | while (inverse is None) or (torch.isnan(inverse)).any():
16 | torch.cuda.synchronize()
17 | inverse = torch.inverse(matrices)
18 |
19 | # Break out of the loop when the inverse is successful or there"re no more tries
20 | max_tries -= 1
21 | if max_tries == 0:
22 | break
23 |
24 | # Raise an Exception if the inverse contains nan
25 | if (torch.isnan(inverse)).any():
26 | raise Exception("Matrix inverse contains nan!")
27 | return inverse
28 |
29 |
30 | class HomographySample:
31 | def __init__(self, H_tgt, W_tgt, device=None):
32 | if device is None:
33 | self.device = torch.device("cpu")
34 | else:
35 | self.device = device
36 |
37 | self.Height_tgt = H_tgt
38 | self.Width_tgt = W_tgt
39 | self.meshgrid = self.grid_generation(self.Height_tgt, self.Width_tgt, self.device)
40 | self.meshgrid = self.meshgrid.permute(2, 0, 1).contiguous() # 3xHxW
41 |
42 | self.n = self.plane_normal_generation(self.device)
43 |
44 | @staticmethod
45 | def grid_generation(H, W, device):
46 | x = np.linspace(0, W-1, W)
47 | y = np.linspace(0, H-1, H)
48 | xv, yv = np.meshgrid(x, y) # HxW
49 | xv = torch.from_numpy(xv.astype(np.float32)).to(dtype=torch.float32, device=device)
50 | yv = torch.from_numpy(yv.astype(np.float32)).to(dtype=torch.float32, device=device)
51 | ones = torch.ones_like(xv)
52 | meshgrid = torch.stack((xv, yv, ones), dim=2) # HxWx3
53 | return meshgrid
54 |
55 | @staticmethod
56 | def plane_normal_generation(device):
57 | n = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)
58 | return n
59 |
60 | @staticmethod
61 | def euler_to_rotation_matrix(x_angle, y_angle, z_angle, seq='xyz', degrees=False):
62 | """
63 | Note that here we want to return a rotation matrix rot_mtx, which transform the tgt points into src frame,
64 | i.e, rot_mtx * p_tgt = p_src
65 | Therefore we need to add negative to x/y/z_angle
66 | :param roll:
67 | :param pitch:
68 | :param yaw:
69 | :return:
70 | """
71 | r = Rotation.from_euler(seq,
72 | [-x_angle, -y_angle, -z_angle],
73 | degrees=degrees)
74 | rot_mtx = r.as_matrix().astype(np.float32)
75 | return rot_mtx
76 |
77 |
78 | def sample(self, src_BCHW, d_src_B,
79 | G_tgt_src,
80 | K_src_inv, K_tgt):
81 | """
82 | Coordinate system: x, y are the image directions, z is pointing to depth direction
83 | :param src_BCHW: torch tensor float, 0-1, rgb/rgba. BxCxHxW
84 | Assume to be at position P=[I|0]
85 | :param d_src_B: distance of image plane to src camera origin
86 | :param G_tgt_src: Bx4x4
87 | :param K_src_inv: Bx3x3
88 | :param K_tgt: Bx3x3
89 | :return: tgt_BCHW
90 | """
91 | # parameter processing ------ begin ------
92 | B, channels, Height_src, Width_src = src_BCHW.size(0), src_BCHW.size(1), src_BCHW.size(2), src_BCHW.size(3)
93 | R_tgt_src = G_tgt_src[:, 0:3, 0:3]
94 | t_tgt_src = G_tgt_src[:, 0:3, 3]
95 |
96 | Height_tgt = self.Height_tgt
97 | Width_tgt = self.Width_tgt
98 |
99 | R_tgt_src = R_tgt_src.to(device=src_BCHW.device)
100 | t_tgt_src = t_tgt_src.to(device=src_BCHW.device)
101 | K_src_inv = K_src_inv.to(device=src_BCHW.device)
102 | K_tgt = K_tgt.to(device=src_BCHW.device)
103 | # parameter processing ------ end ------
104 |
105 | # the goal is compute H_src_tgt, that maps a tgt pixel to src pixel
106 | # so we compute H_tgt_src first, and then inverse
107 | n = self.n.to(device=src_BCHW.device)
108 | n = n.unsqueeze(0).repeat(B, 1) # Bx3
109 | # Bx3x3 - (Bx3x1 * Bx1x3)
110 | # note here we use -d_src, because the plane function is n^T * X - d_src = 0
111 | d_src_B33 = d_src_B.reshape(B, 1, 1).repeat(1, 3, 3) # B -> Bx3x3
112 | R_tnd = R_tgt_src - torch.matmul(t_tgt_src.unsqueeze(2), n.unsqueeze(1)) / -d_src_B33
113 | H_tgt_src = torch.matmul(K_tgt,
114 | torch.matmul(R_tnd, K_src_inv))
115 |
116 | # TODO: fix cuda inverse
117 | with torch.no_grad():
118 | H_src_tgt = inverse(H_tgt_src)
119 |
120 | # create tgt image grid, and map to src
121 | meshgrid_tgt_homo = self.meshgrid.to(src_BCHW.device)
122 | # 3xHxW -> Bx3xHxW
123 | meshgrid_tgt_homo = meshgrid_tgt_homo.unsqueeze(0).expand(B, 3, Height_tgt, Width_tgt)
124 |
125 | # wrap meshgrid_tgt_homo to meshgrid_src
126 | meshgrid_tgt_homo_B3N = meshgrid_tgt_homo.view(B, 3, -1) # Bx3xHW
127 | meshgrid_src_homo_B3N = torch.matmul(H_src_tgt, meshgrid_tgt_homo_B3N) # Bx3x3 * Bx3xHW -> Bx3xHW
128 | # Bx3xHW -> Bx3xHxW -> BxHxWx3
129 | meshgrid_src_homo = meshgrid_src_homo_B3N.view(B, 3, Height_tgt, Width_tgt).permute(0, 2, 3, 1)
130 | meshgrid_src = meshgrid_src_homo[:, :, :, 0:2] / meshgrid_src_homo[:, :, :, 2:] # BxHxWx2
131 |
132 | valid_mask_x = torch.logical_and(meshgrid_src[:, :, :, 0] < Width_src,
133 | meshgrid_src[:, :, :, 0] > -1)
134 | valid_mask_y = torch.logical_and(meshgrid_src[:, :, :, 1] < Height_src,
135 | meshgrid_src[:, :, :, 1] > -1)
136 | valid_mask = torch.logical_and(valid_mask_x, valid_mask_y) # BxHxW
137 |
138 | # sample from src_BCHW
139 | # normalize meshgrid_src to [-1,1]
140 | meshgrid_src[:, :, :, 0] = (meshgrid_src[:, :, :, 0]+0.5) / (Width_src * 0.5) - 1
141 | meshgrid_src[:, :, :, 1] = (meshgrid_src[:, :, :, 1]+0.5) / (Height_src * 0.5) - 1
142 | tgt_BCHW = torch.nn.functional.grid_sample(src_BCHW, grid=meshgrid_src, padding_mode='border',
143 | align_corners=False)
144 | # BxCxHxW, BxHxW
145 | return tgt_BCHW, valid_mask
146 |
--------------------------------------------------------------------------------
/mpi/mpi_rendering.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from mpi.homography_sampler import HomographySample
4 | from mpi.rendering_utils import transform_G_xyz
5 |
6 |
7 | def render(rgb_BS3HW, sigma_BS1HW, xyz_BS3HW):
8 | imgs_syn, weights = plane_volume_rendering(sigma_BS1HW, rgb_BS3HW)
9 | return imgs_syn
10 |
11 | def plane_volume_rendering(alpha_BK1HW, value_BKCHW):
12 | B, K, _, H, W = alpha_BK1HW.size()
13 | alpha_comp_cumprod = torch.cumprod(1 - alpha_BK1HW, dim=1) # BxKx1xHxW
14 | preserve_ratio = torch.cat((torch.ones((B, 1, 1, H, W), dtype=alpha_BK1HW.dtype, device=alpha_BK1HW.device),
15 | alpha_comp_cumprod[:, 0:K-1, :, :, :]), dim=1) # BxKx1xHxW
16 | weights = alpha_BK1HW * preserve_ratio # BxKx1xHxW
17 | value_composed = torch.sum(value_BKCHW * weights, dim=1, keepdim=False) # Bx3xHxW
18 | return value_composed, weights
19 |
20 | def get_src_xyz_from_plane_disparity(meshgrid_src_homo,
21 | mpi_disparity_src,
22 | K_src_inv):
23 | """
24 |
25 | :param meshgrid_src_homo: 3xHxW
26 | :param mpi_disparity_src: BxS
27 | :param K_src_inv: Bx3x3
28 | :return:
29 | """
30 | B, S = mpi_disparity_src.size()
31 | H, W = meshgrid_src_homo.size(1), meshgrid_src_homo.size(2)
32 | mpi_depth_src = torch.reciprocal(mpi_disparity_src) # BxS
33 |
34 | K_src_inv_Bs33 = K_src_inv.unsqueeze(1).repeat(1, S, 1, 1).reshape(B * S, 3, 3)
35 |
36 | # 3xHxW -> BxSx3xHxW
37 | meshgrid_src_homo = meshgrid_src_homo.unsqueeze(0).unsqueeze(1).repeat(B, S, 1, 1, 1)
38 | meshgrid_src_homo_Bs3N = meshgrid_src_homo.reshape(B * S, 3, -1)
39 | xyz_src = torch.matmul(K_src_inv_Bs33, meshgrid_src_homo_Bs3N) # BSx3xHW
40 | xyz_src = xyz_src.reshape(B, S, 3, H * W) * mpi_depth_src.unsqueeze(2).unsqueeze(3) # BxSx3xHW
41 | xyz_src_BS3HW = xyz_src.reshape(B, S, 3, H, W)
42 |
43 | return xyz_src_BS3HW
44 |
45 |
46 | def get_tgt_xyz_from_plane_disparity(xyz_src_BS3HW,
47 | G_tgt_src):
48 | """
49 |
50 | :param xyz_src_BS3HW: BxSx3xHxW
51 | :param G_tgt_src: Bx4x4
52 | :return:
53 | """
54 | B, S, _, H, W = xyz_src_BS3HW.size()
55 | G_tgt_src_Bs33 = G_tgt_src.unsqueeze(1).repeat(1, S, 1, 1).reshape(B*S, 4, 4)
56 | xyz_tgt = transform_G_xyz(G_tgt_src_Bs33, xyz_src_BS3HW.reshape(B*S, 3, H*W)) # Bsx3xHW
57 | xyz_tgt_BS3HW = xyz_tgt.reshape(B, S, 3, H, W) # BxSx3xHxW
58 | return xyz_tgt_BS3HW
59 |
60 |
61 | def render_tgt_rgb_depth(H_sampler: HomographySample,
62 | mpi_rgb_src,
63 | mpi_sigma_src,
64 | mpi_disparity_src,
65 | xyz_tgt_BS3HW,
66 | G_tgt_src,
67 | K_src_inv, K_tgt,
68 | only_render_in_fov=False,
69 | center_top_left=None,
70 | infov_size=512):
71 | """
72 | :param H_sampler:
73 | :param mpi_rgb_src: BxSx3xHxW
74 | :param mpi_sigma_src: BxSx1xHxW
75 | :param mpi_disparity_src: BxS
76 | :param xyz_tgt_BS3HW: BxSx3xHxW
77 | :param G_tgt_src: Bx4x4
78 | :param K_src_inv: Bx3x3
79 | :param K_tgt: Bx3x3
80 | :return:
81 | """
82 | B, S, _, H, W = mpi_rgb_src.size()
83 | mpi_depth_src = torch.reciprocal(mpi_disparity_src) # BxS
84 |
85 | # note that here we concat the mpi_src with xyz_tgt, because H_sampler will sample them for tgt frame
86 | # mpi_src is the same in whatever frame, but xyz has to be in tgt frame
87 | mpi_xyz_src = torch.cat((mpi_rgb_src, mpi_sigma_src, xyz_tgt_BS3HW), dim=2) # BxSx(3+1+3)xHxW
88 |
89 | # homography warping of mpi_src into tgt frame
90 | G_tgt_src_Bs44 = G_tgt_src.unsqueeze(1).repeat(1, S, 1, 1).contiguous().reshape(B*S, 4, 4) # Bsx4x4
91 | K_src_inv_Bs33 = K_src_inv.unsqueeze(1).repeat(1, S, 1, 1).contiguous().reshape(B*S, 3, 3) # Bsx3x3
92 | K_tgt_Bs33 = K_tgt.unsqueeze(1).repeat(1, S, 1, 1).contiguous().reshape(B*S, 3, 3) # Bsx3x3
93 |
94 | # BsxCxHxW, BsxHxW
95 | tgt_mpi_xyz_BsCHW, tgt_mask_BsHW = H_sampler.sample(mpi_xyz_src.view(B*S, 7, H, W),
96 | mpi_depth_src.view(B*S),
97 | G_tgt_src_Bs44,
98 | K_src_inv_Bs33,
99 | K_tgt_Bs33)
100 |
101 | # mpi composition
102 | if only_render_in_fov:
103 | tgt_mpi_xyz = tgt_mpi_xyz_BsCHW.view(B, S, 7, H, W)
104 | tgt_rgb_BS3HW = tgt_mpi_xyz[:, :, 0:3, center_top_left[0]:center_top_left[0] + infov_size, center_top_left[1]:center_top_left[1] + infov_size]
105 | tgt_sigma_BS1HW = tgt_mpi_xyz[:, :, 3:4, center_top_left[0]:center_top_left[0] + infov_size, center_top_left[1]:center_top_left[1] + infov_size]
106 | tgt_xyz_BS3HW = tgt_mpi_xyz[:, :, 4:, center_top_left[0]:center_top_left[0] + infov_size, center_top_left[1]:center_top_left[1] + infov_size]
107 | H = W = infov_size
108 | else:
109 | tgt_mpi_xyz = tgt_mpi_xyz_BsCHW.view(B, S, 7, H, W)
110 | tgt_rgb_BS3HW = tgt_mpi_xyz[:, :, 0:3, :, :]
111 | tgt_sigma_BS1HW = tgt_mpi_xyz[:, :, 3:4, :, :]
112 | tgt_xyz_BS3HW = tgt_mpi_xyz[:, :, 4:, :, :]
113 |
114 | # Bx3xHxW, Bx1xHxW, Bx1xHxW
115 | tgt_z_BS1HW = tgt_xyz_BS3HW[:, :, -1:]
116 | tgt_sigma_BS1HW = torch.where(tgt_z_BS1HW >= 0,
117 | tgt_sigma_BS1HW,
118 | torch.zeros_like(tgt_sigma_BS1HW, device=tgt_sigma_BS1HW.device))
119 | tgt_rgb_syn = render(tgt_rgb_BS3HW, tgt_sigma_BS1HW, tgt_xyz_BS3HW)
120 |
121 | return tgt_rgb_syn
122 |
--------------------------------------------------------------------------------
/mpi/rendering_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def transform_G_xyz(G, xyz, is_return_homo=False):
5 | """
6 |
7 | :param G: Bx4x4
8 | :param xyz: Bx3xN
9 | :return:
10 | """
11 | assert len(G.size()) == len(xyz.size())
12 | if len(G.size()) == 2:
13 | G_B44 = G.unsqueeze(0)
14 | xyz_B3N = xyz.unsqueeze(0)
15 | else:
16 | G_B44 = G
17 | xyz_B3N = xyz
18 | xyz_B4N = torch.cat((xyz_B3N, torch.ones_like(xyz_B3N[:, 0:1, :])), dim=1)
19 | G_xyz_B4N = torch.matmul(G_B44, xyz_B4N)
20 | if is_return_homo:
21 | return G_xyz_B4N
22 | else:
23 | return G_xyz_B4N[:, 0:3, :]
24 |
25 |
26 | def gather_pixel_by_pxpy(img, pxpy):
27 | """
28 |
29 | :param img: Bx3xHxW
30 | :param pxpy: Bx2xN
31 | :return:
32 | """
33 | with torch.no_grad():
34 | B, C, H, W = img.size()
35 | if pxpy.dtype == torch.float32:
36 | pxpy_int = torch.round(pxpy).to(torch.int64)
37 | pxpy_int = pxpy_int.to(torch.int64)
38 | pxpy_int[:, 0, :] = torch.clamp(pxpy_int[:, 0, :], min=0, max=W-1)
39 | pxpy_int[:, 1, :] = torch.clamp(pxpy_int[:, 1, :], min=0, max=H-1)
40 | pxpy_idx = pxpy_int[:, 0:1, :] + W * pxpy_int[:, 1:2, :] # Bx1xN_pt
41 | rgb = torch.gather(img.view(B, C, H * W), dim=2,
42 | index=pxpy_idx.repeat(1, C, 1)) # BxCxN_pt
43 | return rgb
44 |
45 |
46 | def uniformly_sample_disparity_from_bins(batch_size, disparity_np, device):
47 | """
48 | In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far)
49 | :param start:
50 | :param end:
51 | :param num_bins:
52 | :return:
53 | """
54 | assert disparity_np[0] > disparity_np[-1]
55 | S = disparity_np.shape[0] - 1
56 |
57 | B = batch_size
58 | bin_edges = torch.from_numpy(disparity_np).to(dtype=torch.float32, device=device) # S+1
59 | interval = bin_edges[1:] - bin_edges[0:-1] # S
60 | bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS
61 | # bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS
62 | interval = interval.unsqueeze(0).repeat(B, 1) # S -> BxS
63 |
64 | random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS
65 | disparity_array = bin_edges_start + interval * random_float
66 | return disparity_array # BxS
67 |
68 |
69 | def uniformly_sample_disparity_from_linspace_bins(batch_size, num_bins, start, end, device):
70 | """
71 | In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far)
72 | :param start:
73 | :param end:
74 | :param num_bins:
75 | :return:
76 | """
77 | assert start > end
78 |
79 | B, S = batch_size, num_bins
80 | bin_edges = torch.linspace(start, end, num_bins+1, dtype=torch.float32, device=device) # S+1
81 | interval = bin_edges[1] - bin_edges[0] # scalar
82 | bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS
83 | # bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS
84 |
85 | random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS
86 | disparity_array = bin_edges_start + interval * random_float
87 | return disparity_array # BxS
88 |
89 |
90 | def sample_pdf(values, weights, N_samples):
91 | """
92 | draw samples from distribution approximated by values and weights.
93 | the probability distribution can be denoted as weights = p(values)
94 | :param values: Bx1xNxS
95 | :param weights: Bx1xNxS
96 | :param N_samples: number of sample to draw
97 | :return:
98 | """
99 | B, N, S = weights.size(0), weights.size(2), weights.size(3)
100 | assert values.size() == (B, 1, N, S)
101 |
102 | # convert values to bin edges
103 | bin_edges = (values[:, :, :, 1:] + values[:, :, :, :-1]) * 0.5 # Bx1xNxS-1
104 | bin_edges = torch.cat((values[:, :, :, 0:1],
105 | bin_edges,
106 | values[:, :, :, -1:]), dim=3) # Bx1xNxS+1
107 |
108 | pdf = weights / (torch.sum(weights, dim=3, keepdim=True) + 1e-5) # Bx1xNxS
109 | cdf = torch.cumsum(pdf, dim=3) # Bx1xNxS
110 | cdf = torch.cat((torch.zeros((B, 1, N, 1), dtype=cdf.dtype, device=cdf.device),
111 | cdf), dim=3) # Bx1xNxS+1
112 |
113 | # uniform sample over the cdf values
114 | u = torch.rand((B, 1, N, N_samples), dtype=weights.dtype, device=weights.device) # Bx1xNxN_samples
115 |
116 | # get the index on the cdf array
117 | cdf_idx = torch.searchsorted(cdf, u, right=True) # Bx1xNxN_samples
118 | cdf_idx_lower = torch.clamp(cdf_idx-1, min=0) # Bx1xNxN_samples
119 | cdf_idx_upper = torch.clamp(cdf_idx, max=S) # Bx1xNxN_samples
120 |
121 | # linear approximation for each bin
122 | cdf_idx_lower_upper = torch.cat((cdf_idx_lower, cdf_idx_upper), dim=3) # Bx1xNx(N_samplesx2)
123 | cdf_bounds_N2 = torch.gather(cdf, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2)
124 | cdf_bounds = torch.stack((cdf_bounds_N2[..., 0:N_samples], cdf_bounds_N2[..., N_samples:]), dim=4)
125 | bin_bounds_N2 = torch.gather(bin_edges, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2)
126 | bin_bounds = torch.stack((bin_bounds_N2[..., 0:N_samples], bin_bounds_N2[..., N_samples:]), dim=4)
127 |
128 | # avoid zero cdf_intervals
129 | cdf_intervals = cdf_bounds[:, :, :, :, 1] - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples
130 | bin_intervals = bin_bounds[:, :, :, :, 1] - bin_bounds[:, :, :, :, 0] # Bx1xNxN_samples
131 | u_cdf_lower = u - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples
132 | # there is the case that cdf_interval = 0, caused by the cdf_idx_lower/upper clamp above, need special handling
133 | t = u_cdf_lower / torch.clamp(cdf_intervals, min=1e-5)
134 | t = torch.where(cdf_intervals <= 1e-4,
135 | torch.full_like(u_cdf_lower, 0.5),
136 | t)
137 |
138 | samples = bin_bounds[:, :, :, :, 0] + t*bin_intervals
139 | return samples
140 |
--------------------------------------------------------------------------------
/outpaint_rgbd.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import random
4 | import os
5 | import torchvision
6 | import torchvision.transforms as T
7 |
8 | # pytorch-lightning
9 | from pytorch_lightning import LightningModule, Trainer
10 |
11 | from dataloaders.single_img_data import SinImgDataset
12 | from diffusers import StableDiffusionInpaintPipeline
13 | from utils_for_train import tensor_to_depth
14 | from pytorch_lightning import seed_everything
15 |
16 | class SDOutpainter(LightningModule):
17 | def __init__(self, opt):
18 | super(SDOutpainter, self).__init__()
19 |
20 | self.opt = opt
21 | self.loss = []
22 |
23 | W, H = self.opt.width, self.opt.height
24 | self.save_base_dir = f'ckpts/{opt.ckpt_path}'
25 | if not os.path.exists(self.save_base_dir):
26 | os.makedirs(self.save_base_dir)
27 |
28 | self.dataset_type = 'SinImgDataset'
29 | self.train_dataset = SinImgDataset(img_path=self.opt.img_path, width=W, height=H, repeat_times=1)
30 | self.extrapolate_times = self.opt.extrapolate_times
31 |
32 | if self.extrapolate_times == 3: # extend w = 3 * w
33 | self.center_top_left = (self.opt.height, self.opt.width)
34 | elif self.extrapolate_times == 2: # extend w = 2 * w
35 | self.center_top_left = (self.opt.height//2, self.opt.width//2)
36 | elif self.extrapolate_times == 1:
37 | self.center_top_left = (0, 0)
38 |
39 | with torch.no_grad():
40 | self.sd = StableDiffusionInpaintPipeline.from_pretrained(
41 | "stabilityai/stable-diffusion-2-inpainting",
42 | torch_dtype=torch.float16,
43 | local_files_only=True,
44 | use_auth_token=""
45 | ).to("cuda:0")
46 | self.extrapolate_RGBDs = self.gen_extrapolate_RGBDs()
47 | torch.save(self.extrapolate_RGBDs, self.save_base_dir + "/" + "extrapolate_RGBDs.pkl")
48 |
49 | def gen_extrapolate_RGBDs(self):
50 |
51 | self.prompt = ["continuous sky, without animal, without text, without copy", #for up
52 | "continuous scene, without text, without copy", #for mid
53 | "continuous sea, without text" #for down
54 | ]
55 |
56 | ref_img = self.train_dataset.ref_img.cpu()
57 | depth = tensor_to_depth(ref_img.cuda()).cpu()
58 |
59 | ref_depth = depth
60 |
61 | rgbd = (ref_img.cuda(), ref_depth.cuda())
62 |
63 | _,_,h,w = ref_img.shape
64 |
65 | canvas = torch.zeros(1,3,h*self.extrapolate_times,w*self.extrapolate_times)
66 | mask = torch.zeros(1,1,h*self.extrapolate_times,w*self.extrapolate_times)
67 |
68 | if self.extrapolate_times == 3: # extend w = 3 * w
69 | top_left_points = [
70 | (h//2,w), (0,w), #top
71 | (h + h//2,w), (h + h,w), #down
72 | (h,w//2),(h,0), #left
73 | (h, w + w//2), (h, w + w), #right
74 |
75 | (h//2,w//2), (0,w//2), (h//2,0) ,(0,0), #top left
76 | (h + h//2,w//2), (h + h//2,0), (h + h,w//2), (h + h,0), #down left
77 | (h//2,w + w//2), (0,w + w//2), (h//2,w + w), (0 ,w + w), #top right
78 | (h + h//2, w + w//2), (h + h//2, w + w), (h + h, w + w//2), (h + h, w + w), #down right
79 | ]
80 | up = [0,1,8,9,10,11,16,17,18,19]
81 | mid = [4,5,6,7]
82 | down = [2,3,12,13,14,15,20,21,22,23]
83 | elif self.extrapolate_times == 2: # extend w = 2 * w
84 | top_left_points = [
85 | (0,w//2), #top
86 | (h,w//2), #down
87 | (h//2,0), #left
88 | (h//2,w), #right
89 |
90 | (0,0), #top left
91 | (h,0), #down left
92 | (0,w), #top right
93 | (h,w), #down right
94 | ]
95 | up = [0,4,6]
96 | mid = [2,3]
97 | down = [1,3,5,7]
98 | elif self.extrapolate_times == 1:
99 | top_left_points = []
100 | # return rgbd
101 |
102 | canvas[:,:,self.center_top_left[0]:self.center_top_left[0] + h, self.center_top_left[1]:self.center_top_left[1] + w] = ref_img
103 | mask[:,:,self.center_top_left[0]:self.center_top_left[0] + h, self.center_top_left[1]:self.center_top_left[1] + w] = torch.ones(1,1,h,w)
104 |
105 | for i, point in enumerate(top_left_points):
106 | canvas_part = canvas[:,:,point[0]:point[0]+ h, point[1]:point[1]+ w]
107 | mask_part = mask[:,:,point[0]:point[0]+ h, point[1]:point[1]+ w]
108 | if i in up:
109 | prompt = self.prompt[0]
110 | elif i in mid:
111 | prompt = self.prompt[1]
112 | else:
113 | prompt = self.prompt[2]
114 | canvas[:,:,point[0]:point[0]+ h, point[1]:point[1]+ w] = self.run_sd(canvas_part, mask_part, prompt, h, w)
115 | mask[:,:,point[0]:point[0]+ h, point[1]:point[1]+ w] = torch.ones(1,1,h,w)
116 |
117 | depth = tensor_to_depth(canvas.cuda())
118 |
119 | align_depth = True
120 | if align_depth:
121 | extrapolate_depth = depth
122 | extrapolate_center_depth = extrapolate_depth[:,:,self.center_top_left[0]:self.center_top_left[0] + h, self.center_top_left[1]:self.center_top_left[1] + w]
123 | # align depth with ref_depth
124 | depth[:,:,:,:] = (depth - extrapolate_center_depth.min())/(extrapolate_center_depth.max() - extrapolate_center_depth.min()) * (ref_depth.max() - ref_depth.min()) + ref_depth.min()
125 |
126 | extrapolate_RGBDs = (canvas.cuda(), depth.cuda())
127 | torchvision.utils.save_image(canvas[0], self.save_base_dir + "/" + "canvas.png")
128 | torchvision.utils.save_image(depth[0], self.save_base_dir + "/" + "canvas_depth.png")
129 | return extrapolate_RGBDs
130 |
131 | def run_sd(self, canvas, mask, prompt, w, h):
132 | # Run sd
133 | # prompt = "room"
134 | transform = T.ToPILImage()
135 | warp_rgb_PIL = transform(canvas[0,...]).convert("RGB").resize((512, 512))
136 | warp_mask_PIL = transform(255 * (1 - mask[0,...].to(torch.int32))).convert("RGB").resize((512, 512))
137 | inpainted_warp_image = self.sd(prompt=prompt, image=warp_rgb_PIL, mask_image=warp_mask_PIL).images[0]
138 | inpainted_warp_image = inpainted_warp_image.resize((h,w))
139 | inpainted_warp_image = T.ToTensor()(inpainted_warp_image).unsqueeze(0)
140 | inpainted_warp_image = canvas * mask + inpainted_warp_image * (1 - mask)
141 | return inpainted_warp_image
142 |
143 |
144 | if __name__ == '__main__':
145 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
146 | parser.add_argument('--img_path', type=str, default="images/0810.png")
147 | parser.add_argument('--disp_path', type=str, default="images/depth/0810.png")
148 | parser.add_argument('--width', type=int, default=384)
149 | parser.add_argument('--height', type=int, default=256)
150 | parser.add_argument('--ckpt_path', type=str, default="ExpX")
151 | parser.add_argument('--debugging', default=False, action="store_true")
152 | parser.add_argument('--extrapolate_times', type=int, default=1)
153 |
154 | opt, _ = parser.parse_known_args()
155 |
156 | seed = 50
157 | seed_everything(seed)
158 |
159 | sd_outpainter = SDOutpainter(opt)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.19.0
2 | diffusers==0.16.1
3 | easydict==1.10
4 | efficientnet-pytorch==0.7.1
5 | einops==0.6.1
6 | huggingface-hub==0.14.1
7 | imageio==2.28.1
8 | imageio-ffmpeg==0.4.8
9 | lightning-utilities==0.8.0
10 | moviepy==1.0.3
11 | numpy==1.21.0
12 | open-clip-torch==2.20.0
13 | open3d==0.17.0
14 | opencv-python==4.7.0.72
15 | Pillow==10.0.0
16 | pytorch-lightning==1.4.2
17 | pytorch3d==0.7.4
18 | timm==0.9.2
19 | torch==1.11.0
20 | torchmetrics==0.5.0
21 | torchvision==0.12.0
22 | tornado==6.2
23 | transformers==4.29.2
24 | trimesh==3.21.7
25 |
--------------------------------------------------------------------------------
/scripts/train_all.sh:
--------------------------------------------------------------------------------
1 | #conda activate SinMPI
2 | cuda=0
3 | width=512
4 | height=512
5 | data_name='Syndney'
6 | extrapolate_times=2
7 |
8 | ckpt_path='Exp-Syndney'
9 | img_path=$data_name'.jpg'
10 |
11 | CUDA_VISIBLE_DEVICES=$cuda python outpaint_rgbd.py \
12 | --width $width \
13 | --height $height \
14 | --ckpt_path $ckpt_path \
15 | --img_path $img_path \
16 | --extrapolate_times $extrapolate_times
17 |
18 | CUDA_VISIBLE_DEVICES=$cuda python train_inpainting.py \
19 | --width $width \
20 | --height $height \
21 | --ckpt_path $ckpt_path \
22 | --img_path $img_path \
23 | --num_epochs 10 \
24 | --extrapolate_times $extrapolate_times \
25 | --batch_size 1 #--load_warp_pairs --debugging
26 |
27 | CUDA_VISIBLE_DEVICES=$cuda python train_mpi.py \
28 | --width $width \
29 | --height $height \
30 | --ckpt_path $ckpt_path \
31 | --img_path $img_path \
32 | --num_epochs 10 \
33 | --extrapolate_times $extrapolate_times \
34 | --batch_size 1 #--debugging #--resume
--------------------------------------------------------------------------------
/test_images/Syndney.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrickyGo/SinMPI/e14a8232ea1c0e6b6680a49a89cc0ce6f6e2409f/test_images/Syndney.jpg
--------------------------------------------------------------------------------
/train_inpainting.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn.functional as F
4 | import math
5 | import random
6 |
7 | from pytorch_lightning import LightningModule, Trainer
8 | from torch.utils.data import DataLoader
9 |
10 | from dataloaders.single_img_data import SinImgDataset
11 | from model.Inpainter import InpaintingModule
12 | from torchvision.utils import save_image
13 | from utils_for_train import VGGPerceptualLoss
14 | from model.VitExtractor import VitExtractor
15 | from utils_for_train import tensor_to_depth
16 | from warpback.utils import (
17 | RGBDRenderer,
18 | image_to_tensor,
19 | transformation_from_parameters,
20 | )
21 | from pytorch_lightning import seed_everything
22 |
23 | class TrainInpaintingModule(LightningModule):
24 | def __init__(self, opt):
25 | super(TrainInpaintingModule, self).__init__()
26 |
27 | self.opt = opt
28 | self.loss = []
29 |
30 | W, H = self.opt.width, self.opt.height
31 | self.save_base_dir = f'ckpts/{opt.ckpt_path}'
32 |
33 | if self.opt.resume:
34 | self.inpaint_module = InpaintingModule()
35 | self.inpaint_module.load_state_dict(torch.load(f'ckpts/{self.opt.ckpt_path}/inpaint_latest.pt'), strict=True)
36 | else:
37 | self.inpaint_module = InpaintingModule()
38 |
39 | self.models = [self.inpaint_module.cuda()]
40 |
41 | # for training
42 | self.extrapolate_times = self.opt.extrapolate_times
43 | self.train_dataset = SinImgDataset(img_path=self.opt.img_path, width=W, height=H, repeat_times=1)
44 |
45 | if self.extrapolate_times == 3: # extend w = 3 * w
46 | self.center_top_left = (self.opt.height, self.opt.width)
47 | elif self.extrapolate_times == 2: # extend w = 2 * w
48 | self.center_top_left = (self.opt.height//2, self.opt.width//2)
49 | elif self.extrapolate_times == 1:
50 | self.center_top_left = (0, 0)
51 |
52 | self.K = torch.tensor([
53 | [0.58, 0, 0.5],
54 | [0, 0.58, 0.5],
55 | [0, 0, 1]
56 | ])
57 |
58 | with torch.no_grad():
59 | if self.extrapolate_times == 1:
60 | ref_img = image_to_tensor(self.save_base_dir + "/" + "canvas.png", unsqueeze=False) # [3,h,w]
61 | ref_img = ref_img.unsqueeze(0).cuda()
62 | if ref_img.shape[1] == 4:
63 | ref_img = ref_img[:,:3,:,:]
64 |
65 | ref_depth = tensor_to_depth(ref_img)
66 | save_image(ref_depth[0,0,...], self.save_base_dir + "/" + "canvas_depth.png")
67 | else:
68 | ref_img, ref_depth = torch.load(self.save_base_dir + "/" + "extrapolate_RGBDs.pkl")
69 |
70 | ref_depth = (ref_depth - ref_depth.min())/(ref_depth.max() - ref_depth.min())
71 |
72 | self.extrapolate_RGBDs = (ref_img.cpu(), ref_depth.cpu())
73 |
74 | if self.opt.load_warp_pairs:
75 | self.inpaint_pairs = torch.load(self.save_base_dir + "/" + "inpaint_pairs.pkl")
76 | else:
77 | self.renderer = RGBDRenderer('cuda:0')
78 | self.inpaint_pairs = self.get_pairs()
79 | torch.save(self.inpaint_pairs,self.save_base_dir + "/" + "inpaint_pairs.pkl")
80 |
81 | self.perceptual_loss = VGGPerceptualLoss()
82 | self.VitExtractor = VitExtractor(
83 | model_name='dino_vits16', device='cuda:0')
84 |
85 | self.renderer_pair_saved = False
86 |
87 |
88 | def configure_optimizers(self):
89 | from torch.optim import SGD, Adam
90 |
91 | parameters = []
92 | for model in self.models:
93 | parameters += list(model.parameters())
94 | self.optimizer = Adam(parameters, lr=5e-4, eps=1e-8, weight_decay=0)
95 |
96 | return [self.optimizer], []
97 |
98 |
99 | def train_dataloader(self):
100 | return DataLoader(self.train_dataset,
101 | shuffle=True,
102 | num_workers=8,
103 | batch_size=self.opt.batch_size,
104 | pin_memory=True)
105 |
106 |
107 | def get_rand_ext(self, bs=1):
108 | def rand_tensor(r, l):
109 | if r < 0:
110 | return torch.zeros((l, 1, 1))
111 | rand = torch.rand((l, 1, 1))
112 | sign = 2 * (torch.randn_like(rand) > 0).float() - 1
113 | return sign * (r / 2 + r / 2 * rand)
114 |
115 | trans_range={"x":0.2, "y":-0.2, "z":-0.2, "a":-0.2, "b":-0.2, "c":-0.2}
116 | x, y, z = trans_range['x'], trans_range['y'], trans_range['z']
117 | a, b, c = trans_range['a'], trans_range['b'], trans_range['c']
118 | cix = rand_tensor(x, bs)
119 | ciy = rand_tensor(y, bs)
120 | ciz = rand_tensor(z, bs)
121 | aix = rand_tensor(math.pi / a, bs)
122 | aiy = rand_tensor(math.pi / b, bs)
123 | aiz = rand_tensor(math.pi / c, bs)
124 |
125 | axisangle = torch.cat([aix, aiy, aiz], dim=-1) # [b,1,3]
126 | translation = torch.cat([cix, ciy, ciz], dim=-1)
127 |
128 | cam_ext = transformation_from_parameters(axisangle, translation) # [b,4,4]
129 | cam_ext_inv = torch.inverse(cam_ext) # [b,4,4]
130 | return cam_ext, cam_ext_inv
131 |
132 | def get_pairs(self):
133 | all_poses = self.train_dataset.all_poses
134 |
135 | aug_pose_factor = 0 # set pose augmentation for better results
136 | cnt = len(all_poses)
137 | if aug_pose_factor > 0:
138 | for i in range(cnt):
139 | cur_pose = torch.FloatTensor(all_poses[i])
140 | for _ in range(aug_pose_factor):
141 | cam_ext, cam_ext_inv = self.get_rand_ext() # [b,4,4]
142 | cur_aug_pose = torch.matmul(cam_ext, cur_pose)
143 | all_poses += [cur_aug_pose]
144 |
145 | ref_depth = self.extrapolate_RGBDs[1]
146 |
147 | ref_img = self.extrapolate_RGBDs[0]
148 | W, H = self.opt.width * self.extrapolate_times, self.opt.height * self.extrapolate_times
149 |
150 | inpaint_pairs = [] #(warp_back_image, warp_back_disp, warp_back_mask, ref_img, ref_depth)
151 | val_pairs = [] #(cam_ext, ref_img, warp_image, warp_disp, warp_mask, gt_img)
152 |
153 | print("all_poses len:",len(all_poses))
154 | for i, cur_pose in enumerate(all_poses[:]):
155 | cur_pose = all_poses[i]
156 | c2w = cur_pose
157 | c2w = torch.FloatTensor(c2w)
158 |
159 | cam_int = self.K.repeat(1, 1, 1) # [b,3,3]
160 |
161 | #load cam_ext
162 | cam_ext = c2w
163 | cam_ext_inv = torch.inverse(cam_ext)
164 | cam_ext = cam_ext.repeat(1, 1, 1)[:,:-1,:]
165 | cam_ext_inv = cam_ext_inv.repeat(1, 1, 1)[:,:-1,:]
166 |
167 | rgbd = torch.cat([ref_img, ref_depth], dim=1).cuda()
168 | cam_int = cam_int.cuda()
169 | cam_ext = cam_ext.cuda()
170 | cam_ext_inv = cam_ext_inv.cuda()
171 |
172 | # warp to a random novel view
173 | mesh = self.renderer.construct_mesh(rgbd, cam_int)
174 | warp_image, warp_disp, warp_mask = self.renderer.render_mesh(mesh, cam_int, cam_ext)
175 |
176 | # warp back to the original view
177 | warp_rgbd = torch.cat([warp_image, warp_disp], dim=1) # [b,4,h,w]
178 | warp_mesh = self.renderer.construct_mesh(warp_rgbd, cam_int)
179 | warp_back_image, warp_back_disp, warp_back_mask = self.renderer.render_mesh(warp_mesh, cam_int, cam_ext_inv)
180 |
181 | ref_depth_2 = ref_depth
182 | # all depth should be in [0~1]
183 | inpaint_pairs.append((ref_img, ref_depth_2, cur_pose,
184 | warp_image, warp_disp, warp_mask,
185 | warp_back_image, warp_back_disp, warp_back_mask))
186 | print("collecting inpaint_pairs:", len(inpaint_pairs))
187 |
188 | return inpaint_pairs
189 |
190 |
191 | def forward(self, renderer_pair):
192 | (ref_img, ref_depth, cur_pose,
193 | warp_rgb, warp_disp, warp_mask,
194 | warp_back_image, warp_back_disp, warp_back_mask) = renderer_pair
195 |
196 | inpainted_warp_image, inpainted_warp_disp = self.inpaint_module(warp_rgb.cuda(), warp_disp.cuda(), warp_mask.cuda())
197 | inpainted_warp_back_image, inpainted_warp_back_disp = self.inpaint_module(warp_back_image.cuda(), warp_back_disp.cuda(), warp_back_mask.cuda())
198 |
199 | return {
200 | "ref_img": ref_img,
201 | "ref_depth": ref_depth,
202 | "warp_image": warp_rgb,
203 | "warp_disp": warp_disp,
204 | "inpainted_warp_image":inpainted_warp_image,
205 | "inpainted_warp_disp":inpainted_warp_disp,
206 | "warp_back_image": warp_back_image,
207 | "warp_back_disp": warp_back_disp,
208 | "inpainted_warp_back_image":inpainted_warp_back_image,
209 | "inpainted_warp_back_disp":inpainted_warp_back_disp,
210 | }
211 |
212 |
213 | def training_step(self, batch, batch_idx, optimizer_idx=0):
214 | renderer_pair = random.choice(self.inpaint_pairs)
215 | batch = self(renderer_pair)
216 |
217 | ref_img = batch["ref_img"].cuda()
218 | ref_depth = batch["ref_depth"].cuda()
219 |
220 | warp_image = batch["warp_image"]
221 | warp_disp = batch["warp_disp"]
222 |
223 | inpainted_warp_image = batch["inpainted_warp_image"]
224 | inpainted_warp_disp = batch["inpainted_warp_disp"]
225 |
226 | warp_back_image = batch["warp_back_image"]
227 | warp_back_disp = batch["warp_back_disp"]
228 |
229 | inpainted_warp_back_image = batch["inpainted_warp_back_image"]
230 | inpainted_warp_back_disp = batch["inpainted_warp_back_disp"]
231 |
232 | # Losses
233 | loss_total = 0
234 | loss_L1 = 0
235 | lambda_loss_L1 = 10
236 |
237 | loss_perc = 0
238 | lambda_loss_perc = 5
239 |
240 | loss_L1 += F.l1_loss(ref_img, inpainted_warp_back_image)
241 | loss_total += loss_L1 * lambda_loss_L1
242 |
243 | loss_perc += self.perceptual_loss(inpainted_warp_image, ref_img) + self.perceptual_loss(inpainted_warp_back_image, ref_img)
244 | loss_total += loss_perc * lambda_loss_perc
245 |
246 | loss_inpainted_vit = 1e-1
247 | lambda_loss_vit = 0
248 | ref_vit_feature = self.get_vit_feature(ref_img)
249 | inpainted_vit_feature = self.get_vit_feature(inpainted_warp_image)
250 | inpainted_warp_back_vit_feature = self.get_vit_feature(inpainted_warp_back_image)
251 | loss_inpainted_vit += F.mse_loss(inpainted_vit_feature, ref_vit_feature) + F.mse_loss(inpainted_warp_back_vit_feature, ref_vit_feature)
252 | loss_total += loss_inpainted_vit * lambda_loss_vit
253 |
254 | loss_depth = 0
255 | lambda_loss_depth = 1
256 | loss_depth += F.l1_loss(ref_depth, inpainted_warp_back_disp)
257 | loss_total += lambda_loss_depth * loss_depth
258 |
259 | if self.opt.debugging:
260 | self.training_epoch_end(None)
261 | assert 0
262 |
263 | return {'loss': loss_total}
264 |
265 |
266 | def training_epoch_end(self, outputs):
267 | with torch.no_grad():
268 |
269 | # pred_frames = []
270 |
271 | self.renderer_pairs = []
272 | for i, inpaint_pair in enumerate(self.inpaint_pairs):
273 | (ref_img, ref_depth, cur_pose,
274 | warp_image, warp_disp, warp_mask,
275 | warp_back_image, warp_back_disp, warp_back_mask) = inpaint_pair
276 | inpainted_warp_image, inpainted_warp_disp = self.inpaint_module(warp_image.cuda(), warp_disp.cuda(), warp_mask.cuda())
277 |
278 | self.renderer_pairs += [(cur_pose,
279 | ref_img,
280 | inpainted_warp_image)]
281 |
282 | torch.save(self.inpaint_module.state_dict(), self.save_base_dir + "/" +"inpaint_latest.pt")
283 | if not self.renderer_pair_saved:
284 | torch.save(self.renderer_pairs, self.save_base_dir + "/" + "renderer_pairs.pkl")
285 | if self.opt.debugging:
286 | assert 0, "no bug"
287 | return
288 |
289 | def get_vit_feature(self, x):
290 | mean = torch.tensor([0.485, 0.456, 0.406],
291 | device=x.device).reshape(1, 3, 1, 1)
292 | std = torch.tensor([0.229, 0.224, 0.225],
293 | device=x.device).reshape(1, 3, 1, 1)
294 | x = F.interpolate(x, size=(224, 224))
295 | x = (x - mean) / std
296 | return self.VitExtractor.get_feature_from_input(x)[-1][0, 0, :]
297 |
298 |
299 | if __name__ == '__main__':
300 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
301 | parser.add_argument('--img_path', type=str, default="test_images/Syndney.jpg")
302 | parser.add_argument('--width', type=int, default=512)
303 | parser.add_argument('--height', type=int, default=512)
304 | parser.add_argument('--ckpt_path', type=str, default="Exp-X")
305 | parser.add_argument('--num_epochs', type=int, default=1)
306 | parser.add_argument('--resume_path', type=str, default=None)
307 | parser.add_argument('--resume', default=False, action="store_true")
308 | parser.add_argument('--batch_size', type=int, default=1)
309 | parser.add_argument('--debugging', default=False, action="store_true")
310 | parser.add_argument('--extrapolate_times', type=int, default=1)
311 | parser.add_argument('--load_warp_pairs', default=False, action="store_true")
312 |
313 | opt, _ = parser.parse_known_args()
314 |
315 | seed = 50
316 | seed_everything(seed)
317 |
318 | system = TrainInpaintingModule(opt)
319 |
320 | trainer = Trainer(max_epochs=opt.num_epochs,
321 | progress_bar_refresh_rate=1,
322 | gpus=1,
323 | num_sanity_val_steps=1)
324 |
325 | trainer.fit(system)
--------------------------------------------------------------------------------
/train_mpi.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn.functional as F
4 | import random
5 | from moviepy.editor import ImageSequenceClip
6 | from utils_for_train import VGGPerceptualLoss
7 | from dataloaders.single_img_data import SinImgDataset
8 | import numpy as np
9 | from pytorch_lightning import LightningModule, Trainer
10 | from torch.utils.data import DataLoader
11 | from model.MPF import MultiPlaneField
12 | from model.VitExtractor import VitExtractor
13 | from mpi.homography_sampler import HomographySample
14 | from mpi import mpi_rendering
15 | from model.TrainableFilter import TrainableFilter
16 | from pytorch_lightning import seed_everything
17 |
18 | class System(LightningModule):
19 | def __init__(self, opt):
20 | super(System, self).__init__()
21 | self.save_base_dir = f'ckpts/{opt.ckpt_path}'
22 | self.opt = opt
23 | self.loss = []
24 | self.models = []
25 | W, H = self.opt.width, self.opt.height
26 |
27 | self.extrapolate_times = self.opt.extrapolate_times
28 | self.train_dataset = SinImgDataset(img_path=self.opt.img_path, width=W, height=H, repeat_times=10)
29 |
30 | W, H = self.opt.width * self.extrapolate_times, self.opt.height * self.extrapolate_times
31 |
32 | self.K = torch.tensor([
33 | [0.58, 0, 0.5],
34 | [0, 0.58, 0.5],
35 | [0, 0, 1]
36 | ])
37 | self.K[0, :] *= W
38 | self.K[1, :] *= H
39 | self.K = self.K.unsqueeze(0)
40 |
41 | if self.extrapolate_times == 3: # extend w = 3 * w
42 | self.center_top_left = (self.opt.height, self.opt.width)
43 | elif self.extrapolate_times == 2: # extend w = 2 * w
44 | self.center_top_left = (self.opt.height//2, self.opt.width//2)
45 | elif self.extrapolate_times == 1:
46 | self.center_top_left = (0, 0)
47 |
48 | with torch.no_grad():
49 | self.extrapolate_RGBDs = torch.load(self.save_base_dir + "/" + "extrapolate_RGBDs.pkl")
50 | img, depth = self.extrapolate_RGBDs
51 | depth = (depth - depth.min())/(depth.max() - depth.min())
52 | self.extrapolate_RGBDs = (img.cuda(), depth.cuda())
53 |
54 | # create MPI
55 | self.num_planes = 64
56 | self.MPF = MultiPlaneField(num_planes=self.num_planes,
57 | image_size=(H, W),
58 | assign_origin_planes=self.extrapolate_RGBDs,
59 | depth_range=[self.extrapolate_RGBDs[1].min()+1e-6, self.extrapolate_RGBDs[1].max()])
60 | self.trainable_filter = TrainableFilter(ksize=3).cuda()
61 |
62 | if self.opt.resume:
63 | self.MPF.load_state_dict(torch.load(f'ckpts/{self.opt.ckpt_path}/MPF_latest.pt'), strict=True)
64 |
65 | self.renderer_pairs = torch.load(self.save_base_dir + "/" + "renderer_pairs.pkl")
66 |
67 | self.models += [self.MPF]
68 | self.models += [self.trainable_filter]
69 |
70 | self.perceptual_loss = VGGPerceptualLoss()
71 | self.VitExtractor = VitExtractor(
72 | model_name='dino_vits16', device='cuda:0')
73 |
74 | def train_dataloader(self):
75 | return DataLoader(self.train_dataset,
76 | shuffle=True,
77 | num_workers=4,
78 | batch_size=self.opt.batch_size,
79 | pin_memory=True)
80 |
81 | def configure_optimizers(self):
82 | from torch.optim import SGD, Adam
83 | from torch.optim.lr_scheduler import MultiStepLR
84 |
85 | parameters = []
86 | for model in self.models:
87 | parameters += list(model.parameters())
88 |
89 | self.optimizer = Adam(parameters, lr=5e-4, eps=1e-8,
90 | weight_decay=0)
91 |
92 | scheduler = MultiStepLR(self.optimizer, milestones=[20],
93 | gamma=0.1)
94 |
95 | return [self.optimizer], [scheduler]
96 |
97 |
98 | def training_step(self, batch, batch_idx, optimizer_idx=0):
99 |
100 | renderer_pair = random.choice(self.renderer_pairs)
101 | (cam_ext, ref_img, inpainted_warp_image) = renderer_pair
102 | cam_ext, ref_img, inpainted_warp_image= cam_ext.cuda(), ref_img.cuda(), inpainted_warp_image.cuda()
103 | # Losses
104 | loss = 0
105 | ref_vit_feature = self.get_vit_feature(ref_img)
106 |
107 | # run MPI forward
108 | frames_tensor = self(cam_ext)
109 |
110 | # mpi losses
111 | loss_L1 = F.l1_loss(frames_tensor, inpainted_warp_image)
112 |
113 | lambda_loss_L1 = 10
114 | loss += loss_L1 * lambda_loss_L1
115 |
116 | lambda_loss_vit = 1e-1
117 | frames_vit_feature = self.get_vit_feature(frames_tensor)
118 | loss_vit = F.mse_loss(frames_vit_feature, ref_vit_feature)
119 | loss += loss_vit * lambda_loss_vit
120 |
121 |
122 | if self.opt.debugging:
123 | # loss *= 0
124 | self.training_epoch_end(None)
125 | assert 0
126 |
127 | return {'loss': loss}
128 |
129 |
130 | def forward(self, cam_ext, only_render_in_fov=False):
131 | # mpi_planes[b,s,4,h,w], mpi_disp[b,s]
132 | mpi = self.MPF()
133 | mpi_all_rgb_src = mpi[:, :, 0:3, :, :]
134 | mpi_all_sigma_src = mpi[:, :, 3:4, :, :]
135 | disparity_all_src = self.MPF.planes_disp
136 |
137 | k_tgt = k_src = self.K
138 | k_src_inv = torch.inverse(k_src).cuda()
139 | k_tgt = k_tgt.cuda()
140 | k_src = k_src.cuda()
141 | h, w = mpi.shape[-2:]
142 | homography_sampler = HomographySample(h, w, "cuda:0")
143 | G_tgt_src = cam_ext.cuda()
144 |
145 | xyz_src_BS3HW = mpi_rendering.get_src_xyz_from_plane_disparity(
146 | homography_sampler.meshgrid,
147 | disparity_all_src,
148 | k_src_inv
149 | )
150 |
151 | xyz_tgt_BS3HW = mpi_rendering.get_tgt_xyz_from_plane_disparity(
152 | xyz_src_BS3HW,
153 | G_tgt_src
154 | )
155 |
156 | tgt_imgs_syn = mpi_rendering.render_tgt_rgb_depth(
157 | homography_sampler,
158 | mpi_all_rgb_src,
159 | mpi_all_sigma_src,
160 | disparity_all_src,
161 | xyz_tgt_BS3HW,
162 | G_tgt_src,
163 | k_src_inv,
164 | k_tgt,
165 | only_render_in_fov=only_render_in_fov,
166 | center_top_left=self.center_top_left
167 | )
168 |
169 | tgt_imgs_syn = self.trainable_filter(tgt_imgs_syn)
170 |
171 | return tgt_imgs_syn
172 |
173 | def training_epoch_end(self, outputs):
174 | with torch.no_grad():
175 |
176 | pred_frames = []
177 | for i, renderer_pair in enumerate(self.renderer_pairs):
178 | (cam_ext, ref_img, inpainted_warp_image) = renderer_pair
179 | # run MPI forward
180 | frames_tensor = self(cam_ext, only_render_in_fov=True)
181 | cam_ext, ref_img, inpainted_warp_image = cam_ext.cuda(), ref_img.cuda(), inpainted_warp_image.cuda()
182 |
183 | pred_frame_np = frames_tensor.squeeze(0).permute(1, 2, 0).contiguous().detach().cpu().numpy() # [b,h,w,3]
184 | pred_frame_np = np.clip(np.round(pred_frame_np * 255), a_min=0, a_max=255).astype(np.uint8)
185 | pred_frames += [pred_frame_np]
186 |
187 | rgb_clip = ImageSequenceClip(pred_frames, fps=10)
188 | save_path = f'ckpts/{self.opt.ckpt_path}/MPI_rendered_views.mp4'
189 | rgb_clip.write_videofile(save_path, verbose=False, codec='mpeg4', logger=None, bitrate='2000k')
190 |
191 | save_base_dir = f'ckpts/{self.opt.ckpt_path}'
192 |
193 | torch.save(self.MPF.state_dict(), save_base_dir + "/" +"MPF_latest.pt")
194 |
195 | if self.opt.debugging:
196 | assert 0, "no bug"
197 | return
198 |
199 | def get_vit_feature(self, x):
200 | mean = torch.tensor([0.485, 0.456, 0.406],
201 | device=x.device).reshape(1, 3, 1, 1)
202 | std = torch.tensor([0.229, 0.224, 0.225],
203 | device=x.device).reshape(1, 3, 1, 1)
204 | x = F.interpolate(x, size=(224, 224))
205 | x = (x - mean) / std
206 | return self.VitExtractor.get_feature_from_input(x)[-1][0, 0, :]
207 |
208 |
209 | if __name__ == '__main__':
210 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
211 | parser.add_argument('--img_path', type=str, default="test_images/Syndney.jpg")
212 | parser.add_argument('--width', type=int, default=512)
213 | parser.add_argument('--height', type=int, default=512)
214 | parser.add_argument('--ckpt_path', type=str, default="Exp-X")
215 | parser.add_argument('--num_epochs', type=int, default=1)
216 | parser.add_argument('--resume_path', type=str, default=None)
217 | parser.add_argument('--batch_size', type=int, default=1)
218 | parser.add_argument('--debugging', default=False, action="store_true")
219 | parser.add_argument('--resume', default=False, action="store_true")
220 | parser.add_argument('--extrapolate_times', type=int, default=1)
221 | opt, _ = parser.parse_known_args()
222 |
223 | seed = 50
224 | seed_everything(seed)
225 |
226 | system = System(opt)
227 |
228 | trainer = Trainer(max_epochs=opt.num_epochs,
229 | resume_from_checkpoint=opt.resume_path,
230 | progress_bar_refresh_rate=1,
231 | gpus=1,
232 | num_sanity_val_steps=1)
233 |
234 | trainer.fit(system)
--------------------------------------------------------------------------------
/utils_for_train.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageOps
2 | import cv2
3 | import torch
4 | import torchvision
5 | from torchvision import transforms
6 |
7 | class VGGPerceptualLoss(torch.nn.Module):
8 | def __init__(self, resize=True):
9 | super(VGGPerceptualLoss, self).__init__()
10 | blocks = []
11 | blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
12 | blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
13 | blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
14 | blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
15 | for bl in blocks:
16 | for p in bl:
17 | p.requires_grad = False
18 | self.blocks = torch.nn.ModuleList(blocks)
19 | self.transform = torch.nn.functional.interpolate
20 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
21 | self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
22 | self.mean.requires_grad = False
23 | self.std.requires_grad = False
24 | self.resize = resize
25 |
26 | def forward(self, syn_imgs, gt_imgs):
27 | syn_imgs = (syn_imgs - self.mean) / self.std
28 | gt_imgs = (gt_imgs - self.mean) / self.std
29 | if self.resize:
30 | syn_imgs = self.transform(syn_imgs, mode="bilinear", size=(224, 224),
31 | align_corners=False)
32 | gt_imgs = self.transform(gt_imgs, mode="bilinear", size=(224, 224),
33 | align_corners=False)
34 |
35 | loss = 0.0
36 | x = syn_imgs
37 | y = gt_imgs
38 | for block in self.blocks:
39 | with torch.no_grad():
40 | x = block(x)
41 | y = block(y)
42 | loss += torch.nn.functional.l1_loss(x, y)
43 | return loss
44 |
45 |
46 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'):
47 | value = (image_pred-image_gt)**2
48 | if valid_mask is not None:
49 | value = value[valid_mask]
50 | if reduction == 'mean':
51 | return torch.mean(value)
52 | return value
53 |
54 | def image_to_tensor(img_path, unsqueeze=True):
55 | im = Image.open(img_path).convert('RGB')
56 | if img_path[-3:] == 'jpg':
57 | im = ImageOps.exif_transpose(im)
58 | rgb = transforms.ToTensor()(im)
59 | if unsqueeze:
60 | rgb = rgb.unsqueeze(0)
61 | return rgb
62 |
63 | def disparity_to_tensor(disp_path, unsqueeze=True):
64 | disp = cv2.imread(disp_path, -1) / (2 ** 16 - 1)
65 | disp = torch.from_numpy(disp)[None, ...]
66 | if unsqueeze:
67 | disp = disp.unsqueeze(0)
68 | return disp.float()
69 |
70 | def tensor_to_depth(tensor): # BCHW
71 | model_type = "DPT_Large"
72 | midas_model = torch.hub.load("/home/pug/.cache/torch/hub/MiDaS", model_type, source='local').cuda()
73 |
74 | midas_transforms = torch.hub.load("/home/pug/.cache/torch/hub/MiDaS", "transforms", source='local')
75 | if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
76 | transform = midas_transforms.dpt_transform
77 | else:
78 | transform = midas_transforms.small_transform
79 |
80 | input_batch = tensor
81 | with torch.no_grad():
82 | prediction = midas_model(input_batch)
83 | output = prediction
84 | output = (output - output.min()) / (output.max() - output.min())
85 | return output.unsqueeze(0)
86 |
--------------------------------------------------------------------------------
/warpback/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 |
5 | def get_edge_connect(weight_dir):
6 | inpaint_model = InpaintGenerator()
7 | inpaint_model_weight = torch.load(os.path.join(weight_dir, "InpaintingModel_gen.pth"))
8 | inpaint_model.load_state_dict(inpaint_model_weight["generator"])
9 | inpaint_model.eval()
10 |
11 | edge_model = EdgeGenerator()
12 | edge_model_weight = torch.load(os.path.join(weight_dir, "EdgeModel_gen.pth"))
13 | edge_model.load_state_dict(edge_model_weight["generator"])
14 | edge_model.eval()
15 |
16 | disp_model = InpaintGenerator(in_channels=2, out_channels=1)
17 | disp_model_weight = torch.load(os.path.join(weight_dir, "InpaintingModel_disp.pth"))
18 | disp_model.load_state_dict(disp_model_weight["generator"])
19 | disp_model.eval()
20 | return edge_model, inpaint_model, disp_model
21 |
22 | class BaseNetwork(nn.Module):
23 | def __init__(self):
24 | super(BaseNetwork, self).__init__()
25 |
26 | def init_weights(self, init_type='normal', gain=0.02):
27 | '''
28 | initialize network's weights
29 | init_type: normal | xavier | kaiming | orthogonal
30 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
31 | '''
32 |
33 | def init_func(m):
34 | classname = m.__class__.__name__
35 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
36 | if init_type == 'normal':
37 | nn.init.normal_(m.weight.data, 0.0, gain)
38 | elif init_type == 'xavier':
39 | nn.init.xavier_normal_(m.weight.data, gain=gain)
40 | elif init_type == 'kaiming':
41 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
42 | elif init_type == 'orthogonal':
43 | nn.init.orthogonal_(m.weight.data, gain=gain)
44 |
45 | if hasattr(m, 'bias') and m.bias is not None:
46 | nn.init.constant_(m.bias.data, 0.0)
47 |
48 | elif classname.find('BatchNorm2d') != -1:
49 | nn.init.normal_(m.weight.data, 1.0, gain)
50 | nn.init.constant_(m.bias.data, 0.0)
51 |
52 | self.apply(init_func)
53 |
54 |
55 | class InpaintGenerator(BaseNetwork):
56 | def __init__(self, residual_blocks=8, init_weights=True, in_channels=4, out_channels=3):
57 | super(InpaintGenerator, self).__init__()
58 |
59 | self.encoder = nn.Sequential(
60 | nn.ReflectionPad2d(3),
61 | nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0),
62 | nn.InstanceNorm2d(64, track_running_stats=False),
63 | nn.ReLU(True),
64 |
65 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
66 | nn.InstanceNorm2d(128, track_running_stats=False),
67 | nn.ReLU(True),
68 |
69 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
70 | nn.InstanceNorm2d(256, track_running_stats=False),
71 | nn.ReLU(True)
72 | )
73 |
74 | blocks = []
75 | for _ in range(residual_blocks):
76 | block = ResnetBlock(256, 2)
77 | blocks.append(block)
78 |
79 | self.middle = nn.Sequential(*blocks)
80 |
81 | self.decoder = nn.Sequential(
82 | nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
83 | nn.InstanceNorm2d(128, track_running_stats=False),
84 | nn.ReLU(True),
85 |
86 | nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
87 | nn.InstanceNorm2d(64, track_running_stats=False),
88 | nn.ReLU(True),
89 |
90 | nn.ReflectionPad2d(3),
91 | nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=7, padding=0),
92 | )
93 |
94 | if init_weights:
95 | self.init_weights()
96 |
97 | def forward(self, x):
98 | x = self.encoder(x)
99 | x = self.middle(x)
100 | x = self.decoder(x)
101 | x = (torch.tanh(x) + 1) / 2
102 |
103 | return x
104 |
105 |
106 | class EdgeGenerator(BaseNetwork):
107 | def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True):
108 | super(EdgeGenerator, self).__init__()
109 |
110 | self.encoder = nn.Sequential(
111 | nn.ReflectionPad2d(3),
112 | spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm),
113 | nn.InstanceNorm2d(64, track_running_stats=False),
114 | nn.ReLU(True),
115 |
116 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
117 | nn.InstanceNorm2d(128, track_running_stats=False),
118 | nn.ReLU(True),
119 |
120 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm),
121 | nn.InstanceNorm2d(256, track_running_stats=False),
122 | nn.ReLU(True)
123 | )
124 |
125 | blocks = []
126 | for _ in range(residual_blocks):
127 | block = ResnetBlock(256, 2, use_spectral_norm=use_spectral_norm)
128 | blocks.append(block)
129 |
130 | self.middle = nn.Sequential(*blocks)
131 |
132 | self.decoder = nn.Sequential(
133 | spectral_norm(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
134 | nn.InstanceNorm2d(128, track_running_stats=False),
135 | nn.ReLU(True),
136 |
137 | spectral_norm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), use_spectral_norm),
138 | nn.InstanceNorm2d(64, track_running_stats=False),
139 | nn.ReLU(True),
140 |
141 | nn.ReflectionPad2d(3),
142 | nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, padding=0),
143 | )
144 |
145 | if init_weights:
146 | self.init_weights()
147 |
148 | def forward(self, x):
149 | x = self.encoder(x)
150 | x = self.middle(x)
151 | x = self.decoder(x)
152 | x = torch.sigmoid(x)
153 | return x
154 |
155 |
156 | class ResnetBlock(nn.Module):
157 | def __init__(self, dim, dilation=1, use_spectral_norm=False):
158 | super(ResnetBlock, self).__init__()
159 | self.conv_block = nn.Sequential(
160 | nn.ReflectionPad2d(dilation),
161 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm),
162 | nn.InstanceNorm2d(dim, track_running_stats=False),
163 | nn.ReLU(True),
164 |
165 | nn.ReflectionPad2d(1),
166 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm),
167 | nn.InstanceNorm2d(dim, track_running_stats=False),
168 | )
169 |
170 | def forward(self, x):
171 | out = x + self.conv_block(x)
172 | return out
173 |
174 |
175 | def spectral_norm(module, mode=True):
176 | if mode:
177 | return nn.utils.spectral_norm(module)
178 |
179 | return module
--------------------------------------------------------------------------------
/warpback/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from PIL import Image
3 | import torch
4 | import torch.nn.functional as F
5 | from torchvision.utils import save_image
6 | from torchvision import transforms
7 | from pytorch3d.renderer.mesh import rasterize_meshes
8 | from pytorch3d.structures import Meshes
9 | from pytorch3d.ops import interpolate_face_attributes
10 |
11 |
12 | class RGBDRenderer:
13 | def __init__(self, device):
14 | self.device = device
15 | self.eps = 1e-4
16 | self.near_z = 1e-4
17 | self.far_z = 1e4
18 |
19 | def render_mesh(self, mesh_dict, cam_int, cam_ext):
20 | '''
21 | input:
22 | mesh: the output for construct_mesh function
23 | cam_int: [b,3,3] normalized camera intrinsic matrix
24 | cam_ext: [b,3,4] camera extrinsic matrix with the same scale as depth map
25 | camera coord: x to right, z to front, y to down
26 |
27 | output:
28 | render: [b,3,h,w]
29 | disparity: [b,1,h,w]
30 | '''
31 | vertice = mesh_dict["vertice"] # [b,h*w,3]
32 | faces = mesh_dict["faces"] # [b,nface,3]
33 | attributes = mesh_dict["attributes"] # [b,h*w,4]
34 | h, w = mesh_dict["size"]
35 |
36 | ############
37 | # to NDC space
38 | vertice_homo = self.lift_to_homo(vertice) # [b,h*w,4]
39 | # [b,1,3,4] x [b,h*w,4,1] = [b,h*w,3,1]
40 | vertice_world = torch.matmul(cam_ext.unsqueeze(1), vertice_homo[..., None]).squeeze(-1) # [b,h*w,3]
41 | vertice_depth = vertice_world[..., -1:] # [b,h*w,1]
42 | attributes = torch.cat([attributes, vertice_depth], dim=-1) # [b,h*w,5]
43 | # [b,1,3,3] x [b,h*w,3,1] = [b,h*w,3,1]
44 | vertice_world_homo = self.lift_to_homo(vertice_world)
45 | persp = self.get_perspective_from_intrinsic(cam_int) # [b,4,4]
46 |
47 | # [b,1,4,4] x [b,h*w,4,1] = [b,h*w,4,1]
48 | vertice_ndc = torch.matmul(persp.unsqueeze(1), vertice_world_homo[..., None]).squeeze(-1) # [b,h*w,4]
49 | vertice_ndc = vertice_ndc[..., :-1] / vertice_ndc[..., -1:]
50 | vertice_ndc[..., :-1] *= -1
51 | vertice_ndc[..., 0] *= w / h
52 |
53 | ############
54 | # render
55 | mesh = Meshes(vertice_ndc, faces)
56 | pix_to_face, _, bary_coords, _ = rasterize_meshes(mesh, (h, w), faces_per_pixel=1, blur_radius=1e-6) # [b,h,w,1] [b,h,w,1,3]
57 |
58 | b, nf, _ = faces.size()
59 | faces = faces.reshape(b, nf * 3, 1).repeat(1, 1, 5) # [b,3f,5]
60 | face_attributes = torch.gather(attributes, dim=1, index=faces) # [b,3f,5]
61 | face_attributes = face_attributes.reshape(b * nf, 3, 5)
62 | output = interpolate_face_attributes(pix_to_face, bary_coords, face_attributes)
63 | output = output.squeeze(-2).permute(0, 3, 1, 2)
64 |
65 | render = output[:, :3]
66 | mask = output[:, 3:4]
67 | disparity = torch.reciprocal(output[:, 4:] + self.eps)
68 | return render * mask, disparity * mask, mask
69 |
70 | def construct_mesh(self, rgbd, cam_int):
71 | '''
72 | input:
73 | rgbd: [b,4,h,w]
74 | the first 3 channels for RGB
75 | the last channel for normalized disparity, in range [0,1]
76 | cam_int: [b,3,3] normalized camera intrinsic matrix
77 |
78 | output:
79 | mesh_dict: define mesh in camera space, includes the following keys
80 | vertice: [b,h*w,3]
81 | faces: [b,nface,3]
82 | attributes: [b,h*w,c] include color and mask
83 | '''
84 | b, _, h, w = rgbd.size()
85 |
86 | ############
87 | # get pixel coordinates
88 | pixel_2d = self.get_screen_pixel_coord(h, w) # [1,h,w,2]
89 | pixel_2d_homo = self.lift_to_homo(pixel_2d) # [1,h,w,3]
90 |
91 | ############
92 | # project pixels to 3D space
93 | rgbd = rgbd.permute(0, 2, 3, 1) # [b,h,w,4]
94 | disparity = rgbd[..., -1:] # [b,h,w,1]
95 | depth = torch.reciprocal(disparity + self.eps) # [b,h,w,1]
96 | cam_int_inv = torch.inverse(cam_int) # [b,3,3]
97 | # [b,1,1,3,3] x [1,h,w,3,1] = [b,h,w,3,1]
98 | pixel_3d = torch.matmul(cam_int_inv[:, None, None, :, :], pixel_2d_homo[..., None]).squeeze(-1) # [b,h,w,3]
99 | pixel_3d = pixel_3d * depth # [b,h,w,3]
100 | vertice = pixel_3d.reshape(b, h * w, 3) # [b,h*w,3]
101 |
102 | ############
103 | # construct faces
104 | faces = self.get_faces(h, w) # [1,nface,3]
105 | faces = faces.repeat(b, 1, 1).long() # [b,nface,3]
106 |
107 | ############
108 | # compute attributes
109 | attr_color = rgbd[..., :-1].reshape(b, h * w, 3) # [b,h*w,3]
110 | attr_mask = self.get_visible_mask(disparity).reshape(b, h * w, 1) # [b,h*w,1]
111 | attr = torch.cat([attr_color, attr_mask], dim=-1) # [b,h*w,4]
112 |
113 | mesh_dict = {
114 | "vertice": vertice,
115 | "faces": faces,
116 | "attributes": attr,
117 | "size": [h, w],
118 | }
119 | return mesh_dict
120 |
121 | def get_screen_pixel_coord(self, h, w):
122 | '''
123 | get normalized pixel coordinates on the screen
124 | x to left, y to down
125 |
126 | e.g.
127 | [0,0][1,0][2,0]
128 | [0,1][1,1][2,1]
129 | output:
130 | pixel_coord: [1,h,w,2]
131 | '''
132 | x = torch.arange(w).to(self.device) # [w]
133 | y = torch.arange(h).to(self.device) # [h]
134 | x = (x + 0.5) / w
135 | y = (y + 0.5) / h
136 | x = x[None, None, ..., None].repeat(1, h, 1, 1) # [1,h,w,1]
137 | y = y[None, ..., None, None].repeat(1, 1, w, 1) # [1,h,w,1]
138 | pixel_coord = torch.cat([x, y], dim=-1) # [1,h,w,2]
139 | return pixel_coord
140 |
141 | def lift_to_homo(self, coord):
142 | '''
143 | return the homo version of coord
144 | input: coord [..., k]
145 | output: homo_coord [...,k+1]
146 | '''
147 | ones = torch.ones_like(coord[..., -1:])
148 | return torch.cat([coord, ones], dim=-1)
149 |
150 | def get_faces(self, h, w):
151 | '''
152 | get face connect information
153 | x to left, y to down
154 | e.g.
155 | [0,0][1,0][2,0]
156 | [0,1][1,1][2,1]
157 | faces: [1,nface,3]
158 | '''
159 | x = torch.arange(w - 1).to(self.device) # [w-1]
160 | y = torch.arange(h - 1).to(self.device) # [h-1]
161 | x = x[None, None, ..., None].repeat(1, h - 1, 1, 1) # [1,h-1,w-1,1]
162 | y = y[None, ..., None, None].repeat(1, 1, w - 1, 1) # [1,h-1,w-1,1]
163 |
164 | tl = y * w + x
165 | tr = y * w + x + 1
166 | bl = (y + 1) * w + x
167 | br = (y + 1) * w + x + 1
168 |
169 | faces_l = torch.cat([tl, bl, br], dim=-1).reshape(1, -1, 3) # [1,(h-1)(w-1),3]
170 | faces_r = torch.cat([br, tr, tl], dim=-1).reshape(1, -1, 3) # [1,(h-1)(w-1),3]
171 |
172 | return torch.cat([faces_l, faces_r], dim=1) # [1,nface,3]
173 |
174 | def get_visible_mask(self, disparity, beta=10, alpha_threshold=0.3):
175 | '''
176 | filter the disparity map using sobel kernel, then mask out the edge (depth discontinuity)
177 | input:
178 | disparity: [b,h,w,1]
179 |
180 | output:
181 | vis_mask: [b,h,w,1]
182 | '''
183 | b, h, w, _ = disparity.size()
184 | disparity = disparity.reshape(b, 1, h, w) # [b,1,h,w]
185 | kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).unsqueeze(0).unsqueeze(0).float().to(self.device)
186 | kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).unsqueeze(0).unsqueeze(0).float().to(self.device)
187 | sobel_x = F.conv2d(disparity, kernel_x, padding=(1, 1)) # [b,1,h,w]
188 | sobel_y = F.conv2d(disparity, kernel_y, padding=(1, 1)) # [b,1,h,w]
189 | sobel_mag = torch.sqrt(sobel_x ** 2 + sobel_y ** 2).reshape(b, h, w, 1) # [b,h,w,1]
190 | alpha = torch.exp(-1.0 * beta * sobel_mag) # [b,h,w,1]
191 | vis_mask = torch.greater(alpha, alpha_threshold).float()
192 | return vis_mask
193 |
194 | def get_perspective_from_intrinsic(self, cam_int):
195 | '''
196 | input:
197 | cam_int: [b,3,3]
198 |
199 | output:
200 | persp: [b,4,4]
201 | '''
202 | fx, fy = cam_int[:, 0, 0], cam_int[:, 1, 1] # [b]
203 | cx, cy = cam_int[:, 0, 2], cam_int[:, 1, 2] # [b]
204 |
205 | one = torch.ones_like(cx) # [b]
206 | zero = torch.zeros_like(cx) # [b]
207 |
208 | near_z, far_z = self.near_z * one, self.far_z * one
209 | a = (near_z + far_z) / (far_z - near_z)
210 | b = -2.0 * near_z * far_z / (far_z - near_z)
211 |
212 | matrix = [[2.0 * fx, zero, 2.0 * cx - 1.0, zero],
213 | [zero, 2.0 * fy, 2.0 * cy - 1.0, zero],
214 | [zero, zero, a, b],
215 | [zero, zero, one, zero]]
216 | # -> [[b,4],[b,4],[b,4],[b,4]] -> [b,4,4]
217 | persp = torch.stack([torch.stack(row, dim=-1) for row in matrix], dim=-2) # [b,4,4]
218 | return persp
219 |
220 |
221 | #######################
222 | # some helper I/O functions
223 | #######################
224 | def image_to_tensor(img_path, unsqueeze=True):
225 | rgb = transforms.ToTensor()(Image.open(img_path))
226 | if unsqueeze:
227 | rgb = rgb.unsqueeze(0)
228 | return rgb
229 |
230 |
231 | def disparity_to_tensor(disp_path, unsqueeze=True):
232 | disp = cv2.imread(disp_path, -1) / (2 ** 16 - 1)
233 | disp = torch.from_numpy(disp)[None, ...]
234 | if unsqueeze:
235 | disp = disp.unsqueeze(0)
236 | return disp.float()
237 |
238 |
239 | #######################
240 | # some helper geometry functions
241 | # adapt from https://github.com/mattpoggi/depthstillation
242 | #######################
243 | def transformation_from_parameters(axisangle, translation, invert=False):
244 | R = rot_from_axisangle(axisangle)
245 | t = translation.clone()
246 |
247 | if invert:
248 | R = R.transpose(1, 2)
249 | t *= -1
250 |
251 | T = get_translation_matrix(t)
252 |
253 | if invert:
254 | M = torch.matmul(R, T)
255 | else:
256 | M = torch.matmul(T, R)
257 |
258 | return M
259 |
260 |
261 | def get_translation_matrix(translation_vector):
262 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
263 | t = translation_vector.contiguous().view(-1, 3, 1)
264 | T[:, 0, 0] = 1
265 | T[:, 1, 1] = 1
266 | T[:, 2, 2] = 1
267 | T[:, 3, 3] = 1
268 | T[:, :3, 3, None] = t
269 | return T
270 |
271 |
272 | def rot_from_axisangle(vec):
273 | angle = torch.norm(vec, 2, 2, True)
274 | axis = vec / (angle + 1e-7)
275 |
276 | ca = torch.cos(angle)
277 | sa = torch.sin(angle)
278 | C = 1 - ca
279 |
280 | x = axis[..., 0].unsqueeze(1)
281 | y = axis[..., 1].unsqueeze(1)
282 | z = axis[..., 2].unsqueeze(1)
283 |
284 | xs = x * sa
285 | ys = y * sa
286 | zs = z * sa
287 | xC = x * C
288 | yC = y * C
289 | zC = z * C
290 | xyC = x * yC
291 | yzC = y * zC
292 | zxC = z * xC
293 |
294 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
295 |
296 | rot[:, 0, 0] = torch.squeeze(x * xC + ca)
297 | rot[:, 0, 1] = torch.squeeze(xyC - zs)
298 | rot[:, 0, 2] = torch.squeeze(zxC + ys)
299 | rot[:, 1, 0] = torch.squeeze(xyC + zs)
300 | rot[:, 1, 1] = torch.squeeze(y * yC + ca)
301 | rot[:, 1, 2] = torch.squeeze(yzC - xs)
302 | rot[:, 2, 0] = torch.squeeze(zxC - ys)
303 | rot[:, 2, 1] = torch.squeeze(yzC + xs)
304 | rot[:, 2, 2] = torch.squeeze(z * zC + ca)
305 | rot[:, 3, 3] = 1
306 |
307 | return rot
308 |
309 |
--------------------------------------------------------------------------------