├── LICENSE ├── README.md ├── ckpts └── Exp-syndney │ ├── MPI_rendered_views.gif │ ├── MPI_rendered_views.mp4 │ ├── canvas.png │ └── canvas_depth.png ├── dataloaders └── single_img_data.py ├── model ├── Inpainter.py ├── MPF.py ├── TrainableFilter.py └── VitExtractor.py ├── mpi ├── homography_sampler.py ├── mpi_rendering.py └── rendering_utils.py ├── outpaint_rgbd.py ├── requirements.txt ├── scripts └── train_all.sh ├── test_images └── Syndney.jpg ├── train_inpainting.py ├── train_mpi.py ├── utils_for_train.py └── warpback ├── networks.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tricky 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Welcome to SinMPI! 2 | 3 | ["SinMPI: Novel View Synthesis from a Single Image with Expanded Multiplane Images" (SIGGRAPH Asia 2023)](https://arxiv.org/abs/2312.11037). 4 | 5 | ## Quick demo 6 | 7 | ### 1. Prepare 8 | 9 | (1) Create a new conda environment specified in requirements.txt. 10 | 11 | (2) Download pretrained weights of depth-aware inpainter [ecweights](https://drive.google.com/drive/folders/1FZZ6laPuqEMSfrGvEWYaDZWEPaHvGm6r) and put them into 'warpback/ecweights/xxx.pth'. 12 | 13 | ### 2. Run demo 14 | ``` 15 | sh scripts/train_all.sh 16 | ``` 17 | This demo converts 'test_images/Syndney.jpg' to an expanded MPI and renders novel views as in 'ckpts/Exp-Syndney-new/MPI_rendered_views.mp4'. 18 | 19 | ## What happens when running the demo? 20 | 21 | ### 1. Outpaint the input image 22 | 23 | In the above demo, we specify 'test_images/Syndney.jpg' 24 | 25 | 26 | 27 | as the input image, then we continuously outpaint the input image: 28 | ``` 29 | CUDA_VISIBLE_DEVICES=$cuda python outpaint_rgbd.py \ 30 | --width $width \ 31 | --height $height \ 32 | --ckpt_path $ckpt_path \ 33 | --img_path $img_path \ 34 | --extrapolate_times $extrapolate_times 35 | ``` 36 | Then we get the outpainted image and its depth estimated by a monocular depth estimator (DPT): 37 | 38 | 39 | 40 | ### 2. Finetune Depth-aware Inpainter and create Pseudo-multi-view images 41 | 42 | ``` 43 | CUDA_VISIBLE_DEVICES=$cuda python train_inpainting.py \ 44 | --width $width \ 45 | --height $height \ 46 | --ckpt_path $ckpt_path \ 47 | --img_path $img_path \ 48 | --num_epochs 10 \ 49 | --extrapolate_times $extrapolate_times \ 50 | --batch_size 1 #--load_warp_pairs --debugging 51 | ``` 52 | 53 | ### 3. Optimizing the expanded MPI 54 | 55 | ``` 56 | CUDA_VISIBLE_DEVICES=$cuda python train_mpi.py \ 57 | --width $width \ 58 | --height $height \ 59 | --ckpt_path $ckpt_path \ 60 | --img_path $img_path \ 61 | --num_epochs 10 \ 62 | --extrapolate_times $extrapolate_times \ 63 | --batch_size 1 #--debugging #--resume 64 | ``` 65 | 66 | After optimization, we render novel views: 67 | 68 | 69 | 70 | ## Toward Better quality and robustness 71 | 72 | Notice the above demo is designed for fast illustration (FPS is low). For better quality: 73 | 74 | #### Pesudo-multi-views should be more and cover more areas. 75 | Increasing the sample rate and sample areas helps to optimize MPI with better quality. To add training and rendering view trajectories, modify 'dataloaders/single_img_data.py'. 76 | #### More training epochs are needed. 77 | 78 | ## Cite our paper 79 | 80 | If you find our work helpful, please cite our paper. Thank you! 81 | 82 | ACM Reference Format: 83 | ``` 84 | Guo Pu, Peng-Shuai Wang, and Zhouhui Lian. 2023. SinMPI: Novel View 85 | Synthesis from a Single Image with Expanded Multiplane Images. In SIGGRAPH Asia 2023 Conference Papers (SA Conference Papers '23), December 86 | 12–15, 2023, Sydney, NSW, Australia. ACM, New York, NY, USA, 10 pages. 87 | https://doi.org/10.1145/3610548.3618155 88 | ``` 89 | -------------------------------------------------------------------------------- /ckpts/Exp-syndney/MPI_rendered_views.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/SinMPI/e14a8232ea1c0e6b6680a49a89cc0ce6f6e2409f/ckpts/Exp-syndney/MPI_rendered_views.gif -------------------------------------------------------------------------------- /ckpts/Exp-syndney/MPI_rendered_views.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/SinMPI/e14a8232ea1c0e6b6680a49a89cc0ce6f6e2409f/ckpts/Exp-syndney/MPI_rendered_views.mp4 -------------------------------------------------------------------------------- /ckpts/Exp-syndney/canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/SinMPI/e14a8232ea1c0e6b6680a49a89cc0ce6f6e2409f/ckpts/Exp-syndney/canvas.png -------------------------------------------------------------------------------- /ckpts/Exp-syndney/canvas_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/SinMPI/e14a8232ea1c0e6b6680a49a89cc0ce6f6e2409f/ckpts/Exp-syndney/canvas_depth.png -------------------------------------------------------------------------------- /dataloaders/single_img_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | sys.path.append("..") 4 | import math 5 | import numpy as np 6 | import torch 7 | from torch.utils.data.dataset import Dataset 8 | from torchvision import transforms 9 | from PIL import Image 10 | 11 | def gen_swing_path(init_pose=torch.eye(4), num_frames=10, r_x=0.2, r_y=0, r_z=-0.2): 12 | "Return a list of matrix [4, 4]" 13 | t = torch.arange(num_frames) / (num_frames - 1) 14 | poses = init_pose.repeat(num_frames, 1, 1) 15 | 16 | swing = torch.eye(4).repeat(num_frames, 1, 1) 17 | swing[:, 0, 3] = r_x * torch.sin(2. * math.pi * t) 18 | swing[:, 1, 3] = r_y * torch.cos(2. * math.pi * t) 19 | swing[:, 2, 3] = r_z * (torch.cos(2. * math.pi * t) - 1.) 20 | 21 | for i in range(num_frames): 22 | poses[i, :, :] = poses[i, :, :] @ swing[i, :, :] 23 | return list(poses.unbind()) 24 | 25 | def create_spheric_poses_along_y(n_poses=10): 26 | 27 | def spheric_pose_y(phi, radius=10): 28 | trans_t = lambda t : np.array([ 29 | [1,0,0, math.sin(2. * math.pi * t) * radius], 30 | [0,1,0,0], 31 | [0,0,1,0], 32 | [0,0,0,1], 33 | ]) 34 | 35 | # rotation along y 36 | rot_phi = lambda phi : np.array([ 37 | [np.cos(phi),0, np.sin(phi),0], 38 | [0,1,0,0], 39 | [-np.sin(phi),0,np.cos(phi),0], 40 | [0,0,0,1], 41 | ]) 42 | 43 | c2w = rot_phi(phi) @ trans_t(phi) 44 | c2w = np.array([[1,0,0,0], 45 | [0,1,0,0], 46 | [0,0,1,0], 47 | [0,0,0,1]]) @ c2w 48 | c2w = torch.tensor(c2w).float() 49 | return c2w 50 | 51 | def spheric_pose_x(phi, radius=10): 52 | trans_t = lambda t : np.array([ 53 | [1,0,0,0], 54 | [0,1,0,math.sin(2. * math.pi * t) * radius * -1], 55 | [0,0,1,0], 56 | [0,0,0,1], 57 | ]) 58 | 59 | # rotation along x 60 | rot_theta = lambda th : np.array([ 61 | [1,0,0,0], 62 | [0,np.cos(th),-np.sin(th),0], 63 | [0,np.sin(th), np.cos(th),0], 64 | [0,0,0,1], 65 | ]) 66 | 67 | c2w = rot_theta(phi) @ trans_t(phi) 68 | c2w = np.array([[1,0,0,0], 69 | [0,1,0,0], 70 | [0,0,1,0], 71 | [0,0,0,1]]) @ c2w 72 | c2w = torch.tensor(c2w).float() 73 | return c2w 74 | 75 | spheric_poses = [] 76 | poses = gen_swing_path() 77 | spheric_poses += poses 78 | 79 | factor = 1 80 | y_angle = (1/16) * np.pi * factor 81 | x_angle = (1/16) * np.pi * factor 82 | x_radius = 0.1 * factor 83 | y_radius = 0.1 * factor 84 | 85 | # rotate left and right 86 | for th in np.linspace(0, -1 * y_angle, n_poses//2): 87 | spheric_poses += [spheric_pose_y(th, y_radius)] 88 | 89 | poses = gen_swing_path(spheric_poses[-1]) 90 | spheric_poses += poses 91 | 92 | for th in np.linspace(-1 * y_angle, y_angle, n_poses)[:-1]: 93 | spheric_poses += [spheric_pose_y(th, y_radius)] 94 | 95 | poses = gen_swing_path(spheric_poses[-1]) 96 | spheric_poses += poses 97 | 98 | for th in np.linspace(y_angle, 0, n_poses//2)[:-1]: 99 | spheric_poses += [spheric_pose_y(th, y_radius)] 100 | 101 | # rotate up and down 102 | for th in np.linspace(0, -1 * x_angle, n_poses//2): 103 | spheric_poses += [spheric_pose_x(th, x_radius)] 104 | 105 | poses = gen_swing_path(spheric_poses[-1]) 106 | spheric_poses += poses 107 | 108 | for th in np.linspace(-1 * x_angle, x_angle, n_poses)[:-1]: 109 | spheric_poses += [spheric_pose_x(th, x_radius)] 110 | 111 | poses = gen_swing_path(spheric_poses[-1]) 112 | spheric_poses += poses 113 | 114 | for th in np.linspace(x_angle, 0, n_poses//2)[:-1]: 115 | spheric_poses += [spheric_pose_x(th, x_radius)] 116 | 117 | return spheric_poses 118 | 119 | 120 | def convert(c2w, phi=0): 121 | # rot_along_y 122 | c2w = np.concatenate((c2w, np.array([[0, 0, 0, 1]])), axis=0) 123 | rot = np.array([ 124 | [np.cos(phi),0, np.sin(phi),0], 125 | [0,1,0,0], 126 | [-np.sin(phi),0,np.cos(phi),0], 127 | [0,0,0,1], 128 | ]) 129 | return rot @ c2w 130 | 131 | class SinImgDataset(Dataset): 132 | def __init__( 133 | self, 134 | img_path, 135 | width=512, 136 | height=512, 137 | repeat_times = 1 138 | ): 139 | self.repeat_times = repeat_times 140 | self.img_wh = [width, height] 141 | self.transform = transforms.ToTensor() 142 | self.img_path = img_path 143 | self.read_meta() 144 | 145 | def read_meta(self): 146 | 147 | self.all_rgbs = [] 148 | self.all_poses = [] 149 | self.all_poses += create_spheric_poses_along_y() 150 | 151 | img_path = "test_images/" + self.img_path 152 | self.ref_img = self.load_img(img_path) 153 | 154 | 155 | def load_img(self, image_path): 156 | img = Image.open(image_path).convert('RGB') 157 | img = img.resize(self.img_wh, Image.LANCZOS) 158 | img = self.transform(img) # (3, h, w) 159 | img = img.unsqueeze(0).cuda() 160 | return img # [img:1*3*h*w] 161 | 162 | 163 | def __len__(self): 164 | return len(self.all_poses) * self.repeat_times 165 | 166 | def __getitem__(self, idx): 167 | sample = { 168 | 'cur_pose':self.all_poses[idx % len(self.all_poses)] 169 | } 170 | return sample 171 | -------------------------------------------------------------------------------- /model/Inpainter.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | sys.path.append("..") 4 | import os 5 | import glob 6 | import numpy as np 7 | from skimage.feature import canny 8 | import torch 9 | import torch.nn.functional as F 10 | from torchvision import transforms 11 | 12 | from warpback.networks import get_edge_connect 13 | 14 | class InpaintingModule(torch.nn.Module): 15 | def __init__( 16 | self, 17 | data_root='', 18 | width=512, 19 | height=512, 20 | depth_dir_name="dpt_depth", 21 | device="cuda:0", 22 | trans_range={"x":0.2, "y":-1, "z":-1, "a":-1, "b":-1, "c":-1}, # xyz for translation, abc for euler angle 23 | ec_weight_dir="warpback/ecweight", 24 | ): 25 | super(InpaintingModule, self).__init__() 26 | self.data_root = data_root 27 | self.depth_dir_name = depth_dir_name 28 | self.width = width 29 | self.height = height 30 | self.device = device 31 | self.trans_range = trans_range 32 | self.image_path_list = glob.glob(os.path.join(self.data_root, "*.jpg")) 33 | self.image_path_list += glob.glob(os.path.join(self.data_root, "*.png")) 34 | 35 | self.edge_model, self.inpaint_model, self.disp_model = get_edge_connect(ec_weight_dir) 36 | self.edge_model = self.edge_model.to(self.device) 37 | self.inpaint_model = self.inpaint_model.to(self.device) 38 | self.disp_model = self.disp_model.to(self.device) 39 | 40 | def preprocess_rgbd(self, image, disp): 41 | image = F.interpolate(image.unsqueeze(0), (self.height, self.width), mode="bilinear").squeeze(0) 42 | disp = F.interpolate(disp.unsqueeze(0), (self.height, self.width), mode="bilinear").squeeze(0) 43 | return image, disp 44 | 45 | def forward(self, image, disp, mask): 46 | 47 | image_gray = transforms.Grayscale()(image) 48 | edge = self.get_edge(image_gray, mask) 49 | 50 | mask_hole = 1 - mask 51 | 52 | # inpaint edge 53 | edge_model_input = torch.cat([image_gray, edge, mask_hole], dim=1) # [b,4,h,w] 54 | edge_inpaint = self.edge_model(edge_model_input) # [b,1,h,w] 55 | 56 | # inpaint RGB 57 | inpaint_model_input = torch.cat([image + mask_hole, edge_inpaint], dim=1) 58 | image_inpaint = self.inpaint_model(inpaint_model_input) 59 | image_merged = image * (1 - mask_hole) + image_inpaint * mask_hole 60 | 61 | # inpaint Disparity 62 | disp_model_input = torch.cat([disp + mask_hole, edge_inpaint], dim=1) 63 | disp_inpaint = self.disp_model(disp_model_input) 64 | disp_merged = disp * (1 - mask_hole) + disp_inpaint * mask_hole 65 | 66 | return image_merged, disp_merged 67 | 68 | def get_edge(self, image_gray, mask): 69 | image_gray_np = image_gray.squeeze(1).cpu().numpy() # [b,h,w] 70 | mask_bool_np = np.array(mask.squeeze(1).cpu(), dtype=np.bool_) # [b,h,w] 71 | edges = [] 72 | for i in range(mask.shape[0]): 73 | cur_edge = canny(image_gray_np[i], sigma=2, mask=mask_bool_np[i]) 74 | edges.append(torch.from_numpy(cur_edge).unsqueeze(0)) # [1,h,w] 75 | edge = torch.cat(edges, dim=0).unsqueeze(1).float() # [b,1,h,w] 76 | return edge.to(self.device) 77 | 78 | -------------------------------------------------------------------------------- /model/MPF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class MultiPlaneField(torch.nn.Module): 4 | def __init__( 5 | self, 6 | image_size=(256, 384), 7 | num_planes=12, 8 | assign_origin_planes=None, 9 | depth_range=None 10 | ): 11 | super(MultiPlaneField, self).__init__() 12 | # x:horizontal axis, y:vertical axis, z:inward axis 13 | 14 | self.num_planes = num_planes 15 | self.near, self.far = depth_range 16 | 17 | (self.plane_h, self.plane_w) = image_size 18 | 19 | self.planes_disp = torch.linspace(self.near, self.far, num_planes, requires_grad=False).unsqueeze(0).cuda() # [b,s] 20 | #[S:num_planes, H:plane_h , W:plane_w, 4:rgb+transparency] 21 | 22 | self.extrapolate_RGBDs = assign_origin_planes 23 | 24 | init_planes = self.assign_image_to_planes(self.extrapolate_RGBDs[0], self.extrapolate_RGBDs[1]) 25 | self.planes_mid2 = init_planes 26 | 27 | init_val = torch.ones(1, num_planes, 4, self.plane_h, self.plane_w) * 0.1 28 | init_val[:,:,:4,:,:] *= 0.01 29 | 30 | self.planes_residual = torch.nn.Parameter(init_val) 31 | 32 | 33 | def assign_image_to_planes(self, ref_img, ref_disp): 34 | planes = torch.zeros(1, self.num_planes, 4, self.plane_h, self.plane_w, requires_grad=False).cuda() 35 | # set ref_img alpha channels all ones 36 | ref_img = torch.cat([ref_img, torch.ones_like(ref_disp) * self.far],dim=1) #[1,3+1,h,w] 37 | depth_levels = torch.linspace(self.near, self.far, self.num_planes).cuda() 38 | 39 | planes_masks = [] 40 | for i in range(len(depth_levels)): 41 | cur_depth_mask = torch.where(ref_disp < depth_levels[i], 42 | torch.ones_like(ref_disp).cuda(), 43 | torch.zeros_like(ref_disp).cuda()) 44 | planes_masks.append(cur_depth_mask) 45 | cur_depth_pixels = ref_img * cur_depth_mask.repeat(1,4,1,1) 46 | cur_depth_pixels = cur_depth_pixels.unsqueeze(0) 47 | planes[:,i:i+1,:,:,:] = cur_depth_pixels#[1,1,4,h,w] 48 | 49 | ref_disp = ref_disp + cur_depth_mask * (self.far + 1)# the cur_masked area are discarded 50 | 51 | return planes 52 | 53 | def forward(self): 54 | planes = self.planes_mid2 55 | pred_planes = self.planes_residual 56 | 57 | residual_mask = torch.where(planes > 0, 58 | torch.zeros_like(planes).cuda(), 59 | torch.ones_like(planes).cuda()) 60 | 61 | x = pred_planes * residual_mask + planes 62 | return x 63 | -------------------------------------------------------------------------------- /model/TrainableFilter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class GaussianFilter(nn.Module): 6 | def __init__(self, ksize=5, sigma=None): 7 | super(GaussianFilter, self).__init__() 8 | # initialize guassian kernel 9 | if sigma is None: 10 | sigma = 0.3 * ((ksize-1) / 2.0 - 1) + 0.8 11 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) 12 | x_coord = torch.arange(ksize) 13 | x_grid = x_coord.repeat(ksize).view(ksize, ksize) 14 | y_grid = x_grid.t() 15 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() 16 | 17 | # Calculate the 2-dimensional gaussian kernel 18 | center = ksize // 2 19 | weight = torch.exp(-torch.sum((xy_grid - center)**2., dim=-1) / (2*sigma**2)) 20 | # Make sure sum of values in gaussian kernel equals 1. 21 | weight /= torch.sum(weight) 22 | self.gaussian_weight = weight 23 | 24 | def forward(self, x): 25 | return self.filter(x) 26 | 27 | class TrainableFilter(nn.Module): 28 | def __init__(self, ksize=5, sigma_space=None, sigma_density=1): 29 | super(TrainableFilter, self).__init__() 30 | # initialization 31 | if sigma_space is None: 32 | self.sigma_space = 0.3 * ((ksize-1) * 0.5 - 1) + 0.8 33 | else: 34 | self.sigma_space = sigma_space 35 | if sigma_density is None: 36 | self.sigma_density = self.sigma_space 37 | else: 38 | self.sigma_density = sigma_density 39 | 40 | self.pad = (ksize-1) // 2 41 | self.ksize = ksize 42 | # get the spatial gaussian weight 43 | self.weight_space = GaussianFilter(ksize=self.ksize, sigma=self.sigma_space).gaussian_weight.cuda() 44 | # # create gaussian filter as convolutional layer 45 | self.weight_space = torch.nn.Parameter(self.weight_space) 46 | 47 | def forward(self, x): 48 | # Extracts sliding local patches from a batched input tensor. 49 | x_pad = F.pad(x, pad=[self.pad, self.pad, self.pad, self.pad], mode='reflect') 50 | x_patches = x_pad.unfold(2, self.ksize, 1).unfold(3, self.ksize, 1) 51 | patch_dim = x_patches.dim() 52 | 53 | # Calculate the 2-dimensional gaussian kernel 54 | diff_density = x_patches - x.unsqueeze(-1).unsqueeze(-1) 55 | weight_density = torch.exp(-(diff_density ** 2) / (2 * self.sigma_density ** 2)) 56 | # Normalization 57 | # weight_density /= weight_density.sum(dim=(-1, -2), keepdim=True) 58 | weight_density = weight_density / weight_density.sum(dim=(-1, -2), keepdim=True) 59 | # print(weight_density.shape) 60 | 61 | # Keep same shape with weight_density 62 | weight_space_dim = (patch_dim - 2) * (1, ) + (self.ksize, self.ksize) 63 | weight_space = self.weight_space.view(*weight_space_dim).expand_as(weight_density) 64 | 65 | # get the final kernel weight 66 | weight = weight_density * weight_space 67 | weight_sum = weight.sum(dim=(-1, -2)) 68 | x = (weight * x_patches).sum(dim=(-1, -2)) / weight_sum 69 | return x 70 | -------------------------------------------------------------------------------- /model/VitExtractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def attn_cosine_sim(x, eps=1e-08): 6 | x = x[0] # TEMP: getting rid of redundant dimension, TBF 7 | norm1 = x.norm(dim=2, keepdim=True) 8 | factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps) 9 | sim_matrix = (x @ x.permute(0, 2, 1)) / factor 10 | return sim_matrix 11 | 12 | 13 | class VitExtractor(nn.Module): 14 | BLOCK_KEY = 'block' 15 | ATTN_KEY = 'attn' 16 | PATCH_IMD_KEY = 'patch_imd' 17 | QKV_KEY = 'qkv' 18 | KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY] 19 | 20 | def __init__(self, model_name, device): 21 | super().__init__() 22 | self.model = torch.hub.load( 23 | 'facebookresearch/dino:main', model_name).to(device) 24 | self.model.eval() 25 | self.model_name = model_name 26 | self.hook_handlers = [] 27 | self.layers_dict = {} 28 | self.outputs_dict = {} 29 | for key in VitExtractor.KEY_LIST: 30 | self.layers_dict[key] = [] 31 | self.outputs_dict[key] = [] 32 | self._init_hooks_data() 33 | 34 | def _init_hooks_data(self): 35 | self.layers_dict[VitExtractor.BLOCK_KEY] = [ 36 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] 37 | self.layers_dict[VitExtractor.ATTN_KEY] = [ 38 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] 39 | self.layers_dict[VitExtractor.QKV_KEY] = [ 40 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] 41 | self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [ 42 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] 43 | for key in VitExtractor.KEY_LIST: 44 | # self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else [] 45 | self.outputs_dict[key] = [] 46 | 47 | def _register_hooks(self, **kwargs): 48 | for block_idx, block in enumerate(self.model.blocks): 49 | if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]: 50 | self.hook_handlers.append( 51 | block.register_forward_hook(self._get_block_hook())) 52 | if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]: 53 | self.hook_handlers.append( 54 | block.attn.attn_drop.register_forward_hook(self._get_attn_hook())) 55 | if block_idx in self.layers_dict[VitExtractor.QKV_KEY]: 56 | self.hook_handlers.append( 57 | block.attn.qkv.register_forward_hook(self._get_qkv_hook())) 58 | if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]: 59 | self.hook_handlers.append( 60 | block.attn.register_forward_hook(self._get_patch_imd_hook())) 61 | 62 | def _clear_hooks(self): 63 | for handler in self.hook_handlers: 64 | handler.remove() 65 | self.hook_handlers = [] 66 | 67 | def _get_block_hook(self): 68 | def _get_block_output(model, input, output): 69 | self.outputs_dict[VitExtractor.BLOCK_KEY].append(output) 70 | 71 | return _get_block_output 72 | 73 | def _get_attn_hook(self): 74 | def _get_attn_output(model, inp, output): 75 | self.outputs_dict[VitExtractor.ATTN_KEY].append(output) 76 | 77 | return _get_attn_output 78 | 79 | def _get_qkv_hook(self): 80 | def _get_qkv_output(model, inp, output): 81 | self.outputs_dict[VitExtractor.QKV_KEY].append(output) 82 | 83 | return _get_qkv_output 84 | 85 | # TODO: CHECK ATTN OUTPUT TUPLE 86 | def _get_patch_imd_hook(self): 87 | def _get_attn_output(model, inp, output): 88 | self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0]) 89 | 90 | return _get_attn_output 91 | 92 | def get_feature_from_input(self, input_img): # List([B, N, D]) 93 | self._register_hooks() 94 | self.model(input_img) 95 | feature = self.outputs_dict[VitExtractor.BLOCK_KEY] 96 | self._clear_hooks() 97 | self._init_hooks_data() 98 | return feature 99 | 100 | def get_qkv_feature_from_input(self, input_img): 101 | self._register_hooks() 102 | self.model(input_img) 103 | feature = self.outputs_dict[VitExtractor.QKV_KEY] 104 | self._clear_hooks() 105 | self._init_hooks_data() 106 | return feature 107 | 108 | def get_attn_feature_from_input(self, input_img): 109 | self._register_hooks() 110 | self.model(input_img) 111 | feature = self.outputs_dict[VitExtractor.ATTN_KEY] 112 | self._clear_hooks() 113 | self._init_hooks_data() 114 | return feature 115 | 116 | def get_patch_size(self): 117 | return 8 if "8" in self.model_name else 16 118 | 119 | def get_width_patch_num(self, input_img_shape): 120 | b, c, h, w = input_img_shape 121 | patch_size = self.get_patch_size() 122 | return w // patch_size 123 | 124 | def get_height_patch_num(self, input_img_shape): 125 | b, c, h, w = input_img_shape 126 | patch_size = self.get_patch_size() 127 | return h // patch_size 128 | 129 | def get_patch_num(self, input_img_shape): 130 | patch_num = 1 + (self.get_height_patch_num(input_img_shape) 131 | * self.get_width_patch_num(input_img_shape)) 132 | return patch_num 133 | 134 | def get_head_num(self): 135 | if "dino" in self.model_name: 136 | return 6 if "s" in self.model_name else 12 137 | return 6 if "small" in self.model_name else 12 138 | 139 | def get_embedding_dim(self): 140 | if "dino" in self.model_name: 141 | return 384 if "s" in self.model_name else 768 142 | return 384 if "small" in self.model_name else 768 143 | 144 | def get_queries_from_qkv(self, qkv, input_img_shape): 145 | patch_num = self.get_patch_num(input_img_shape) 146 | head_num = self.get_head_num() 147 | embedding_dim = self.get_embedding_dim() 148 | q = qkv.reshape(patch_num, 3, head_num, embedding_dim // 149 | head_num).permute(1, 2, 0, 3)[0] 150 | return q 151 | 152 | def get_keys_from_qkv(self, qkv, input_img_shape): 153 | patch_num = self.get_patch_num(input_img_shape) 154 | head_num = self.get_head_num() 155 | embedding_dim = self.get_embedding_dim() 156 | k = qkv.reshape(patch_num, 3, head_num, embedding_dim // 157 | head_num).permute(1, 2, 0, 3)[1] 158 | return k 159 | 160 | def get_values_from_qkv(self, qkv, input_img_shape): 161 | patch_num = self.get_patch_num(input_img_shape) 162 | head_num = self.get_head_num() 163 | embedding_dim = self.get_embedding_dim() 164 | v = qkv.reshape(patch_num, 3, head_num, embedding_dim // 165 | head_num).permute(1, 2, 0, 3)[2] 166 | return v 167 | 168 | def get_keys_from_input(self, input_img, layer_num): 169 | qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num] 170 | keys = self.get_keys_from_qkv(qkv_features, input_img.shape) 171 | return keys 172 | 173 | def get_keys_self_sim_from_input(self, input_img, layer_num): 174 | keys = self.get_keys_from_input(input_img, layer_num=layer_num) 175 | h, t, d = keys.shape 176 | concatenated_keys = keys.transpose(0, 1).reshape(t, h * d) 177 | ssim_map = attn_cosine_sim(concatenated_keys[None, None, ...]) 178 | return ssim_map 179 | -------------------------------------------------------------------------------- /mpi/homography_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.spatial.transform import Rotation 4 | 5 | 6 | def inverse(matrices): 7 | """ 8 | torch.inverse() sometimes produces outputs with nan the when batch size is 2. 9 | Ref https://github.com/pytorch/pytorch/issues/47272 10 | this function keeps inversing the matrix until successful or maximum tries is reached 11 | :param matrices Bx3x3 12 | """ 13 | inverse = None 14 | max_tries = 5 15 | while (inverse is None) or (torch.isnan(inverse)).any(): 16 | torch.cuda.synchronize() 17 | inverse = torch.inverse(matrices) 18 | 19 | # Break out of the loop when the inverse is successful or there"re no more tries 20 | max_tries -= 1 21 | if max_tries == 0: 22 | break 23 | 24 | # Raise an Exception if the inverse contains nan 25 | if (torch.isnan(inverse)).any(): 26 | raise Exception("Matrix inverse contains nan!") 27 | return inverse 28 | 29 | 30 | class HomographySample: 31 | def __init__(self, H_tgt, W_tgt, device=None): 32 | if device is None: 33 | self.device = torch.device("cpu") 34 | else: 35 | self.device = device 36 | 37 | self.Height_tgt = H_tgt 38 | self.Width_tgt = W_tgt 39 | self.meshgrid = self.grid_generation(self.Height_tgt, self.Width_tgt, self.device) 40 | self.meshgrid = self.meshgrid.permute(2, 0, 1).contiguous() # 3xHxW 41 | 42 | self.n = self.plane_normal_generation(self.device) 43 | 44 | @staticmethod 45 | def grid_generation(H, W, device): 46 | x = np.linspace(0, W-1, W) 47 | y = np.linspace(0, H-1, H) 48 | xv, yv = np.meshgrid(x, y) # HxW 49 | xv = torch.from_numpy(xv.astype(np.float32)).to(dtype=torch.float32, device=device) 50 | yv = torch.from_numpy(yv.astype(np.float32)).to(dtype=torch.float32, device=device) 51 | ones = torch.ones_like(xv) 52 | meshgrid = torch.stack((xv, yv, ones), dim=2) # HxWx3 53 | return meshgrid 54 | 55 | @staticmethod 56 | def plane_normal_generation(device): 57 | n = torch.tensor([0, 0, 1], dtype=torch.float32, device=device) 58 | return n 59 | 60 | @staticmethod 61 | def euler_to_rotation_matrix(x_angle, y_angle, z_angle, seq='xyz', degrees=False): 62 | """ 63 | Note that here we want to return a rotation matrix rot_mtx, which transform the tgt points into src frame, 64 | i.e, rot_mtx * p_tgt = p_src 65 | Therefore we need to add negative to x/y/z_angle 66 | :param roll: 67 | :param pitch: 68 | :param yaw: 69 | :return: 70 | """ 71 | r = Rotation.from_euler(seq, 72 | [-x_angle, -y_angle, -z_angle], 73 | degrees=degrees) 74 | rot_mtx = r.as_matrix().astype(np.float32) 75 | return rot_mtx 76 | 77 | 78 | def sample(self, src_BCHW, d_src_B, 79 | G_tgt_src, 80 | K_src_inv, K_tgt): 81 | """ 82 | Coordinate system: x, y are the image directions, z is pointing to depth direction 83 | :param src_BCHW: torch tensor float, 0-1, rgb/rgba. BxCxHxW 84 | Assume to be at position P=[I|0] 85 | :param d_src_B: distance of image plane to src camera origin 86 | :param G_tgt_src: Bx4x4 87 | :param K_src_inv: Bx3x3 88 | :param K_tgt: Bx3x3 89 | :return: tgt_BCHW 90 | """ 91 | # parameter processing ------ begin ------ 92 | B, channels, Height_src, Width_src = src_BCHW.size(0), src_BCHW.size(1), src_BCHW.size(2), src_BCHW.size(3) 93 | R_tgt_src = G_tgt_src[:, 0:3, 0:3] 94 | t_tgt_src = G_tgt_src[:, 0:3, 3] 95 | 96 | Height_tgt = self.Height_tgt 97 | Width_tgt = self.Width_tgt 98 | 99 | R_tgt_src = R_tgt_src.to(device=src_BCHW.device) 100 | t_tgt_src = t_tgt_src.to(device=src_BCHW.device) 101 | K_src_inv = K_src_inv.to(device=src_BCHW.device) 102 | K_tgt = K_tgt.to(device=src_BCHW.device) 103 | # parameter processing ------ end ------ 104 | 105 | # the goal is compute H_src_tgt, that maps a tgt pixel to src pixel 106 | # so we compute H_tgt_src first, and then inverse 107 | n = self.n.to(device=src_BCHW.device) 108 | n = n.unsqueeze(0).repeat(B, 1) # Bx3 109 | # Bx3x3 - (Bx3x1 * Bx1x3) 110 | # note here we use -d_src, because the plane function is n^T * X - d_src = 0 111 | d_src_B33 = d_src_B.reshape(B, 1, 1).repeat(1, 3, 3) # B -> Bx3x3 112 | R_tnd = R_tgt_src - torch.matmul(t_tgt_src.unsqueeze(2), n.unsqueeze(1)) / -d_src_B33 113 | H_tgt_src = torch.matmul(K_tgt, 114 | torch.matmul(R_tnd, K_src_inv)) 115 | 116 | # TODO: fix cuda inverse 117 | with torch.no_grad(): 118 | H_src_tgt = inverse(H_tgt_src) 119 | 120 | # create tgt image grid, and map to src 121 | meshgrid_tgt_homo = self.meshgrid.to(src_BCHW.device) 122 | # 3xHxW -> Bx3xHxW 123 | meshgrid_tgt_homo = meshgrid_tgt_homo.unsqueeze(0).expand(B, 3, Height_tgt, Width_tgt) 124 | 125 | # wrap meshgrid_tgt_homo to meshgrid_src 126 | meshgrid_tgt_homo_B3N = meshgrid_tgt_homo.view(B, 3, -1) # Bx3xHW 127 | meshgrid_src_homo_B3N = torch.matmul(H_src_tgt, meshgrid_tgt_homo_B3N) # Bx3x3 * Bx3xHW -> Bx3xHW 128 | # Bx3xHW -> Bx3xHxW -> BxHxWx3 129 | meshgrid_src_homo = meshgrid_src_homo_B3N.view(B, 3, Height_tgt, Width_tgt).permute(0, 2, 3, 1) 130 | meshgrid_src = meshgrid_src_homo[:, :, :, 0:2] / meshgrid_src_homo[:, :, :, 2:] # BxHxWx2 131 | 132 | valid_mask_x = torch.logical_and(meshgrid_src[:, :, :, 0] < Width_src, 133 | meshgrid_src[:, :, :, 0] > -1) 134 | valid_mask_y = torch.logical_and(meshgrid_src[:, :, :, 1] < Height_src, 135 | meshgrid_src[:, :, :, 1] > -1) 136 | valid_mask = torch.logical_and(valid_mask_x, valid_mask_y) # BxHxW 137 | 138 | # sample from src_BCHW 139 | # normalize meshgrid_src to [-1,1] 140 | meshgrid_src[:, :, :, 0] = (meshgrid_src[:, :, :, 0]+0.5) / (Width_src * 0.5) - 1 141 | meshgrid_src[:, :, :, 1] = (meshgrid_src[:, :, :, 1]+0.5) / (Height_src * 0.5) - 1 142 | tgt_BCHW = torch.nn.functional.grid_sample(src_BCHW, grid=meshgrid_src, padding_mode='border', 143 | align_corners=False) 144 | # BxCxHxW, BxHxW 145 | return tgt_BCHW, valid_mask 146 | -------------------------------------------------------------------------------- /mpi/mpi_rendering.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mpi.homography_sampler import HomographySample 4 | from mpi.rendering_utils import transform_G_xyz 5 | 6 | 7 | def render(rgb_BS3HW, sigma_BS1HW, xyz_BS3HW): 8 | imgs_syn, weights = plane_volume_rendering(sigma_BS1HW, rgb_BS3HW) 9 | return imgs_syn 10 | 11 | def plane_volume_rendering(alpha_BK1HW, value_BKCHW): 12 | B, K, _, H, W = alpha_BK1HW.size() 13 | alpha_comp_cumprod = torch.cumprod(1 - alpha_BK1HW, dim=1) # BxKx1xHxW 14 | preserve_ratio = torch.cat((torch.ones((B, 1, 1, H, W), dtype=alpha_BK1HW.dtype, device=alpha_BK1HW.device), 15 | alpha_comp_cumprod[:, 0:K-1, :, :, :]), dim=1) # BxKx1xHxW 16 | weights = alpha_BK1HW * preserve_ratio # BxKx1xHxW 17 | value_composed = torch.sum(value_BKCHW * weights, dim=1, keepdim=False) # Bx3xHxW 18 | return value_composed, weights 19 | 20 | def get_src_xyz_from_plane_disparity(meshgrid_src_homo, 21 | mpi_disparity_src, 22 | K_src_inv): 23 | """ 24 | 25 | :param meshgrid_src_homo: 3xHxW 26 | :param mpi_disparity_src: BxS 27 | :param K_src_inv: Bx3x3 28 | :return: 29 | """ 30 | B, S = mpi_disparity_src.size() 31 | H, W = meshgrid_src_homo.size(1), meshgrid_src_homo.size(2) 32 | mpi_depth_src = torch.reciprocal(mpi_disparity_src) # BxS 33 | 34 | K_src_inv_Bs33 = K_src_inv.unsqueeze(1).repeat(1, S, 1, 1).reshape(B * S, 3, 3) 35 | 36 | # 3xHxW -> BxSx3xHxW 37 | meshgrid_src_homo = meshgrid_src_homo.unsqueeze(0).unsqueeze(1).repeat(B, S, 1, 1, 1) 38 | meshgrid_src_homo_Bs3N = meshgrid_src_homo.reshape(B * S, 3, -1) 39 | xyz_src = torch.matmul(K_src_inv_Bs33, meshgrid_src_homo_Bs3N) # BSx3xHW 40 | xyz_src = xyz_src.reshape(B, S, 3, H * W) * mpi_depth_src.unsqueeze(2).unsqueeze(3) # BxSx3xHW 41 | xyz_src_BS3HW = xyz_src.reshape(B, S, 3, H, W) 42 | 43 | return xyz_src_BS3HW 44 | 45 | 46 | def get_tgt_xyz_from_plane_disparity(xyz_src_BS3HW, 47 | G_tgt_src): 48 | """ 49 | 50 | :param xyz_src_BS3HW: BxSx3xHxW 51 | :param G_tgt_src: Bx4x4 52 | :return: 53 | """ 54 | B, S, _, H, W = xyz_src_BS3HW.size() 55 | G_tgt_src_Bs33 = G_tgt_src.unsqueeze(1).repeat(1, S, 1, 1).reshape(B*S, 4, 4) 56 | xyz_tgt = transform_G_xyz(G_tgt_src_Bs33, xyz_src_BS3HW.reshape(B*S, 3, H*W)) # Bsx3xHW 57 | xyz_tgt_BS3HW = xyz_tgt.reshape(B, S, 3, H, W) # BxSx3xHxW 58 | return xyz_tgt_BS3HW 59 | 60 | 61 | def render_tgt_rgb_depth(H_sampler: HomographySample, 62 | mpi_rgb_src, 63 | mpi_sigma_src, 64 | mpi_disparity_src, 65 | xyz_tgt_BS3HW, 66 | G_tgt_src, 67 | K_src_inv, K_tgt, 68 | only_render_in_fov=False, 69 | center_top_left=None, 70 | infov_size=512): 71 | """ 72 | :param H_sampler: 73 | :param mpi_rgb_src: BxSx3xHxW 74 | :param mpi_sigma_src: BxSx1xHxW 75 | :param mpi_disparity_src: BxS 76 | :param xyz_tgt_BS3HW: BxSx3xHxW 77 | :param G_tgt_src: Bx4x4 78 | :param K_src_inv: Bx3x3 79 | :param K_tgt: Bx3x3 80 | :return: 81 | """ 82 | B, S, _, H, W = mpi_rgb_src.size() 83 | mpi_depth_src = torch.reciprocal(mpi_disparity_src) # BxS 84 | 85 | # note that here we concat the mpi_src with xyz_tgt, because H_sampler will sample them for tgt frame 86 | # mpi_src is the same in whatever frame, but xyz has to be in tgt frame 87 | mpi_xyz_src = torch.cat((mpi_rgb_src, mpi_sigma_src, xyz_tgt_BS3HW), dim=2) # BxSx(3+1+3)xHxW 88 | 89 | # homography warping of mpi_src into tgt frame 90 | G_tgt_src_Bs44 = G_tgt_src.unsqueeze(1).repeat(1, S, 1, 1).contiguous().reshape(B*S, 4, 4) # Bsx4x4 91 | K_src_inv_Bs33 = K_src_inv.unsqueeze(1).repeat(1, S, 1, 1).contiguous().reshape(B*S, 3, 3) # Bsx3x3 92 | K_tgt_Bs33 = K_tgt.unsqueeze(1).repeat(1, S, 1, 1).contiguous().reshape(B*S, 3, 3) # Bsx3x3 93 | 94 | # BsxCxHxW, BsxHxW 95 | tgt_mpi_xyz_BsCHW, tgt_mask_BsHW = H_sampler.sample(mpi_xyz_src.view(B*S, 7, H, W), 96 | mpi_depth_src.view(B*S), 97 | G_tgt_src_Bs44, 98 | K_src_inv_Bs33, 99 | K_tgt_Bs33) 100 | 101 | # mpi composition 102 | if only_render_in_fov: 103 | tgt_mpi_xyz = tgt_mpi_xyz_BsCHW.view(B, S, 7, H, W) 104 | tgt_rgb_BS3HW = tgt_mpi_xyz[:, :, 0:3, center_top_left[0]:center_top_left[0] + infov_size, center_top_left[1]:center_top_left[1] + infov_size] 105 | tgt_sigma_BS1HW = tgt_mpi_xyz[:, :, 3:4, center_top_left[0]:center_top_left[0] + infov_size, center_top_left[1]:center_top_left[1] + infov_size] 106 | tgt_xyz_BS3HW = tgt_mpi_xyz[:, :, 4:, center_top_left[0]:center_top_left[0] + infov_size, center_top_left[1]:center_top_left[1] + infov_size] 107 | H = W = infov_size 108 | else: 109 | tgt_mpi_xyz = tgt_mpi_xyz_BsCHW.view(B, S, 7, H, W) 110 | tgt_rgb_BS3HW = tgt_mpi_xyz[:, :, 0:3, :, :] 111 | tgt_sigma_BS1HW = tgt_mpi_xyz[:, :, 3:4, :, :] 112 | tgt_xyz_BS3HW = tgt_mpi_xyz[:, :, 4:, :, :] 113 | 114 | # Bx3xHxW, Bx1xHxW, Bx1xHxW 115 | tgt_z_BS1HW = tgt_xyz_BS3HW[:, :, -1:] 116 | tgt_sigma_BS1HW = torch.where(tgt_z_BS1HW >= 0, 117 | tgt_sigma_BS1HW, 118 | torch.zeros_like(tgt_sigma_BS1HW, device=tgt_sigma_BS1HW.device)) 119 | tgt_rgb_syn = render(tgt_rgb_BS3HW, tgt_sigma_BS1HW, tgt_xyz_BS3HW) 120 | 121 | return tgt_rgb_syn 122 | -------------------------------------------------------------------------------- /mpi/rendering_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def transform_G_xyz(G, xyz, is_return_homo=False): 5 | """ 6 | 7 | :param G: Bx4x4 8 | :param xyz: Bx3xN 9 | :return: 10 | """ 11 | assert len(G.size()) == len(xyz.size()) 12 | if len(G.size()) == 2: 13 | G_B44 = G.unsqueeze(0) 14 | xyz_B3N = xyz.unsqueeze(0) 15 | else: 16 | G_B44 = G 17 | xyz_B3N = xyz 18 | xyz_B4N = torch.cat((xyz_B3N, torch.ones_like(xyz_B3N[:, 0:1, :])), dim=1) 19 | G_xyz_B4N = torch.matmul(G_B44, xyz_B4N) 20 | if is_return_homo: 21 | return G_xyz_B4N 22 | else: 23 | return G_xyz_B4N[:, 0:3, :] 24 | 25 | 26 | def gather_pixel_by_pxpy(img, pxpy): 27 | """ 28 | 29 | :param img: Bx3xHxW 30 | :param pxpy: Bx2xN 31 | :return: 32 | """ 33 | with torch.no_grad(): 34 | B, C, H, W = img.size() 35 | if pxpy.dtype == torch.float32: 36 | pxpy_int = torch.round(pxpy).to(torch.int64) 37 | pxpy_int = pxpy_int.to(torch.int64) 38 | pxpy_int[:, 0, :] = torch.clamp(pxpy_int[:, 0, :], min=0, max=W-1) 39 | pxpy_int[:, 1, :] = torch.clamp(pxpy_int[:, 1, :], min=0, max=H-1) 40 | pxpy_idx = pxpy_int[:, 0:1, :] + W * pxpy_int[:, 1:2, :] # Bx1xN_pt 41 | rgb = torch.gather(img.view(B, C, H * W), dim=2, 42 | index=pxpy_idx.repeat(1, C, 1)) # BxCxN_pt 43 | return rgb 44 | 45 | 46 | def uniformly_sample_disparity_from_bins(batch_size, disparity_np, device): 47 | """ 48 | In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far) 49 | :param start: 50 | :param end: 51 | :param num_bins: 52 | :return: 53 | """ 54 | assert disparity_np[0] > disparity_np[-1] 55 | S = disparity_np.shape[0] - 1 56 | 57 | B = batch_size 58 | bin_edges = torch.from_numpy(disparity_np).to(dtype=torch.float32, device=device) # S+1 59 | interval = bin_edges[1:] - bin_edges[0:-1] # S 60 | bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS 61 | # bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS 62 | interval = interval.unsqueeze(0).repeat(B, 1) # S -> BxS 63 | 64 | random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS 65 | disparity_array = bin_edges_start + interval * random_float 66 | return disparity_array # BxS 67 | 68 | 69 | def uniformly_sample_disparity_from_linspace_bins(batch_size, num_bins, start, end, device): 70 | """ 71 | In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far) 72 | :param start: 73 | :param end: 74 | :param num_bins: 75 | :return: 76 | """ 77 | assert start > end 78 | 79 | B, S = batch_size, num_bins 80 | bin_edges = torch.linspace(start, end, num_bins+1, dtype=torch.float32, device=device) # S+1 81 | interval = bin_edges[1] - bin_edges[0] # scalar 82 | bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS 83 | # bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS 84 | 85 | random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS 86 | disparity_array = bin_edges_start + interval * random_float 87 | return disparity_array # BxS 88 | 89 | 90 | def sample_pdf(values, weights, N_samples): 91 | """ 92 | draw samples from distribution approximated by values and weights. 93 | the probability distribution can be denoted as weights = p(values) 94 | :param values: Bx1xNxS 95 | :param weights: Bx1xNxS 96 | :param N_samples: number of sample to draw 97 | :return: 98 | """ 99 | B, N, S = weights.size(0), weights.size(2), weights.size(3) 100 | assert values.size() == (B, 1, N, S) 101 | 102 | # convert values to bin edges 103 | bin_edges = (values[:, :, :, 1:] + values[:, :, :, :-1]) * 0.5 # Bx1xNxS-1 104 | bin_edges = torch.cat((values[:, :, :, 0:1], 105 | bin_edges, 106 | values[:, :, :, -1:]), dim=3) # Bx1xNxS+1 107 | 108 | pdf = weights / (torch.sum(weights, dim=3, keepdim=True) + 1e-5) # Bx1xNxS 109 | cdf = torch.cumsum(pdf, dim=3) # Bx1xNxS 110 | cdf = torch.cat((torch.zeros((B, 1, N, 1), dtype=cdf.dtype, device=cdf.device), 111 | cdf), dim=3) # Bx1xNxS+1 112 | 113 | # uniform sample over the cdf values 114 | u = torch.rand((B, 1, N, N_samples), dtype=weights.dtype, device=weights.device) # Bx1xNxN_samples 115 | 116 | # get the index on the cdf array 117 | cdf_idx = torch.searchsorted(cdf, u, right=True) # Bx1xNxN_samples 118 | cdf_idx_lower = torch.clamp(cdf_idx-1, min=0) # Bx1xNxN_samples 119 | cdf_idx_upper = torch.clamp(cdf_idx, max=S) # Bx1xNxN_samples 120 | 121 | # linear approximation for each bin 122 | cdf_idx_lower_upper = torch.cat((cdf_idx_lower, cdf_idx_upper), dim=3) # Bx1xNx(N_samplesx2) 123 | cdf_bounds_N2 = torch.gather(cdf, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2) 124 | cdf_bounds = torch.stack((cdf_bounds_N2[..., 0:N_samples], cdf_bounds_N2[..., N_samples:]), dim=4) 125 | bin_bounds_N2 = torch.gather(bin_edges, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2) 126 | bin_bounds = torch.stack((bin_bounds_N2[..., 0:N_samples], bin_bounds_N2[..., N_samples:]), dim=4) 127 | 128 | # avoid zero cdf_intervals 129 | cdf_intervals = cdf_bounds[:, :, :, :, 1] - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples 130 | bin_intervals = bin_bounds[:, :, :, :, 1] - bin_bounds[:, :, :, :, 0] # Bx1xNxN_samples 131 | u_cdf_lower = u - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples 132 | # there is the case that cdf_interval = 0, caused by the cdf_idx_lower/upper clamp above, need special handling 133 | t = u_cdf_lower / torch.clamp(cdf_intervals, min=1e-5) 134 | t = torch.where(cdf_intervals <= 1e-4, 135 | torch.full_like(u_cdf_lower, 0.5), 136 | t) 137 | 138 | samples = bin_bounds[:, :, :, :, 0] + t*bin_intervals 139 | return samples 140 | -------------------------------------------------------------------------------- /outpaint_rgbd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import random 4 | import os 5 | import torchvision 6 | import torchvision.transforms as T 7 | 8 | # pytorch-lightning 9 | from pytorch_lightning import LightningModule, Trainer 10 | 11 | from dataloaders.single_img_data import SinImgDataset 12 | from diffusers import StableDiffusionInpaintPipeline 13 | from utils_for_train import tensor_to_depth 14 | from pytorch_lightning import seed_everything 15 | 16 | class SDOutpainter(LightningModule): 17 | def __init__(self, opt): 18 | super(SDOutpainter, self).__init__() 19 | 20 | self.opt = opt 21 | self.loss = [] 22 | 23 | W, H = self.opt.width, self.opt.height 24 | self.save_base_dir = f'ckpts/{opt.ckpt_path}' 25 | if not os.path.exists(self.save_base_dir): 26 | os.makedirs(self.save_base_dir) 27 | 28 | self.dataset_type = 'SinImgDataset' 29 | self.train_dataset = SinImgDataset(img_path=self.opt.img_path, width=W, height=H, repeat_times=1) 30 | self.extrapolate_times = self.opt.extrapolate_times 31 | 32 | if self.extrapolate_times == 3: # extend w = 3 * w 33 | self.center_top_left = (self.opt.height, self.opt.width) 34 | elif self.extrapolate_times == 2: # extend w = 2 * w 35 | self.center_top_left = (self.opt.height//2, self.opt.width//2) 36 | elif self.extrapolate_times == 1: 37 | self.center_top_left = (0, 0) 38 | 39 | with torch.no_grad(): 40 | self.sd = StableDiffusionInpaintPipeline.from_pretrained( 41 | "stabilityai/stable-diffusion-2-inpainting", 42 | torch_dtype=torch.float16, 43 | local_files_only=True, 44 | use_auth_token="" 45 | ).to("cuda:0") 46 | self.extrapolate_RGBDs = self.gen_extrapolate_RGBDs() 47 | torch.save(self.extrapolate_RGBDs, self.save_base_dir + "/" + "extrapolate_RGBDs.pkl") 48 | 49 | def gen_extrapolate_RGBDs(self): 50 | 51 | self.prompt = ["continuous sky, without animal, without text, without copy", #for up 52 | "continuous scene, without text, without copy", #for mid 53 | "continuous sea, without text" #for down 54 | ] 55 | 56 | ref_img = self.train_dataset.ref_img.cpu() 57 | depth = tensor_to_depth(ref_img.cuda()).cpu() 58 | 59 | ref_depth = depth 60 | 61 | rgbd = (ref_img.cuda(), ref_depth.cuda()) 62 | 63 | _,_,h,w = ref_img.shape 64 | 65 | canvas = torch.zeros(1,3,h*self.extrapolate_times,w*self.extrapolate_times) 66 | mask = torch.zeros(1,1,h*self.extrapolate_times,w*self.extrapolate_times) 67 | 68 | if self.extrapolate_times == 3: # extend w = 3 * w 69 | top_left_points = [ 70 | (h//2,w), (0,w), #top 71 | (h + h//2,w), (h + h,w), #down 72 | (h,w//2),(h,0), #left 73 | (h, w + w//2), (h, w + w), #right 74 | 75 | (h//2,w//2), (0,w//2), (h//2,0) ,(0,0), #top left 76 | (h + h//2,w//2), (h + h//2,0), (h + h,w//2), (h + h,0), #down left 77 | (h//2,w + w//2), (0,w + w//2), (h//2,w + w), (0 ,w + w), #top right 78 | (h + h//2, w + w//2), (h + h//2, w + w), (h + h, w + w//2), (h + h, w + w), #down right 79 | ] 80 | up = [0,1,8,9,10,11,16,17,18,19] 81 | mid = [4,5,6,7] 82 | down = [2,3,12,13,14,15,20,21,22,23] 83 | elif self.extrapolate_times == 2: # extend w = 2 * w 84 | top_left_points = [ 85 | (0,w//2), #top 86 | (h,w//2), #down 87 | (h//2,0), #left 88 | (h//2,w), #right 89 | 90 | (0,0), #top left 91 | (h,0), #down left 92 | (0,w), #top right 93 | (h,w), #down right 94 | ] 95 | up = [0,4,6] 96 | mid = [2,3] 97 | down = [1,3,5,7] 98 | elif self.extrapolate_times == 1: 99 | top_left_points = [] 100 | # return rgbd 101 | 102 | canvas[:,:,self.center_top_left[0]:self.center_top_left[0] + h, self.center_top_left[1]:self.center_top_left[1] + w] = ref_img 103 | mask[:,:,self.center_top_left[0]:self.center_top_left[0] + h, self.center_top_left[1]:self.center_top_left[1] + w] = torch.ones(1,1,h,w) 104 | 105 | for i, point in enumerate(top_left_points): 106 | canvas_part = canvas[:,:,point[0]:point[0]+ h, point[1]:point[1]+ w] 107 | mask_part = mask[:,:,point[0]:point[0]+ h, point[1]:point[1]+ w] 108 | if i in up: 109 | prompt = self.prompt[0] 110 | elif i in mid: 111 | prompt = self.prompt[1] 112 | else: 113 | prompt = self.prompt[2] 114 | canvas[:,:,point[0]:point[0]+ h, point[1]:point[1]+ w] = self.run_sd(canvas_part, mask_part, prompt, h, w) 115 | mask[:,:,point[0]:point[0]+ h, point[1]:point[1]+ w] = torch.ones(1,1,h,w) 116 | 117 | depth = tensor_to_depth(canvas.cuda()) 118 | 119 | align_depth = True 120 | if align_depth: 121 | extrapolate_depth = depth 122 | extrapolate_center_depth = extrapolate_depth[:,:,self.center_top_left[0]:self.center_top_left[0] + h, self.center_top_left[1]:self.center_top_left[1] + w] 123 | # align depth with ref_depth 124 | depth[:,:,:,:] = (depth - extrapolate_center_depth.min())/(extrapolate_center_depth.max() - extrapolate_center_depth.min()) * (ref_depth.max() - ref_depth.min()) + ref_depth.min() 125 | 126 | extrapolate_RGBDs = (canvas.cuda(), depth.cuda()) 127 | torchvision.utils.save_image(canvas[0], self.save_base_dir + "/" + "canvas.png") 128 | torchvision.utils.save_image(depth[0], self.save_base_dir + "/" + "canvas_depth.png") 129 | return extrapolate_RGBDs 130 | 131 | def run_sd(self, canvas, mask, prompt, w, h): 132 | # Run sd 133 | # prompt = "room" 134 | transform = T.ToPILImage() 135 | warp_rgb_PIL = transform(canvas[0,...]).convert("RGB").resize((512, 512)) 136 | warp_mask_PIL = transform(255 * (1 - mask[0,...].to(torch.int32))).convert("RGB").resize((512, 512)) 137 | inpainted_warp_image = self.sd(prompt=prompt, image=warp_rgb_PIL, mask_image=warp_mask_PIL).images[0] 138 | inpainted_warp_image = inpainted_warp_image.resize((h,w)) 139 | inpainted_warp_image = T.ToTensor()(inpainted_warp_image).unsqueeze(0) 140 | inpainted_warp_image = canvas * mask + inpainted_warp_image * (1 - mask) 141 | return inpainted_warp_image 142 | 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 146 | parser.add_argument('--img_path', type=str, default="images/0810.png") 147 | parser.add_argument('--disp_path', type=str, default="images/depth/0810.png") 148 | parser.add_argument('--width', type=int, default=384) 149 | parser.add_argument('--height', type=int, default=256) 150 | parser.add_argument('--ckpt_path', type=str, default="ExpX") 151 | parser.add_argument('--debugging', default=False, action="store_true") 152 | parser.add_argument('--extrapolate_times', type=int, default=1) 153 | 154 | opt, _ = parser.parse_known_args() 155 | 156 | seed = 50 157 | seed_everything(seed) 158 | 159 | sd_outpainter = SDOutpainter(opt) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.19.0 2 | diffusers==0.16.1 3 | easydict==1.10 4 | efficientnet-pytorch==0.7.1 5 | einops==0.6.1 6 | huggingface-hub==0.14.1 7 | imageio==2.28.1 8 | imageio-ffmpeg==0.4.8 9 | lightning-utilities==0.8.0 10 | moviepy==1.0.3 11 | numpy==1.21.0 12 | open-clip-torch==2.20.0 13 | open3d==0.17.0 14 | opencv-python==4.7.0.72 15 | Pillow==10.0.0 16 | pytorch-lightning==1.4.2 17 | pytorch3d==0.7.4 18 | timm==0.9.2 19 | torch==1.11.0 20 | torchmetrics==0.5.0 21 | torchvision==0.12.0 22 | tornado==6.2 23 | transformers==4.29.2 24 | trimesh==3.21.7 25 | -------------------------------------------------------------------------------- /scripts/train_all.sh: -------------------------------------------------------------------------------- 1 | #conda activate SinMPI 2 | cuda=0 3 | width=512 4 | height=512 5 | data_name='Syndney' 6 | extrapolate_times=2 7 | 8 | ckpt_path='Exp-Syndney' 9 | img_path=$data_name'.jpg' 10 | 11 | CUDA_VISIBLE_DEVICES=$cuda python outpaint_rgbd.py \ 12 | --width $width \ 13 | --height $height \ 14 | --ckpt_path $ckpt_path \ 15 | --img_path $img_path \ 16 | --extrapolate_times $extrapolate_times 17 | 18 | CUDA_VISIBLE_DEVICES=$cuda python train_inpainting.py \ 19 | --width $width \ 20 | --height $height \ 21 | --ckpt_path $ckpt_path \ 22 | --img_path $img_path \ 23 | --num_epochs 10 \ 24 | --extrapolate_times $extrapolate_times \ 25 | --batch_size 1 #--load_warp_pairs --debugging 26 | 27 | CUDA_VISIBLE_DEVICES=$cuda python train_mpi.py \ 28 | --width $width \ 29 | --height $height \ 30 | --ckpt_path $ckpt_path \ 31 | --img_path $img_path \ 32 | --num_epochs 10 \ 33 | --extrapolate_times $extrapolate_times \ 34 | --batch_size 1 #--debugging #--resume -------------------------------------------------------------------------------- /test_images/Syndney.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/SinMPI/e14a8232ea1c0e6b6680a49a89cc0ce6f6e2409f/test_images/Syndney.jpg -------------------------------------------------------------------------------- /train_inpainting.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn.functional as F 4 | import math 5 | import random 6 | 7 | from pytorch_lightning import LightningModule, Trainer 8 | from torch.utils.data import DataLoader 9 | 10 | from dataloaders.single_img_data import SinImgDataset 11 | from model.Inpainter import InpaintingModule 12 | from torchvision.utils import save_image 13 | from utils_for_train import VGGPerceptualLoss 14 | from model.VitExtractor import VitExtractor 15 | from utils_for_train import tensor_to_depth 16 | from warpback.utils import ( 17 | RGBDRenderer, 18 | image_to_tensor, 19 | transformation_from_parameters, 20 | ) 21 | from pytorch_lightning import seed_everything 22 | 23 | class TrainInpaintingModule(LightningModule): 24 | def __init__(self, opt): 25 | super(TrainInpaintingModule, self).__init__() 26 | 27 | self.opt = opt 28 | self.loss = [] 29 | 30 | W, H = self.opt.width, self.opt.height 31 | self.save_base_dir = f'ckpts/{opt.ckpt_path}' 32 | 33 | if self.opt.resume: 34 | self.inpaint_module = InpaintingModule() 35 | self.inpaint_module.load_state_dict(torch.load(f'ckpts/{self.opt.ckpt_path}/inpaint_latest.pt'), strict=True) 36 | else: 37 | self.inpaint_module = InpaintingModule() 38 | 39 | self.models = [self.inpaint_module.cuda()] 40 | 41 | # for training 42 | self.extrapolate_times = self.opt.extrapolate_times 43 | self.train_dataset = SinImgDataset(img_path=self.opt.img_path, width=W, height=H, repeat_times=1) 44 | 45 | if self.extrapolate_times == 3: # extend w = 3 * w 46 | self.center_top_left = (self.opt.height, self.opt.width) 47 | elif self.extrapolate_times == 2: # extend w = 2 * w 48 | self.center_top_left = (self.opt.height//2, self.opt.width//2) 49 | elif self.extrapolate_times == 1: 50 | self.center_top_left = (0, 0) 51 | 52 | self.K = torch.tensor([ 53 | [0.58, 0, 0.5], 54 | [0, 0.58, 0.5], 55 | [0, 0, 1] 56 | ]) 57 | 58 | with torch.no_grad(): 59 | if self.extrapolate_times == 1: 60 | ref_img = image_to_tensor(self.save_base_dir + "/" + "canvas.png", unsqueeze=False) # [3,h,w] 61 | ref_img = ref_img.unsqueeze(0).cuda() 62 | if ref_img.shape[1] == 4: 63 | ref_img = ref_img[:,:3,:,:] 64 | 65 | ref_depth = tensor_to_depth(ref_img) 66 | save_image(ref_depth[0,0,...], self.save_base_dir + "/" + "canvas_depth.png") 67 | else: 68 | ref_img, ref_depth = torch.load(self.save_base_dir + "/" + "extrapolate_RGBDs.pkl") 69 | 70 | ref_depth = (ref_depth - ref_depth.min())/(ref_depth.max() - ref_depth.min()) 71 | 72 | self.extrapolate_RGBDs = (ref_img.cpu(), ref_depth.cpu()) 73 | 74 | if self.opt.load_warp_pairs: 75 | self.inpaint_pairs = torch.load(self.save_base_dir + "/" + "inpaint_pairs.pkl") 76 | else: 77 | self.renderer = RGBDRenderer('cuda:0') 78 | self.inpaint_pairs = self.get_pairs() 79 | torch.save(self.inpaint_pairs,self.save_base_dir + "/" + "inpaint_pairs.pkl") 80 | 81 | self.perceptual_loss = VGGPerceptualLoss() 82 | self.VitExtractor = VitExtractor( 83 | model_name='dino_vits16', device='cuda:0') 84 | 85 | self.renderer_pair_saved = False 86 | 87 | 88 | def configure_optimizers(self): 89 | from torch.optim import SGD, Adam 90 | 91 | parameters = [] 92 | for model in self.models: 93 | parameters += list(model.parameters()) 94 | self.optimizer = Adam(parameters, lr=5e-4, eps=1e-8, weight_decay=0) 95 | 96 | return [self.optimizer], [] 97 | 98 | 99 | def train_dataloader(self): 100 | return DataLoader(self.train_dataset, 101 | shuffle=True, 102 | num_workers=8, 103 | batch_size=self.opt.batch_size, 104 | pin_memory=True) 105 | 106 | 107 | def get_rand_ext(self, bs=1): 108 | def rand_tensor(r, l): 109 | if r < 0: 110 | return torch.zeros((l, 1, 1)) 111 | rand = torch.rand((l, 1, 1)) 112 | sign = 2 * (torch.randn_like(rand) > 0).float() - 1 113 | return sign * (r / 2 + r / 2 * rand) 114 | 115 | trans_range={"x":0.2, "y":-0.2, "z":-0.2, "a":-0.2, "b":-0.2, "c":-0.2} 116 | x, y, z = trans_range['x'], trans_range['y'], trans_range['z'] 117 | a, b, c = trans_range['a'], trans_range['b'], trans_range['c'] 118 | cix = rand_tensor(x, bs) 119 | ciy = rand_tensor(y, bs) 120 | ciz = rand_tensor(z, bs) 121 | aix = rand_tensor(math.pi / a, bs) 122 | aiy = rand_tensor(math.pi / b, bs) 123 | aiz = rand_tensor(math.pi / c, bs) 124 | 125 | axisangle = torch.cat([aix, aiy, aiz], dim=-1) # [b,1,3] 126 | translation = torch.cat([cix, ciy, ciz], dim=-1) 127 | 128 | cam_ext = transformation_from_parameters(axisangle, translation) # [b,4,4] 129 | cam_ext_inv = torch.inverse(cam_ext) # [b,4,4] 130 | return cam_ext, cam_ext_inv 131 | 132 | def get_pairs(self): 133 | all_poses = self.train_dataset.all_poses 134 | 135 | aug_pose_factor = 0 # set pose augmentation for better results 136 | cnt = len(all_poses) 137 | if aug_pose_factor > 0: 138 | for i in range(cnt): 139 | cur_pose = torch.FloatTensor(all_poses[i]) 140 | for _ in range(aug_pose_factor): 141 | cam_ext, cam_ext_inv = self.get_rand_ext() # [b,4,4] 142 | cur_aug_pose = torch.matmul(cam_ext, cur_pose) 143 | all_poses += [cur_aug_pose] 144 | 145 | ref_depth = self.extrapolate_RGBDs[1] 146 | 147 | ref_img = self.extrapolate_RGBDs[0] 148 | W, H = self.opt.width * self.extrapolate_times, self.opt.height * self.extrapolate_times 149 | 150 | inpaint_pairs = [] #(warp_back_image, warp_back_disp, warp_back_mask, ref_img, ref_depth) 151 | val_pairs = [] #(cam_ext, ref_img, warp_image, warp_disp, warp_mask, gt_img) 152 | 153 | print("all_poses len:",len(all_poses)) 154 | for i, cur_pose in enumerate(all_poses[:]): 155 | cur_pose = all_poses[i] 156 | c2w = cur_pose 157 | c2w = torch.FloatTensor(c2w) 158 | 159 | cam_int = self.K.repeat(1, 1, 1) # [b,3,3] 160 | 161 | #load cam_ext 162 | cam_ext = c2w 163 | cam_ext_inv = torch.inverse(cam_ext) 164 | cam_ext = cam_ext.repeat(1, 1, 1)[:,:-1,:] 165 | cam_ext_inv = cam_ext_inv.repeat(1, 1, 1)[:,:-1,:] 166 | 167 | rgbd = torch.cat([ref_img, ref_depth], dim=1).cuda() 168 | cam_int = cam_int.cuda() 169 | cam_ext = cam_ext.cuda() 170 | cam_ext_inv = cam_ext_inv.cuda() 171 | 172 | # warp to a random novel view 173 | mesh = self.renderer.construct_mesh(rgbd, cam_int) 174 | warp_image, warp_disp, warp_mask = self.renderer.render_mesh(mesh, cam_int, cam_ext) 175 | 176 | # warp back to the original view 177 | warp_rgbd = torch.cat([warp_image, warp_disp], dim=1) # [b,4,h,w] 178 | warp_mesh = self.renderer.construct_mesh(warp_rgbd, cam_int) 179 | warp_back_image, warp_back_disp, warp_back_mask = self.renderer.render_mesh(warp_mesh, cam_int, cam_ext_inv) 180 | 181 | ref_depth_2 = ref_depth 182 | # all depth should be in [0~1] 183 | inpaint_pairs.append((ref_img, ref_depth_2, cur_pose, 184 | warp_image, warp_disp, warp_mask, 185 | warp_back_image, warp_back_disp, warp_back_mask)) 186 | print("collecting inpaint_pairs:", len(inpaint_pairs)) 187 | 188 | return inpaint_pairs 189 | 190 | 191 | def forward(self, renderer_pair): 192 | (ref_img, ref_depth, cur_pose, 193 | warp_rgb, warp_disp, warp_mask, 194 | warp_back_image, warp_back_disp, warp_back_mask) = renderer_pair 195 | 196 | inpainted_warp_image, inpainted_warp_disp = self.inpaint_module(warp_rgb.cuda(), warp_disp.cuda(), warp_mask.cuda()) 197 | inpainted_warp_back_image, inpainted_warp_back_disp = self.inpaint_module(warp_back_image.cuda(), warp_back_disp.cuda(), warp_back_mask.cuda()) 198 | 199 | return { 200 | "ref_img": ref_img, 201 | "ref_depth": ref_depth, 202 | "warp_image": warp_rgb, 203 | "warp_disp": warp_disp, 204 | "inpainted_warp_image":inpainted_warp_image, 205 | "inpainted_warp_disp":inpainted_warp_disp, 206 | "warp_back_image": warp_back_image, 207 | "warp_back_disp": warp_back_disp, 208 | "inpainted_warp_back_image":inpainted_warp_back_image, 209 | "inpainted_warp_back_disp":inpainted_warp_back_disp, 210 | } 211 | 212 | 213 | def training_step(self, batch, batch_idx, optimizer_idx=0): 214 | renderer_pair = random.choice(self.inpaint_pairs) 215 | batch = self(renderer_pair) 216 | 217 | ref_img = batch["ref_img"].cuda() 218 | ref_depth = batch["ref_depth"].cuda() 219 | 220 | warp_image = batch["warp_image"] 221 | warp_disp = batch["warp_disp"] 222 | 223 | inpainted_warp_image = batch["inpainted_warp_image"] 224 | inpainted_warp_disp = batch["inpainted_warp_disp"] 225 | 226 | warp_back_image = batch["warp_back_image"] 227 | warp_back_disp = batch["warp_back_disp"] 228 | 229 | inpainted_warp_back_image = batch["inpainted_warp_back_image"] 230 | inpainted_warp_back_disp = batch["inpainted_warp_back_disp"] 231 | 232 | # Losses 233 | loss_total = 0 234 | loss_L1 = 0 235 | lambda_loss_L1 = 10 236 | 237 | loss_perc = 0 238 | lambda_loss_perc = 5 239 | 240 | loss_L1 += F.l1_loss(ref_img, inpainted_warp_back_image) 241 | loss_total += loss_L1 * lambda_loss_L1 242 | 243 | loss_perc += self.perceptual_loss(inpainted_warp_image, ref_img) + self.perceptual_loss(inpainted_warp_back_image, ref_img) 244 | loss_total += loss_perc * lambda_loss_perc 245 | 246 | loss_inpainted_vit = 1e-1 247 | lambda_loss_vit = 0 248 | ref_vit_feature = self.get_vit_feature(ref_img) 249 | inpainted_vit_feature = self.get_vit_feature(inpainted_warp_image) 250 | inpainted_warp_back_vit_feature = self.get_vit_feature(inpainted_warp_back_image) 251 | loss_inpainted_vit += F.mse_loss(inpainted_vit_feature, ref_vit_feature) + F.mse_loss(inpainted_warp_back_vit_feature, ref_vit_feature) 252 | loss_total += loss_inpainted_vit * lambda_loss_vit 253 | 254 | loss_depth = 0 255 | lambda_loss_depth = 1 256 | loss_depth += F.l1_loss(ref_depth, inpainted_warp_back_disp) 257 | loss_total += lambda_loss_depth * loss_depth 258 | 259 | if self.opt.debugging: 260 | self.training_epoch_end(None) 261 | assert 0 262 | 263 | return {'loss': loss_total} 264 | 265 | 266 | def training_epoch_end(self, outputs): 267 | with torch.no_grad(): 268 | 269 | # pred_frames = [] 270 | 271 | self.renderer_pairs = [] 272 | for i, inpaint_pair in enumerate(self.inpaint_pairs): 273 | (ref_img, ref_depth, cur_pose, 274 | warp_image, warp_disp, warp_mask, 275 | warp_back_image, warp_back_disp, warp_back_mask) = inpaint_pair 276 | inpainted_warp_image, inpainted_warp_disp = self.inpaint_module(warp_image.cuda(), warp_disp.cuda(), warp_mask.cuda()) 277 | 278 | self.renderer_pairs += [(cur_pose, 279 | ref_img, 280 | inpainted_warp_image)] 281 | 282 | torch.save(self.inpaint_module.state_dict(), self.save_base_dir + "/" +"inpaint_latest.pt") 283 | if not self.renderer_pair_saved: 284 | torch.save(self.renderer_pairs, self.save_base_dir + "/" + "renderer_pairs.pkl") 285 | if self.opt.debugging: 286 | assert 0, "no bug" 287 | return 288 | 289 | def get_vit_feature(self, x): 290 | mean = torch.tensor([0.485, 0.456, 0.406], 291 | device=x.device).reshape(1, 3, 1, 1) 292 | std = torch.tensor([0.229, 0.224, 0.225], 293 | device=x.device).reshape(1, 3, 1, 1) 294 | x = F.interpolate(x, size=(224, 224)) 295 | x = (x - mean) / std 296 | return self.VitExtractor.get_feature_from_input(x)[-1][0, 0, :] 297 | 298 | 299 | if __name__ == '__main__': 300 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 301 | parser.add_argument('--img_path', type=str, default="test_images/Syndney.jpg") 302 | parser.add_argument('--width', type=int, default=512) 303 | parser.add_argument('--height', type=int, default=512) 304 | parser.add_argument('--ckpt_path', type=str, default="Exp-X") 305 | parser.add_argument('--num_epochs', type=int, default=1) 306 | parser.add_argument('--resume_path', type=str, default=None) 307 | parser.add_argument('--resume', default=False, action="store_true") 308 | parser.add_argument('--batch_size', type=int, default=1) 309 | parser.add_argument('--debugging', default=False, action="store_true") 310 | parser.add_argument('--extrapolate_times', type=int, default=1) 311 | parser.add_argument('--load_warp_pairs', default=False, action="store_true") 312 | 313 | opt, _ = parser.parse_known_args() 314 | 315 | seed = 50 316 | seed_everything(seed) 317 | 318 | system = TrainInpaintingModule(opt) 319 | 320 | trainer = Trainer(max_epochs=opt.num_epochs, 321 | progress_bar_refresh_rate=1, 322 | gpus=1, 323 | num_sanity_val_steps=1) 324 | 325 | trainer.fit(system) -------------------------------------------------------------------------------- /train_mpi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn.functional as F 4 | import random 5 | from moviepy.editor import ImageSequenceClip 6 | from utils_for_train import VGGPerceptualLoss 7 | from dataloaders.single_img_data import SinImgDataset 8 | import numpy as np 9 | from pytorch_lightning import LightningModule, Trainer 10 | from torch.utils.data import DataLoader 11 | from model.MPF import MultiPlaneField 12 | from model.VitExtractor import VitExtractor 13 | from mpi.homography_sampler import HomographySample 14 | from mpi import mpi_rendering 15 | from model.TrainableFilter import TrainableFilter 16 | from pytorch_lightning import seed_everything 17 | 18 | class System(LightningModule): 19 | def __init__(self, opt): 20 | super(System, self).__init__() 21 | self.save_base_dir = f'ckpts/{opt.ckpt_path}' 22 | self.opt = opt 23 | self.loss = [] 24 | self.models = [] 25 | W, H = self.opt.width, self.opt.height 26 | 27 | self.extrapolate_times = self.opt.extrapolate_times 28 | self.train_dataset = SinImgDataset(img_path=self.opt.img_path, width=W, height=H, repeat_times=10) 29 | 30 | W, H = self.opt.width * self.extrapolate_times, self.opt.height * self.extrapolate_times 31 | 32 | self.K = torch.tensor([ 33 | [0.58, 0, 0.5], 34 | [0, 0.58, 0.5], 35 | [0, 0, 1] 36 | ]) 37 | self.K[0, :] *= W 38 | self.K[1, :] *= H 39 | self.K = self.K.unsqueeze(0) 40 | 41 | if self.extrapolate_times == 3: # extend w = 3 * w 42 | self.center_top_left = (self.opt.height, self.opt.width) 43 | elif self.extrapolate_times == 2: # extend w = 2 * w 44 | self.center_top_left = (self.opt.height//2, self.opt.width//2) 45 | elif self.extrapolate_times == 1: 46 | self.center_top_left = (0, 0) 47 | 48 | with torch.no_grad(): 49 | self.extrapolate_RGBDs = torch.load(self.save_base_dir + "/" + "extrapolate_RGBDs.pkl") 50 | img, depth = self.extrapolate_RGBDs 51 | depth = (depth - depth.min())/(depth.max() - depth.min()) 52 | self.extrapolate_RGBDs = (img.cuda(), depth.cuda()) 53 | 54 | # create MPI 55 | self.num_planes = 64 56 | self.MPF = MultiPlaneField(num_planes=self.num_planes, 57 | image_size=(H, W), 58 | assign_origin_planes=self.extrapolate_RGBDs, 59 | depth_range=[self.extrapolate_RGBDs[1].min()+1e-6, self.extrapolate_RGBDs[1].max()]) 60 | self.trainable_filter = TrainableFilter(ksize=3).cuda() 61 | 62 | if self.opt.resume: 63 | self.MPF.load_state_dict(torch.load(f'ckpts/{self.opt.ckpt_path}/MPF_latest.pt'), strict=True) 64 | 65 | self.renderer_pairs = torch.load(self.save_base_dir + "/" + "renderer_pairs.pkl") 66 | 67 | self.models += [self.MPF] 68 | self.models += [self.trainable_filter] 69 | 70 | self.perceptual_loss = VGGPerceptualLoss() 71 | self.VitExtractor = VitExtractor( 72 | model_name='dino_vits16', device='cuda:0') 73 | 74 | def train_dataloader(self): 75 | return DataLoader(self.train_dataset, 76 | shuffle=True, 77 | num_workers=4, 78 | batch_size=self.opt.batch_size, 79 | pin_memory=True) 80 | 81 | def configure_optimizers(self): 82 | from torch.optim import SGD, Adam 83 | from torch.optim.lr_scheduler import MultiStepLR 84 | 85 | parameters = [] 86 | for model in self.models: 87 | parameters += list(model.parameters()) 88 | 89 | self.optimizer = Adam(parameters, lr=5e-4, eps=1e-8, 90 | weight_decay=0) 91 | 92 | scheduler = MultiStepLR(self.optimizer, milestones=[20], 93 | gamma=0.1) 94 | 95 | return [self.optimizer], [scheduler] 96 | 97 | 98 | def training_step(self, batch, batch_idx, optimizer_idx=0): 99 | 100 | renderer_pair = random.choice(self.renderer_pairs) 101 | (cam_ext, ref_img, inpainted_warp_image) = renderer_pair 102 | cam_ext, ref_img, inpainted_warp_image= cam_ext.cuda(), ref_img.cuda(), inpainted_warp_image.cuda() 103 | # Losses 104 | loss = 0 105 | ref_vit_feature = self.get_vit_feature(ref_img) 106 | 107 | # run MPI forward 108 | frames_tensor = self(cam_ext) 109 | 110 | # mpi losses 111 | loss_L1 = F.l1_loss(frames_tensor, inpainted_warp_image) 112 | 113 | lambda_loss_L1 = 10 114 | loss += loss_L1 * lambda_loss_L1 115 | 116 | lambda_loss_vit = 1e-1 117 | frames_vit_feature = self.get_vit_feature(frames_tensor) 118 | loss_vit = F.mse_loss(frames_vit_feature, ref_vit_feature) 119 | loss += loss_vit * lambda_loss_vit 120 | 121 | 122 | if self.opt.debugging: 123 | # loss *= 0 124 | self.training_epoch_end(None) 125 | assert 0 126 | 127 | return {'loss': loss} 128 | 129 | 130 | def forward(self, cam_ext, only_render_in_fov=False): 131 | # mpi_planes[b,s,4,h,w], mpi_disp[b,s] 132 | mpi = self.MPF() 133 | mpi_all_rgb_src = mpi[:, :, 0:3, :, :] 134 | mpi_all_sigma_src = mpi[:, :, 3:4, :, :] 135 | disparity_all_src = self.MPF.planes_disp 136 | 137 | k_tgt = k_src = self.K 138 | k_src_inv = torch.inverse(k_src).cuda() 139 | k_tgt = k_tgt.cuda() 140 | k_src = k_src.cuda() 141 | h, w = mpi.shape[-2:] 142 | homography_sampler = HomographySample(h, w, "cuda:0") 143 | G_tgt_src = cam_ext.cuda() 144 | 145 | xyz_src_BS3HW = mpi_rendering.get_src_xyz_from_plane_disparity( 146 | homography_sampler.meshgrid, 147 | disparity_all_src, 148 | k_src_inv 149 | ) 150 | 151 | xyz_tgt_BS3HW = mpi_rendering.get_tgt_xyz_from_plane_disparity( 152 | xyz_src_BS3HW, 153 | G_tgt_src 154 | ) 155 | 156 | tgt_imgs_syn = mpi_rendering.render_tgt_rgb_depth( 157 | homography_sampler, 158 | mpi_all_rgb_src, 159 | mpi_all_sigma_src, 160 | disparity_all_src, 161 | xyz_tgt_BS3HW, 162 | G_tgt_src, 163 | k_src_inv, 164 | k_tgt, 165 | only_render_in_fov=only_render_in_fov, 166 | center_top_left=self.center_top_left 167 | ) 168 | 169 | tgt_imgs_syn = self.trainable_filter(tgt_imgs_syn) 170 | 171 | return tgt_imgs_syn 172 | 173 | def training_epoch_end(self, outputs): 174 | with torch.no_grad(): 175 | 176 | pred_frames = [] 177 | for i, renderer_pair in enumerate(self.renderer_pairs): 178 | (cam_ext, ref_img, inpainted_warp_image) = renderer_pair 179 | # run MPI forward 180 | frames_tensor = self(cam_ext, only_render_in_fov=True) 181 | cam_ext, ref_img, inpainted_warp_image = cam_ext.cuda(), ref_img.cuda(), inpainted_warp_image.cuda() 182 | 183 | pred_frame_np = frames_tensor.squeeze(0).permute(1, 2, 0).contiguous().detach().cpu().numpy() # [b,h,w,3] 184 | pred_frame_np = np.clip(np.round(pred_frame_np * 255), a_min=0, a_max=255).astype(np.uint8) 185 | pred_frames += [pred_frame_np] 186 | 187 | rgb_clip = ImageSequenceClip(pred_frames, fps=10) 188 | save_path = f'ckpts/{self.opt.ckpt_path}/MPI_rendered_views.mp4' 189 | rgb_clip.write_videofile(save_path, verbose=False, codec='mpeg4', logger=None, bitrate='2000k') 190 | 191 | save_base_dir = f'ckpts/{self.opt.ckpt_path}' 192 | 193 | torch.save(self.MPF.state_dict(), save_base_dir + "/" +"MPF_latest.pt") 194 | 195 | if self.opt.debugging: 196 | assert 0, "no bug" 197 | return 198 | 199 | def get_vit_feature(self, x): 200 | mean = torch.tensor([0.485, 0.456, 0.406], 201 | device=x.device).reshape(1, 3, 1, 1) 202 | std = torch.tensor([0.229, 0.224, 0.225], 203 | device=x.device).reshape(1, 3, 1, 1) 204 | x = F.interpolate(x, size=(224, 224)) 205 | x = (x - mean) / std 206 | return self.VitExtractor.get_feature_from_input(x)[-1][0, 0, :] 207 | 208 | 209 | if __name__ == '__main__': 210 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 211 | parser.add_argument('--img_path', type=str, default="test_images/Syndney.jpg") 212 | parser.add_argument('--width', type=int, default=512) 213 | parser.add_argument('--height', type=int, default=512) 214 | parser.add_argument('--ckpt_path', type=str, default="Exp-X") 215 | parser.add_argument('--num_epochs', type=int, default=1) 216 | parser.add_argument('--resume_path', type=str, default=None) 217 | parser.add_argument('--batch_size', type=int, default=1) 218 | parser.add_argument('--debugging', default=False, action="store_true") 219 | parser.add_argument('--resume', default=False, action="store_true") 220 | parser.add_argument('--extrapolate_times', type=int, default=1) 221 | opt, _ = parser.parse_known_args() 222 | 223 | seed = 50 224 | seed_everything(seed) 225 | 226 | system = System(opt) 227 | 228 | trainer = Trainer(max_epochs=opt.num_epochs, 229 | resume_from_checkpoint=opt.resume_path, 230 | progress_bar_refresh_rate=1, 231 | gpus=1, 232 | num_sanity_val_steps=1) 233 | 234 | trainer.fit(system) -------------------------------------------------------------------------------- /utils_for_train.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageOps 2 | import cv2 3 | import torch 4 | import torchvision 5 | from torchvision import transforms 6 | 7 | class VGGPerceptualLoss(torch.nn.Module): 8 | def __init__(self, resize=True): 9 | super(VGGPerceptualLoss, self).__init__() 10 | blocks = [] 11 | blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) 12 | blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) 13 | blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) 14 | blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) 15 | for bl in blocks: 16 | for p in bl: 17 | p.requires_grad = False 18 | self.blocks = torch.nn.ModuleList(blocks) 19 | self.transform = torch.nn.functional.interpolate 20 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda() 21 | self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda() 22 | self.mean.requires_grad = False 23 | self.std.requires_grad = False 24 | self.resize = resize 25 | 26 | def forward(self, syn_imgs, gt_imgs): 27 | syn_imgs = (syn_imgs - self.mean) / self.std 28 | gt_imgs = (gt_imgs - self.mean) / self.std 29 | if self.resize: 30 | syn_imgs = self.transform(syn_imgs, mode="bilinear", size=(224, 224), 31 | align_corners=False) 32 | gt_imgs = self.transform(gt_imgs, mode="bilinear", size=(224, 224), 33 | align_corners=False) 34 | 35 | loss = 0.0 36 | x = syn_imgs 37 | y = gt_imgs 38 | for block in self.blocks: 39 | with torch.no_grad(): 40 | x = block(x) 41 | y = block(y) 42 | loss += torch.nn.functional.l1_loss(x, y) 43 | return loss 44 | 45 | 46 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'): 47 | value = (image_pred-image_gt)**2 48 | if valid_mask is not None: 49 | value = value[valid_mask] 50 | if reduction == 'mean': 51 | return torch.mean(value) 52 | return value 53 | 54 | def image_to_tensor(img_path, unsqueeze=True): 55 | im = Image.open(img_path).convert('RGB') 56 | if img_path[-3:] == 'jpg': 57 | im = ImageOps.exif_transpose(im) 58 | rgb = transforms.ToTensor()(im) 59 | if unsqueeze: 60 | rgb = rgb.unsqueeze(0) 61 | return rgb 62 | 63 | def disparity_to_tensor(disp_path, unsqueeze=True): 64 | disp = cv2.imread(disp_path, -1) / (2 ** 16 - 1) 65 | disp = torch.from_numpy(disp)[None, ...] 66 | if unsqueeze: 67 | disp = disp.unsqueeze(0) 68 | return disp.float() 69 | 70 | def tensor_to_depth(tensor): # BCHW 71 | model_type = "DPT_Large" 72 | midas_model = torch.hub.load("/home/pug/.cache/torch/hub/MiDaS", model_type, source='local').cuda() 73 | 74 | midas_transforms = torch.hub.load("/home/pug/.cache/torch/hub/MiDaS", "transforms", source='local') 75 | if model_type == "DPT_Large" or model_type == "DPT_Hybrid": 76 | transform = midas_transforms.dpt_transform 77 | else: 78 | transform = midas_transforms.small_transform 79 | 80 | input_batch = tensor 81 | with torch.no_grad(): 82 | prediction = midas_model(input_batch) 83 | output = prediction 84 | output = (output - output.min()) / (output.max() - output.min()) 85 | return output.unsqueeze(0) 86 | -------------------------------------------------------------------------------- /warpback/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | def get_edge_connect(weight_dir): 6 | inpaint_model = InpaintGenerator() 7 | inpaint_model_weight = torch.load(os.path.join(weight_dir, "InpaintingModel_gen.pth")) 8 | inpaint_model.load_state_dict(inpaint_model_weight["generator"]) 9 | inpaint_model.eval() 10 | 11 | edge_model = EdgeGenerator() 12 | edge_model_weight = torch.load(os.path.join(weight_dir, "EdgeModel_gen.pth")) 13 | edge_model.load_state_dict(edge_model_weight["generator"]) 14 | edge_model.eval() 15 | 16 | disp_model = InpaintGenerator(in_channels=2, out_channels=1) 17 | disp_model_weight = torch.load(os.path.join(weight_dir, "InpaintingModel_disp.pth")) 18 | disp_model.load_state_dict(disp_model_weight["generator"]) 19 | disp_model.eval() 20 | return edge_model, inpaint_model, disp_model 21 | 22 | class BaseNetwork(nn.Module): 23 | def __init__(self): 24 | super(BaseNetwork, self).__init__() 25 | 26 | def init_weights(self, init_type='normal', gain=0.02): 27 | ''' 28 | initialize network's weights 29 | init_type: normal | xavier | kaiming | orthogonal 30 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 31 | ''' 32 | 33 | def init_func(m): 34 | classname = m.__class__.__name__ 35 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 36 | if init_type == 'normal': 37 | nn.init.normal_(m.weight.data, 0.0, gain) 38 | elif init_type == 'xavier': 39 | nn.init.xavier_normal_(m.weight.data, gain=gain) 40 | elif init_type == 'kaiming': 41 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 42 | elif init_type == 'orthogonal': 43 | nn.init.orthogonal_(m.weight.data, gain=gain) 44 | 45 | if hasattr(m, 'bias') and m.bias is not None: 46 | nn.init.constant_(m.bias.data, 0.0) 47 | 48 | elif classname.find('BatchNorm2d') != -1: 49 | nn.init.normal_(m.weight.data, 1.0, gain) 50 | nn.init.constant_(m.bias.data, 0.0) 51 | 52 | self.apply(init_func) 53 | 54 | 55 | class InpaintGenerator(BaseNetwork): 56 | def __init__(self, residual_blocks=8, init_weights=True, in_channels=4, out_channels=3): 57 | super(InpaintGenerator, self).__init__() 58 | 59 | self.encoder = nn.Sequential( 60 | nn.ReflectionPad2d(3), 61 | nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0), 62 | nn.InstanceNorm2d(64, track_running_stats=False), 63 | nn.ReLU(True), 64 | 65 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), 66 | nn.InstanceNorm2d(128, track_running_stats=False), 67 | nn.ReLU(True), 68 | 69 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), 70 | nn.InstanceNorm2d(256, track_running_stats=False), 71 | nn.ReLU(True) 72 | ) 73 | 74 | blocks = [] 75 | for _ in range(residual_blocks): 76 | block = ResnetBlock(256, 2) 77 | blocks.append(block) 78 | 79 | self.middle = nn.Sequential(*blocks) 80 | 81 | self.decoder = nn.Sequential( 82 | nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), 83 | nn.InstanceNorm2d(128, track_running_stats=False), 84 | nn.ReLU(True), 85 | 86 | nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), 87 | nn.InstanceNorm2d(64, track_running_stats=False), 88 | nn.ReLU(True), 89 | 90 | nn.ReflectionPad2d(3), 91 | nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=7, padding=0), 92 | ) 93 | 94 | if init_weights: 95 | self.init_weights() 96 | 97 | def forward(self, x): 98 | x = self.encoder(x) 99 | x = self.middle(x) 100 | x = self.decoder(x) 101 | x = (torch.tanh(x) + 1) / 2 102 | 103 | return x 104 | 105 | 106 | class EdgeGenerator(BaseNetwork): 107 | def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True): 108 | super(EdgeGenerator, self).__init__() 109 | 110 | self.encoder = nn.Sequential( 111 | nn.ReflectionPad2d(3), 112 | spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm), 113 | nn.InstanceNorm2d(64, track_running_stats=False), 114 | nn.ReLU(True), 115 | 116 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm), 117 | nn.InstanceNorm2d(128, track_running_stats=False), 118 | nn.ReLU(True), 119 | 120 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm), 121 | nn.InstanceNorm2d(256, track_running_stats=False), 122 | nn.ReLU(True) 123 | ) 124 | 125 | blocks = [] 126 | for _ in range(residual_blocks): 127 | block = ResnetBlock(256, 2, use_spectral_norm=use_spectral_norm) 128 | blocks.append(block) 129 | 130 | self.middle = nn.Sequential(*blocks) 131 | 132 | self.decoder = nn.Sequential( 133 | spectral_norm(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm), 134 | nn.InstanceNorm2d(128, track_running_stats=False), 135 | nn.ReLU(True), 136 | 137 | spectral_norm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), use_spectral_norm), 138 | nn.InstanceNorm2d(64, track_running_stats=False), 139 | nn.ReLU(True), 140 | 141 | nn.ReflectionPad2d(3), 142 | nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, padding=0), 143 | ) 144 | 145 | if init_weights: 146 | self.init_weights() 147 | 148 | def forward(self, x): 149 | x = self.encoder(x) 150 | x = self.middle(x) 151 | x = self.decoder(x) 152 | x = torch.sigmoid(x) 153 | return x 154 | 155 | 156 | class ResnetBlock(nn.Module): 157 | def __init__(self, dim, dilation=1, use_spectral_norm=False): 158 | super(ResnetBlock, self).__init__() 159 | self.conv_block = nn.Sequential( 160 | nn.ReflectionPad2d(dilation), 161 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm), 162 | nn.InstanceNorm2d(dim, track_running_stats=False), 163 | nn.ReLU(True), 164 | 165 | nn.ReflectionPad2d(1), 166 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm), 167 | nn.InstanceNorm2d(dim, track_running_stats=False), 168 | ) 169 | 170 | def forward(self, x): 171 | out = x + self.conv_block(x) 172 | return out 173 | 174 | 175 | def spectral_norm(module, mode=True): 176 | if mode: 177 | return nn.utils.spectral_norm(module) 178 | 179 | return module -------------------------------------------------------------------------------- /warpback/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from PIL import Image 3 | import torch 4 | import torch.nn.functional as F 5 | from torchvision.utils import save_image 6 | from torchvision import transforms 7 | from pytorch3d.renderer.mesh import rasterize_meshes 8 | from pytorch3d.structures import Meshes 9 | from pytorch3d.ops import interpolate_face_attributes 10 | 11 | 12 | class RGBDRenderer: 13 | def __init__(self, device): 14 | self.device = device 15 | self.eps = 1e-4 16 | self.near_z = 1e-4 17 | self.far_z = 1e4 18 | 19 | def render_mesh(self, mesh_dict, cam_int, cam_ext): 20 | ''' 21 | input: 22 | mesh: the output for construct_mesh function 23 | cam_int: [b,3,3] normalized camera intrinsic matrix 24 | cam_ext: [b,3,4] camera extrinsic matrix with the same scale as depth map 25 | camera coord: x to right, z to front, y to down 26 | 27 | output: 28 | render: [b,3,h,w] 29 | disparity: [b,1,h,w] 30 | ''' 31 | vertice = mesh_dict["vertice"] # [b,h*w,3] 32 | faces = mesh_dict["faces"] # [b,nface,3] 33 | attributes = mesh_dict["attributes"] # [b,h*w,4] 34 | h, w = mesh_dict["size"] 35 | 36 | ############ 37 | # to NDC space 38 | vertice_homo = self.lift_to_homo(vertice) # [b,h*w,4] 39 | # [b,1,3,4] x [b,h*w,4,1] = [b,h*w,3,1] 40 | vertice_world = torch.matmul(cam_ext.unsqueeze(1), vertice_homo[..., None]).squeeze(-1) # [b,h*w,3] 41 | vertice_depth = vertice_world[..., -1:] # [b,h*w,1] 42 | attributes = torch.cat([attributes, vertice_depth], dim=-1) # [b,h*w,5] 43 | # [b,1,3,3] x [b,h*w,3,1] = [b,h*w,3,1] 44 | vertice_world_homo = self.lift_to_homo(vertice_world) 45 | persp = self.get_perspective_from_intrinsic(cam_int) # [b,4,4] 46 | 47 | # [b,1,4,4] x [b,h*w,4,1] = [b,h*w,4,1] 48 | vertice_ndc = torch.matmul(persp.unsqueeze(1), vertice_world_homo[..., None]).squeeze(-1) # [b,h*w,4] 49 | vertice_ndc = vertice_ndc[..., :-1] / vertice_ndc[..., -1:] 50 | vertice_ndc[..., :-1] *= -1 51 | vertice_ndc[..., 0] *= w / h 52 | 53 | ############ 54 | # render 55 | mesh = Meshes(vertice_ndc, faces) 56 | pix_to_face, _, bary_coords, _ = rasterize_meshes(mesh, (h, w), faces_per_pixel=1, blur_radius=1e-6) # [b,h,w,1] [b,h,w,1,3] 57 | 58 | b, nf, _ = faces.size() 59 | faces = faces.reshape(b, nf * 3, 1).repeat(1, 1, 5) # [b,3f,5] 60 | face_attributes = torch.gather(attributes, dim=1, index=faces) # [b,3f,5] 61 | face_attributes = face_attributes.reshape(b * nf, 3, 5) 62 | output = interpolate_face_attributes(pix_to_face, bary_coords, face_attributes) 63 | output = output.squeeze(-2).permute(0, 3, 1, 2) 64 | 65 | render = output[:, :3] 66 | mask = output[:, 3:4] 67 | disparity = torch.reciprocal(output[:, 4:] + self.eps) 68 | return render * mask, disparity * mask, mask 69 | 70 | def construct_mesh(self, rgbd, cam_int): 71 | ''' 72 | input: 73 | rgbd: [b,4,h,w] 74 | the first 3 channels for RGB 75 | the last channel for normalized disparity, in range [0,1] 76 | cam_int: [b,3,3] normalized camera intrinsic matrix 77 | 78 | output: 79 | mesh_dict: define mesh in camera space, includes the following keys 80 | vertice: [b,h*w,3] 81 | faces: [b,nface,3] 82 | attributes: [b,h*w,c] include color and mask 83 | ''' 84 | b, _, h, w = rgbd.size() 85 | 86 | ############ 87 | # get pixel coordinates 88 | pixel_2d = self.get_screen_pixel_coord(h, w) # [1,h,w,2] 89 | pixel_2d_homo = self.lift_to_homo(pixel_2d) # [1,h,w,3] 90 | 91 | ############ 92 | # project pixels to 3D space 93 | rgbd = rgbd.permute(0, 2, 3, 1) # [b,h,w,4] 94 | disparity = rgbd[..., -1:] # [b,h,w,1] 95 | depth = torch.reciprocal(disparity + self.eps) # [b,h,w,1] 96 | cam_int_inv = torch.inverse(cam_int) # [b,3,3] 97 | # [b,1,1,3,3] x [1,h,w,3,1] = [b,h,w,3,1] 98 | pixel_3d = torch.matmul(cam_int_inv[:, None, None, :, :], pixel_2d_homo[..., None]).squeeze(-1) # [b,h,w,3] 99 | pixel_3d = pixel_3d * depth # [b,h,w,3] 100 | vertice = pixel_3d.reshape(b, h * w, 3) # [b,h*w,3] 101 | 102 | ############ 103 | # construct faces 104 | faces = self.get_faces(h, w) # [1,nface,3] 105 | faces = faces.repeat(b, 1, 1).long() # [b,nface,3] 106 | 107 | ############ 108 | # compute attributes 109 | attr_color = rgbd[..., :-1].reshape(b, h * w, 3) # [b,h*w,3] 110 | attr_mask = self.get_visible_mask(disparity).reshape(b, h * w, 1) # [b,h*w,1] 111 | attr = torch.cat([attr_color, attr_mask], dim=-1) # [b,h*w,4] 112 | 113 | mesh_dict = { 114 | "vertice": vertice, 115 | "faces": faces, 116 | "attributes": attr, 117 | "size": [h, w], 118 | } 119 | return mesh_dict 120 | 121 | def get_screen_pixel_coord(self, h, w): 122 | ''' 123 | get normalized pixel coordinates on the screen 124 | x to left, y to down 125 | 126 | e.g. 127 | [0,0][1,0][2,0] 128 | [0,1][1,1][2,1] 129 | output: 130 | pixel_coord: [1,h,w,2] 131 | ''' 132 | x = torch.arange(w).to(self.device) # [w] 133 | y = torch.arange(h).to(self.device) # [h] 134 | x = (x + 0.5) / w 135 | y = (y + 0.5) / h 136 | x = x[None, None, ..., None].repeat(1, h, 1, 1) # [1,h,w,1] 137 | y = y[None, ..., None, None].repeat(1, 1, w, 1) # [1,h,w,1] 138 | pixel_coord = torch.cat([x, y], dim=-1) # [1,h,w,2] 139 | return pixel_coord 140 | 141 | def lift_to_homo(self, coord): 142 | ''' 143 | return the homo version of coord 144 | input: coord [..., k] 145 | output: homo_coord [...,k+1] 146 | ''' 147 | ones = torch.ones_like(coord[..., -1:]) 148 | return torch.cat([coord, ones], dim=-1) 149 | 150 | def get_faces(self, h, w): 151 | ''' 152 | get face connect information 153 | x to left, y to down 154 | e.g. 155 | [0,0][1,0][2,0] 156 | [0,1][1,1][2,1] 157 | faces: [1,nface,3] 158 | ''' 159 | x = torch.arange(w - 1).to(self.device) # [w-1] 160 | y = torch.arange(h - 1).to(self.device) # [h-1] 161 | x = x[None, None, ..., None].repeat(1, h - 1, 1, 1) # [1,h-1,w-1,1] 162 | y = y[None, ..., None, None].repeat(1, 1, w - 1, 1) # [1,h-1,w-1,1] 163 | 164 | tl = y * w + x 165 | tr = y * w + x + 1 166 | bl = (y + 1) * w + x 167 | br = (y + 1) * w + x + 1 168 | 169 | faces_l = torch.cat([tl, bl, br], dim=-1).reshape(1, -1, 3) # [1,(h-1)(w-1),3] 170 | faces_r = torch.cat([br, tr, tl], dim=-1).reshape(1, -1, 3) # [1,(h-1)(w-1),3] 171 | 172 | return torch.cat([faces_l, faces_r], dim=1) # [1,nface,3] 173 | 174 | def get_visible_mask(self, disparity, beta=10, alpha_threshold=0.3): 175 | ''' 176 | filter the disparity map using sobel kernel, then mask out the edge (depth discontinuity) 177 | input: 178 | disparity: [b,h,w,1] 179 | 180 | output: 181 | vis_mask: [b,h,w,1] 182 | ''' 183 | b, h, w, _ = disparity.size() 184 | disparity = disparity.reshape(b, 1, h, w) # [b,1,h,w] 185 | kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).unsqueeze(0).unsqueeze(0).float().to(self.device) 186 | kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).unsqueeze(0).unsqueeze(0).float().to(self.device) 187 | sobel_x = F.conv2d(disparity, kernel_x, padding=(1, 1)) # [b,1,h,w] 188 | sobel_y = F.conv2d(disparity, kernel_y, padding=(1, 1)) # [b,1,h,w] 189 | sobel_mag = torch.sqrt(sobel_x ** 2 + sobel_y ** 2).reshape(b, h, w, 1) # [b,h,w,1] 190 | alpha = torch.exp(-1.0 * beta * sobel_mag) # [b,h,w,1] 191 | vis_mask = torch.greater(alpha, alpha_threshold).float() 192 | return vis_mask 193 | 194 | def get_perspective_from_intrinsic(self, cam_int): 195 | ''' 196 | input: 197 | cam_int: [b,3,3] 198 | 199 | output: 200 | persp: [b,4,4] 201 | ''' 202 | fx, fy = cam_int[:, 0, 0], cam_int[:, 1, 1] # [b] 203 | cx, cy = cam_int[:, 0, 2], cam_int[:, 1, 2] # [b] 204 | 205 | one = torch.ones_like(cx) # [b] 206 | zero = torch.zeros_like(cx) # [b] 207 | 208 | near_z, far_z = self.near_z * one, self.far_z * one 209 | a = (near_z + far_z) / (far_z - near_z) 210 | b = -2.0 * near_z * far_z / (far_z - near_z) 211 | 212 | matrix = [[2.0 * fx, zero, 2.0 * cx - 1.0, zero], 213 | [zero, 2.0 * fy, 2.0 * cy - 1.0, zero], 214 | [zero, zero, a, b], 215 | [zero, zero, one, zero]] 216 | # -> [[b,4],[b,4],[b,4],[b,4]] -> [b,4,4] 217 | persp = torch.stack([torch.stack(row, dim=-1) for row in matrix], dim=-2) # [b,4,4] 218 | return persp 219 | 220 | 221 | ####################### 222 | # some helper I/O functions 223 | ####################### 224 | def image_to_tensor(img_path, unsqueeze=True): 225 | rgb = transforms.ToTensor()(Image.open(img_path)) 226 | if unsqueeze: 227 | rgb = rgb.unsqueeze(0) 228 | return rgb 229 | 230 | 231 | def disparity_to_tensor(disp_path, unsqueeze=True): 232 | disp = cv2.imread(disp_path, -1) / (2 ** 16 - 1) 233 | disp = torch.from_numpy(disp)[None, ...] 234 | if unsqueeze: 235 | disp = disp.unsqueeze(0) 236 | return disp.float() 237 | 238 | 239 | ####################### 240 | # some helper geometry functions 241 | # adapt from https://github.com/mattpoggi/depthstillation 242 | ####################### 243 | def transformation_from_parameters(axisangle, translation, invert=False): 244 | R = rot_from_axisangle(axisangle) 245 | t = translation.clone() 246 | 247 | if invert: 248 | R = R.transpose(1, 2) 249 | t *= -1 250 | 251 | T = get_translation_matrix(t) 252 | 253 | if invert: 254 | M = torch.matmul(R, T) 255 | else: 256 | M = torch.matmul(T, R) 257 | 258 | return M 259 | 260 | 261 | def get_translation_matrix(translation_vector): 262 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) 263 | t = translation_vector.contiguous().view(-1, 3, 1) 264 | T[:, 0, 0] = 1 265 | T[:, 1, 1] = 1 266 | T[:, 2, 2] = 1 267 | T[:, 3, 3] = 1 268 | T[:, :3, 3, None] = t 269 | return T 270 | 271 | 272 | def rot_from_axisangle(vec): 273 | angle = torch.norm(vec, 2, 2, True) 274 | axis = vec / (angle + 1e-7) 275 | 276 | ca = torch.cos(angle) 277 | sa = torch.sin(angle) 278 | C = 1 - ca 279 | 280 | x = axis[..., 0].unsqueeze(1) 281 | y = axis[..., 1].unsqueeze(1) 282 | z = axis[..., 2].unsqueeze(1) 283 | 284 | xs = x * sa 285 | ys = y * sa 286 | zs = z * sa 287 | xC = x * C 288 | yC = y * C 289 | zC = z * C 290 | xyC = x * yC 291 | yzC = y * zC 292 | zxC = z * xC 293 | 294 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) 295 | 296 | rot[:, 0, 0] = torch.squeeze(x * xC + ca) 297 | rot[:, 0, 1] = torch.squeeze(xyC - zs) 298 | rot[:, 0, 2] = torch.squeeze(zxC + ys) 299 | rot[:, 1, 0] = torch.squeeze(xyC + zs) 300 | rot[:, 1, 1] = torch.squeeze(y * yC + ca) 301 | rot[:, 1, 2] = torch.squeeze(yzC - xs) 302 | rot[:, 2, 0] = torch.squeeze(zxC - ys) 303 | rot[:, 2, 1] = torch.squeeze(yzC + xs) 304 | rot[:, 2, 2] = torch.squeeze(z * zC + ca) 305 | rot[:, 3, 3] = 1 306 | 307 | return rot 308 | 309 | --------------------------------------------------------------------------------