├── .gitignore ├── LICENSE ├── README.md ├── assets └── teaser.gif ├── config.py ├── configs └── render.txt ├── core ├── __init__.py ├── depth_layering.py ├── inpainter.py ├── pcd.py ├── renderer.py ├── scene_flow.py └── utils.py ├── download.sh ├── environment.yml ├── examples ├── bottle │ ├── disp.npy │ └── input.jpg ├── burger │ ├── disp.npy │ └── input.jpg ├── camera │ ├── disp.npy │ └── input.jpg ├── car │ ├── disp.npy │ └── input.jpg ├── fireworks │ ├── disp.npy │ └── input.jpg ├── rook │ ├── disp.npy │ └── input.jpg └── spoon │ ├── disp.npy │ └── input.png ├── model.py ├── model_3dm.py ├── networks ├── __init__.py ├── img_decoder.py ├── inpainting_nets.py └── resunet.py ├── posenc.py ├── renderer.py ├── runs └── reshader │ ├── dir_model.pth │ └── model.pth └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Avinash Paliwal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReShader 2 | 3 | > ReShader: View-Dependent Highlights for Single Image View-Synthesis 4 | > [Avinash Paliwal](http://avinashpaliwal.com/), 5 | > [Brandon Nguyen](https://brandon.nguyen.vc/about/), 6 | > [Andrii Tsarov](https://www.linkedin.com/in/andrii-tsarov-b8a9bb13), 7 | > [Nima Khademi Kalantari](http://nkhademi.com/) 8 | > SIGGRAPH Asia 2023 (TOG) 9 | 10 | [![Paper](https://img.shields.io/badge/cs.CV-Paper-b31b1b?logo=arxiv&logoColor=red)](https://arxiv.org/abs/2309.10689) 11 | [![Project Page](https://img.shields.io/badge/ReShader-Website-blue?logo=googlechrome&logoColor=blue)](https://people.engr.tamu.edu/nimak/Papers/SIGAsia2023_Reshader) 12 | [![Video](https://img.shields.io/badge/YouTube-Video-c4302b?logo=youtube&logoColor=red)](https://youtu.be/XW-tl48D3Ok) 13 | 14 | --------------------------------------------------- 15 |

16 | 17 | demo 18 | 19 |

20 | 21 | ## Prerequisites 22 | You can setup the anaconda environment using: 23 | ``` 24 | conda env create -f environment.yml 25 | conda activate reshader 26 | ``` 27 | 28 | Download pretrained models. 29 | The following script from [3D Moments](https://github.com/google-research/3d-moments) will download their pretrained models and [RGBD-inpainting networks](https://github.com/vt-vl-lab/3d-photo-inpainting). 30 | ``` 31 | ./download.sh 32 | ``` 33 | 34 | 35 | ## Demos 36 | We provided some examples in the `examples/` folder. You can render novel views with view-dependent highlights using: 37 | 38 | ``` 39 | python renderer.py --input_dir examples/camera/ --config configs/render.txt 40 | ``` 41 | 42 | ## Training 43 | Training code and dataset to be added. 44 | 45 | ## Citation 46 | ``` 47 | @article{paliwal2023reshader, 48 | author = {Paliwal, Avinash and Nguyen, Brandon G. and Tsarov, Andrii and Kalantari, Nima Khademi}, 49 | title = {ReShader: View-Dependent Highlights for Single Image View-Synthesis}, 50 | year = {2023}, 51 | issue_date = {December 2023}, 52 | volume = {42}, 53 | number = {6}, 54 | journal = {ACM Trans. Graph.}, 55 | month = {dec}, 56 | articleno = {216}, 57 | numpages = {9}, 58 | } 59 | ``` 60 | 61 | 62 | ## Acknowledgement 63 | The novel view synthesis part of the code is borrowed from [3D Moments](https://github.com/google-research/3d-moments). -------------------------------------------------------------------------------- /assets/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/assets/teaser.gif -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import configargparse 16 | 17 | 18 | def config_parser(): 19 | parser = configargparse.ArgumentParser() 20 | parser.add_argument('--config', is_config_file=True, help='config file path') 21 | # general 22 | parser.add_argument('--rootdir', type=str, default='./', 23 | help='the path to the project root directory.') 24 | parser.add_argument("--expname", type=str, default='exp', help='experiment name') 25 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 26 | help='number of data loading workers (default: 8)') 27 | parser.add_argument('--distributed', action='store_true', help='if use distributed training') 28 | parser.add_argument("--local_rank", type=int, default=0, help='rank for distributed training') 29 | parser.add_argument("--eval_mode", action='store_true', help='if in eval mode') 30 | 31 | ########## dataset options ########## 32 | # train and eval dataset 33 | parser.add_argument("--train_dataset", type=str, default='vimeo', 34 | help='the training dataset') 35 | parser.add_argument("--dataset_weights", nargs='+', type=float, default=[], 36 | help='the weights for training datasets, used when multiple datasets are used.') 37 | parser.add_argument('--eval_dataset', type=str, default='vimeo', help='the dataset to evaluate') 38 | parser.add_argument("--batch_size", type=int, default=1, help='batch size, currently only support 1') 39 | 40 | ########## network architecture ########## 41 | parser.add_argument("--feature_dim", type=int, default=32, help='the dimension of the extracted features') 42 | 43 | ########## training options ########## 44 | parser.add_argument("--use_inpainting_mask_for_feature", action='store_true') 45 | parser.add_argument("--inpainting", action='store_true', help='if do inpainting') 46 | parser.add_argument("--train_raft", action='store_true', help='if train raft') 47 | parser.add_argument('--boundary_crop_ratio', type=float, default=0, help='crop the image before computing loss') 48 | parser.add_argument("--vary_pts_radius", action='store_true', help='if vary point radius as augmentation') 49 | parser.add_argument("--adaptive_pts_radius", action='store_true', help='if use adaptive point radius') 50 | parser.add_argument("--use_mask_for_decoding", action='store_true', help='if use mask for decoding') 51 | 52 | ########## rendering/evaluation ########## 53 | parser.add_argument("--use_depth_for_feature", action='store_true', 54 | help='if use depth map when extracting features') 55 | parser.add_argument("--use_depth_for_decoding", action='store_true', 56 | help='if use depth map when decoding') 57 | parser.add_argument("--point_radius", type=float, default=1.5, 58 | help='point radius for rasterization') 59 | parser.add_argument("--input_dir", type=str, default='', help='input folder that contains a pair of images') 60 | parser.add_argument("--visualize_rgbda_layers", action='store_true', 61 | help="if visualize rgbda layers, save in out dir") 62 | 63 | ########### iterations & learning rate options & loss ########## 64 | parser.add_argument("--n_iters", type=int, default=250000, help='num of iterations') 65 | parser.add_argument("--lr", type=float, default=3e-4, help='learning rate for feature extractor') 66 | parser.add_argument("--lr_raft", type=float, default=5e-6, help='learning rate for raft') 67 | parser.add_argument("--lrate_decay_factor", type=float, default=0.5, 68 | help='decay learning rate by a factor every specified number of steps') 69 | parser.add_argument("--lrate_decay_steps", type=int, default=50000, 70 | help='decay learning rate by a factor every specified number of steps') 71 | parser.add_argument('--loss_mode', type=str, default='lpips', 72 | help='the loss function to use') 73 | 74 | ########## checkpoints ########## 75 | parser.add_argument("--ckpt_path", type=str, default="", 76 | help='specific weights npy file to reload for coarse network') 77 | parser.add_argument("--no_reload", action='store_true', 78 | help='do not reload weights from saved ckpt') 79 | parser.add_argument("--no_load_opt", action='store_true', 80 | help='do not load optimizer when reloading') 81 | parser.add_argument("--no_load_scheduler", action='store_true', 82 | help='do not load scheduler when reloading') 83 | 84 | ########## logging/saving options ########## 85 | parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin') 86 | parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging') 87 | parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving') 88 | 89 | ############ demo parameters ############## 90 | parser.add_argument("--spec", action='store_true', help='use specular frames') 91 | parser.add_argument("--tung", action='store_true', help='using tungsten depths') 92 | parser.add_argument("--normalize_depth", action='store_true', help='normalize depth when depth map is euclidean distance') 93 | parser.add_argument("--fov", type=float, default=45.0, help='fov of camera') 94 | parser.add_argument("--spd", type=str, default="246", help='spec directory suffix') 95 | parser.add_argument("--dscl", type=int, default=1, help='depth scaling') 96 | 97 | #training schedule 98 | parser.add_argument('--num_iterations', type=int, default=300000, help='total epochs to train') 99 | parser.add_argument('-train_batch_size', type=int, default=10) 100 | parser.add_argument('-val_batch_size', type=int, default=4) 101 | parser.add_argument('-checkpoint', type=int, default=10, help='save checkpoint for every epochs. Be aware that! It will replace the previous checkpoint.') 102 | parser.add_argument('-tb_toc',type=int, default=100, help="print output to terminal for every tb_toc iterations") 103 | 104 | #lr schedule 105 | parser.add_argument('-lr', '--learning_rate', type=float, default=1e-4, help='learning rate of the network') 106 | 107 | #loss 108 | parser.add_argument('-style_coeff', type=float, default=1, help='hyperparameter for style loss') 109 | parser.add_argument('-prcp_coeff', type=float, default=0.01, help='hyperparameter for perceptual loss') 110 | parser.add_argument('-mse_coeff', type=float, default=1.0, help='hyperparameter for MSE loss') 111 | parser.add_argument('-l1_coeff', type=float, default=0.1, help='hyperparameter for L1 loss') 112 | #training and eval data 113 | parser.add_argument('-dataset', type=str, default="/data2/avinash/datasets/specular_fixed/specular/", help='directory to the dataset') 114 | 115 | #training utility 116 | parser.add_argument('--model_dir', type=str, default="unet_prcp_gmm_mse", help='model (scene) directory which store in runs//') 117 | parser.add_argument('-clean', action='store_true', help='delete old weight without start training process') 118 | parser.add_argument('--clip', type=float, default=1.0) 119 | 120 | #model 121 | parser.add_argument('-multi', type=bool, default=True, help='append multi level direction vector') 122 | parser.add_argument('-use_mlp', type=bool, default=False, help='use mlp for feature vector from direction vector') 123 | parser.add_argument('--start_iter',type=int, default=0, help="starting iteration") 124 | parser.add_argument('-basis_out',type=int, default=8, help="num of basis functions") 125 | parser.add_argument('-pos_enc_freq',type=int, default=5, help="num of freqs in positional encoding") 126 | parser.add_argument('--losses', type=str, nargs='+', help='losses to use', default=['mse', 'prcp', 'gmm']) 127 | parser.add_argument('--ckpt', type=str, default=None, help='checkpopint to continue from') 128 | parser.add_argument('--example_index', type=str, default=None, help='example index for testing') 129 | parser.add_argument('--test_root', type=str, default="real_data/", help='test examples root dir') 130 | parser.add_argument('-pad', type=bool, default=False, help='use mlp for feature vector from direction vector') 131 | parser.add_argument('--use_depth_posenc', type=bool, default=False, help='use mlp for feature vector from direction vector') 132 | 133 | 134 | args = parser.parse_args() 135 | return args 136 | 137 | -------------------------------------------------------------------------------- /configs/render.txt: -------------------------------------------------------------------------------- 1 | no_load_opt = True 2 | no_load_scheduler = True 3 | distributed = False 4 | loss_mode = vgg19 5 | train_dataset = tiktok 6 | eval_dataset = jamie 7 | eval_mode = True 8 | 9 | use_depth_for_decoding = True 10 | adaptive_pts_radius = True 11 | train_raft = False 12 | visualize_rgbda_layers = False 13 | 14 | ckpt_path = pretrained/model_250000.pth 15 | 16 | 17 | model_dir = reshader 18 | ckpt = "" 19 | use_depth_posenc = True 20 | dscl = 2 -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/core/__init__.py -------------------------------------------------------------------------------- /core/depth_layering.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | from sklearn.cluster import AgglomerativeClustering 18 | 19 | 20 | def get_depth_bins(depth=None, disparity=None, num_bins=None): 21 | """ 22 | :param depth: [1, 1, H, W] 23 | :param disparity: [1, 1, H, W] 24 | :return: depth_bins 25 | """ 26 | 27 | assert (disparity is not None) or (depth is not None) 28 | if disparity is None: 29 | assert depth.min() > 1e-2 30 | disparity = 1. / depth 31 | 32 | if depth is None: 33 | depth = 1. / torch.clamp(disparity, min=1e-2) 34 | 35 | assert depth.shape[:2] == (1, 1) and disparity.shape[:2] == (1, 1) 36 | disparity_max = disparity.max().item() 37 | disparity_min = disparity.min().item() 38 | disparity_feat = disparity[:, :, ::10, ::10].reshape(-1, 1).cpu().numpy() 39 | disparity_feat = (disparity_feat - disparity_min) / (disparity_max - disparity_min) 40 | if num_bins is None: 41 | n_clusters = None 42 | distance_threshold = 5 43 | else: 44 | n_clusters = num_bins 45 | distance_threshold = None 46 | result = AgglomerativeClustering(n_clusters=n_clusters, distance_threshold=distance_threshold).fit(disparity_feat) 47 | num_bins = result.n_clusters_ if n_clusters is None else n_clusters 48 | depth_bins = [depth.min().item()] 49 | for i in range(num_bins): 50 | th = (disparity_feat[result.labels_ == i]).min() 51 | th = th * (disparity_max - disparity_min) + disparity_min 52 | depth_bins.append(1. / th) 53 | 54 | depth_bins = sorted(depth_bins) 55 | depth_bins[0] = depth.min() - 1e-6 56 | depth_bins[-1] = depth.max() + 1e-6 57 | return depth_bins 58 | 59 | 60 | -------------------------------------------------------------------------------- /core/inpainter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | from kornia.morphology import opening, erosion 18 | from kornia.filters import gaussian_blur2d 19 | from networks.inpainting_nets import Inpaint_Depth_Net, Inpaint_Color_Net 20 | from core.utils import masked_median_blur 21 | 22 | 23 | def refine_near_depth_discontinuity(depth, alpha, kernel_size=11): 24 | ''' 25 | median filtering the depth discontinuity boundary 26 | ''' 27 | depth = depth * alpha 28 | depth_median_blurred = masked_median_blur(depth, alpha, kernel_size=kernel_size) * alpha 29 | alpha_eroded = erosion(alpha, kernel=torch.ones(kernel_size, kernel_size).to(alpha.device)) 30 | depth[alpha_eroded == 0] = depth_median_blurred[alpha_eroded == 0] 31 | return depth 32 | 33 | 34 | def define_inpainting_bbox(alpha, border=40): 35 | ''' 36 | define the bounding box for inpainting 37 | :param alpha: alpha map [1, 1, h, w] 38 | :param border: the minimum distance from a valid pixel to the border of the bbox 39 | :return: [1, 1, h, w], a 0/1 map that indicates the inpainting region 40 | ''' 41 | assert alpha.ndim == 4 and alpha.shape[:2] == (1, 1) 42 | x, y = torch.nonzero(alpha)[:, -2:].T 43 | h, w = alpha.shape[-2:] 44 | row_min, row_max = x.min(), x.max() 45 | col_min, col_max = y.min(), y.max() 46 | out = torch.zeros_like(alpha) 47 | x0, x1 = max(row_min - border, 0), min(row_max + border, h - 1) 48 | y0, y1 = max(col_min - border, 0), min(col_max + border, w - 1) 49 | out[:, :, x0:x1, y0:y1] = 1 50 | return out 51 | 52 | 53 | class Inpainter(): 54 | def __init__(self, args, device='cuda'): 55 | self.args = args 56 | self.device = device 57 | print("Loading depth model...") 58 | depth_feat_model = Inpaint_Depth_Net() 59 | depth_feat_weight = torch.load('inpainting_ckpts/depth-model.pth', map_location=torch.device(device)) 60 | depth_feat_model.load_state_dict(depth_feat_weight) 61 | depth_feat_model = depth_feat_model.to(device) 62 | depth_feat_model.eval() 63 | self.depth_feat_model = depth_feat_model.to(device) 64 | print("Loading rgb model...") 65 | rgb_model = Inpaint_Color_Net() 66 | rgb_feat_weight = torch.load('inpainting_ckpts/color-model.pth', map_location=torch.device(device)) 67 | rgb_model.load_state_dict(rgb_feat_weight) 68 | rgb_model.eval() 69 | self.rgb_model = rgb_model.to(device) 70 | 71 | # kernels 72 | self.context_erosion_kernel = torch.ones(10, 10).to(self.device) 73 | self.alpha_kernel = torch.ones(3, 3).to(self.device) 74 | 75 | @staticmethod 76 | def process_depth_for_network(depth, context, log_depth=True): 77 | if log_depth: 78 | log_depth = torch.log(depth + 1e-8) * context 79 | mean_depth = torch.mean(log_depth[context > 0]) 80 | zero_mean_depth = (log_depth - mean_depth) * context 81 | else: 82 | zero_mean_depth = depth 83 | mean_depth = 0 84 | return zero_mean_depth, mean_depth 85 | 86 | @staticmethod 87 | def deprocess_depth(zero_mean_depth, mean_depth, log_depth=True): 88 | if log_depth: 89 | depth = torch.exp(zero_mean_depth + mean_depth) 90 | else: 91 | depth = zero_mean_depth 92 | return depth 93 | 94 | def inpaint_rgb(self, holes, context, context_rgb, edge): 95 | # inpaint rgb 96 | with torch.no_grad(): 97 | inpainted_rgb = self.rgb_model.forward_3P(holes, context, context_rgb, edge, 98 | unit_length=128, cuda=self.device) 99 | inpainted_rgb = inpainted_rgb.detach() * holes + context_rgb 100 | inpainted_a = holes + context 101 | inpainted_a = opening(inpainted_a, self.alpha_kernel) 102 | inpainted_rgba = torch.cat([inpainted_rgb, inpainted_a], dim=1) 103 | return inpainted_rgba 104 | 105 | def inpaint_depth(self, depth, holes, context, edge, depth_range): 106 | zero_mean_depth, mean_depth = self.process_depth_for_network(depth, context) 107 | with torch.no_grad(): 108 | inpainted_depth = self.depth_feat_model.forward_3P(holes, context, zero_mean_depth, edge, 109 | unit_length=128, cuda=self.device) 110 | inpainted_depth = self.deprocess_depth(inpainted_depth.detach(), mean_depth) 111 | inpainted_depth[context > 0.5] = depth[context > 0.5] 112 | inpainted_depth = gaussian_blur2d(inpainted_depth, (3, 3), (1.5, 1.5)) 113 | inpainted_depth[context > 0.5] = depth[context > 0.5] 114 | # if the inpainted depth in the background is smaller that the foreground depth, 115 | # then the inpainted content will mistakenly occlude the foreground. 116 | # Clipping the inpainted depth in this situation. 117 | mask_wrong_depth_ordering = inpainted_depth < depth 118 | inpainted_depth[mask_wrong_depth_ordering] = depth[mask_wrong_depth_ordering] * 1.01 119 | inpainted_depth = torch.clamp(inpainted_depth, min=min(depth_range)*0.9) 120 | return inpainted_depth 121 | 122 | def sequential_inpainting(self, rgb, depth, depth_bins): 123 | ''' 124 | :param rgb: [1, 3, H, W] 125 | :param depth: [1, 1, H, W] 126 | :return: rgba_layers: [N, 1, 3, H, W]: the inpainted RGBA layers 127 | depth_layers: [N, 1, 1, H, W]: the inpainted depth layers 128 | mask_layers: [N, 1, 1, H, W]: the original alpha layers (before inpainting) 129 | ''' 130 | 131 | num_bins = len(depth_bins) - 1 132 | 133 | rgba_layers = [] 134 | depth_layers = [] 135 | mask_layers = [] 136 | 137 | for i in range(num_bins): 138 | alpha_i = (depth >= depth_bins[i]) * (depth < depth_bins[i+1]) 139 | alpha_i = alpha_i.float() 140 | 141 | if i == 0: 142 | rgba_i = torch.cat([rgb*alpha_i, alpha_i], dim=1) 143 | rgba_layers.append(rgba_i) 144 | depth_i = refine_near_depth_discontinuity(depth, alpha_i) 145 | depth_layers.append(depth_i) 146 | mask_layers.append(alpha_i) 147 | pre_alpha = alpha_i.bool() 148 | pre_inpainted_depth = depth * alpha_i 149 | else: 150 | alpha_i_eroded = erosion(alpha_i, self.context_erosion_kernel) 151 | if alpha_i_eroded.sum() < 10: 152 | continue 153 | context = erosion((depth >= depth_bins[i]).float(), self.context_erosion_kernel) 154 | holes = 1. - context 155 | bbox = define_inpainting_bbox(context, border=40) 156 | holes *= bbox 157 | edge = torch.zeros_like(holes) 158 | context_rgb = rgb * context 159 | # inpaint depth 160 | inpainted_depth_i = self.inpaint_depth(depth, holes, context, edge, (depth_bins[i], depth_bins[i+1])) 161 | depth_near_mask = (inpainted_depth_i < depth_bins[i+1]).float() 162 | # inpaint rgb 163 | inpainted_rgba_i = self.inpaint_rgb(holes, context, context_rgb, edge) 164 | 165 | if i < num_bins - 1: 166 | # only keep the content whose depth is smaller than the upper limit of the current layer 167 | # otherwise the inpainted content on the far-depth edge will falsely occlude the next layer. 168 | inpainted_rgba_i *= depth_near_mask 169 | inpainted_depth_i = refine_near_depth_discontinuity(inpainted_depth_i, inpainted_rgba_i[:, [-1]]) 170 | 171 | inpainted_alpha_i = inpainted_rgba_i[:, [-1]].bool() 172 | mask_wrong_ordering = (inpainted_depth_i <= pre_inpainted_depth) * inpainted_alpha_i 173 | inpainted_depth_i[mask_wrong_ordering] = pre_inpainted_depth[mask_wrong_ordering] * 1.05 174 | 175 | rgba_layers.append(inpainted_rgba_i) 176 | depth_layers.append(inpainted_depth_i) 177 | mask_layers.append(context * depth_near_mask) # original mask 178 | 179 | pre_alpha[inpainted_alpha_i] = True 180 | pre_inpainted_depth[inpainted_alpha_i > 0] = inpainted_depth_i[inpainted_alpha_i > 0] 181 | 182 | rgba_layers = torch.stack(rgba_layers) 183 | depth_layers = torch.stack(depth_layers) 184 | mask_layers = torch.stack(mask_layers) 185 | 186 | return rgba_layers, depth_layers, mask_layers 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | -------------------------------------------------------------------------------- /core/pcd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import sys 17 | sys.path.append('../') 18 | import config 19 | import numpy as np 20 | import torch 21 | import torch.nn as nn 22 | from pytorch3d.renderer import ( 23 | PerspectiveCameras, 24 | PointsRasterizationSettings, 25 | PointsRasterizer, 26 | AlphaCompositor, 27 | ) 28 | 29 | args = config.config_parser() 30 | 31 | 32 | class PointsRenderer(nn.Module): 33 | """ 34 | A class for rendering a batch of points. The class should 35 | be initialized with a rasterizer and compositor class which each have a forward 36 | function. 37 | """ 38 | 39 | def __init__(self, rasterizer, compositor) -> None: 40 | super().__init__() 41 | self.rasterizer = rasterizer 42 | self.compositor = compositor 43 | 44 | def to(self, device): 45 | # Manually move to device rasterizer as the cameras 46 | # within the class are not of type nn.Module 47 | self.rasterizer = self.rasterizer.to(device) 48 | self.compositor = self.compositor.to(device) 49 | return self 50 | 51 | def forward(self, point_clouds, **kwargs) -> torch.Tensor: 52 | fragments = self.rasterizer(point_clouds, **kwargs) 53 | 54 | # Construct weights based on the distance of a point to the true point. 55 | # However, this could be done differently: e.g. predicted as opposed 56 | # to a function of the weights. 57 | r = self.rasterizer.raster_settings.radius 58 | 59 | if type(r) == torch.Tensor: 60 | if r.shape[-1] > 1: 61 | idx = fragments.idx.clone() 62 | idx[idx == -1] = 0 63 | r = r[:, idx.squeeze().long()] 64 | r = r.permute(0, 3, 1, 2) 65 | 66 | dists2 = fragments.dists.permute(0, 3, 1, 2) 67 | weights = 1 - dists2 / (r * r) 68 | images = self.compositor( 69 | fragments.idx.long().permute(0, 3, 1, 2), 70 | weights, 71 | point_clouds.features_packed().permute(1, 0), 72 | **kwargs, 73 | ) 74 | 75 | # permute so image comes at the end 76 | images = images.permute(0, 2, 3, 1) 77 | 78 | return images 79 | 80 | 81 | def linear_interpolation(data0, data1, time): 82 | return (1. - time) * data0 + time * data1 83 | 84 | 85 | def create_pcd_renderer(h, w, intrinsics, R=None, T=None, radius=None, device="cuda"): 86 | # Initialize a camera. 87 | fx = intrinsics[0, 0] 88 | fy = intrinsics[1, 1] 89 | # print("CREATEPCD") 90 | if R is None: 91 | R = torch.eye(3)[None] # (1, 3, 3) 92 | # R[:, 1, 1] = -1 93 | # R[:, 2, 2] = -1 94 | # print("RNONE") 95 | # print(R) 96 | if T is None: 97 | T = torch.zeros(1, 3) # (1, 3) 98 | 99 | cameras = PerspectiveCameras(R=R, T=T, 100 | device=device, 101 | focal_length=((-fx, -fy),), 102 | principal_point=(tuple(intrinsics[:2, -1]),), 103 | image_size=((h, w),), 104 | in_ndc=False, 105 | ) 106 | 107 | # Define the settings for rasterization and shading. Here we set the output image to be of size 108 | # 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1 109 | # and blur_radius=0.0. Refer to raster_points.py for explanations of these parameters. 110 | if radius is None: 111 | radius = args.point_radius / min(h, w) * 2.0 112 | if args.vary_pts_radius: 113 | if np.random.choice([0, 1], p=[0.6, 0.4]): 114 | factor = 1 + (0.2 * (np.random.rand() - 0.5)) 115 | radius *= factor 116 | 117 | raster_settings = PointsRasterizationSettings( 118 | image_size=(h, w), 119 | radius=radius, 120 | points_per_pixel=8, 121 | ) 122 | 123 | # Create a points renderer by compositing points using an alpha compositor (nearer points 124 | # are weighted more heavily). See [1] for an explanation. 125 | rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) 126 | renderer = PointsRenderer( 127 | rasterizer=rasterizer, 128 | compositor=AlphaCompositor() 129 | ) 130 | return renderer 131 | -------------------------------------------------------------------------------- /core/renderer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import imageio 18 | import torch.utils.data.distributed 19 | from pytorch3d.structures import Pointclouds 20 | from core.utils import * 21 | from core.depth_layering import get_depth_bins 22 | from core.pcd import linear_interpolation, create_pcd_renderer 23 | 24 | 25 | class ImgRenderer(): 26 | def __init__(self, args, model, scene_flow_estimator, inpainter, device): 27 | self.args = args 28 | self.model = model 29 | self.scene_flow_estimator = scene_flow_estimator 30 | self.inpainter = inpainter 31 | self.device = device 32 | 33 | def process_data(self, data): 34 | self.src_img1 = data['src_img1'].to(self.device) 35 | self.src_img2 = data['src_img2'].to(self.device) 36 | assert self.src_img1.shape == self.src_img2.shape 37 | self.h, self.w = self.src_img1.shape[-2:] 38 | self.src_depth1 = data['src_depth1'].to(self.device) 39 | self.src_depth2 = data['src_depth2'].to(self.device) 40 | self.intrinsic1 = data['intrinsic1'].to(self.device) 41 | self.intrinsic2 = data['intrinsic2'].to(self.device) 42 | 43 | self.pose = data['pose'].to(self.device) 44 | self.scale_shift1 = data['scale_shift1'][0] 45 | self.scale_shift2 = data['scale_shift2'][0] 46 | self.is_multi_view = data['multi_view'][0] 47 | self.src_rgb_file1 = data['src_rgb_file1'][0] 48 | self.src_rgb_file2 = data['src_rgb_file2'][0] 49 | if 'tgt_img' in data.keys(): 50 | self.tgt_img = data['tgt_img'].to(self.device) 51 | if 'tgt_intrinsic' in data.keys(): 52 | self.tgt_intrinsic = data['tgt_intrinsic'].to(self.device) 53 | if 'tgt_pose' in data.keys(): 54 | self.tgt_pose = data['tgt_pose'].to(self.device) 55 | if 'time' in data.keys(): 56 | self.time = data['time'].item() 57 | if 'src_mask1' in data.keys(): 58 | self.src_mask1 = data['src_mask1'].to(self.device) 59 | else: 60 | self.src_mask1 = torch.ones_like(self.src_depth1) 61 | if 'src_mask2' in data.keys(): 62 | self.src_mask2 = data['src_mask2'].to(self.device) 63 | else: 64 | self.src_mask2 = torch.ones_like(self.src_depth2) 65 | 66 | 67 | 68 | def process_data_single(self, data): 69 | self.src_img1 = data['src_img1'].to(self.device) 70 | self.h, self.w = self.src_img1.shape[-2:] 71 | self.src_depth1 = data['src_depth1'].to(self.device) 72 | self.intrinsic1 = data['intrinsic1'].to(self.device) 73 | 74 | self.pose = data['pose'].to(self.device) 75 | self.scale_shift1 = data['scale_shift1'][0] 76 | self.is_multi_view = data['multi_view'][0] 77 | self.src_rgb_file1 = data['src_rgb_file1'][0] 78 | if 'tgt_img' in data.keys(): 79 | self.tgt_img = data['tgt_img'].to(self.device) 80 | if 'tgt_intrinsic' in data.keys(): 81 | self.tgt_intrinsic = data['tgt_intrinsic'].to(self.device) 82 | if 'tgt_pose' in data.keys(): 83 | self.tgt_pose = data['tgt_pose'].to(self.device) 84 | if 'time' in data.keys(): 85 | self.time = data['time'].item() 86 | if 'src_mask1' in data.keys(): 87 | self.src_mask1 = data['src_mask1'].to(self.device) 88 | else: 89 | self.src_mask1 = torch.ones_like(self.src_depth1) 90 | 91 | def feature_extraction(self, rgba_layers, mask_layers, depth_layers): 92 | rgba_layers_in = rgba_layers.squeeze(1) 93 | 94 | if self.args.use_inpainting_mask_for_feature: 95 | rgba_layers_in = torch.cat([rgba_layers_in, mask_layers.squeeze(1)], dim=1) 96 | 97 | if self.args.use_depth_for_feature: 98 | rgba_layers_in = torch.cat([rgba_layers_in, 1. / torch.clamp(depth_layers.squeeze(1), min=1.)], dim=1) 99 | featmaps = self.model.feature_net(rgba_layers_in) 100 | return featmaps 101 | 102 | def apply_scale_shift(self, depth, scale, shift): 103 | disp = 1. / torch.clamp(depth, min=1e-3) 104 | disp = scale * disp + shift 105 | return 1 / torch.clamp(disp, min=1e-3*scale) 106 | 107 | def masked_diffuse(self, x, mask, iter=10, kernel_size=35, median_blur=False): 108 | if median_blur: 109 | x = masked_median_blur(x, mask.repeat(1, x.shape[1], 1, 1), kernel_size=5) 110 | for _ in range(iter): 111 | x, mask = masked_smooth_filter(x, mask, kernel_size=kernel_size) 112 | return x, mask 113 | 114 | def compute_weight_for_two_frame_blending(self, time, disp1, disp2, alpha1, alpha2): 115 | alpha = 4 116 | weight1 = (1 - time) * torch.exp(alpha*disp1) * alpha1 117 | weight2 = time * torch.exp(alpha*disp2) * alpha2 118 | sum_weight = torch.clamp(weight1 + weight2, min=1e-6) 119 | out_weight1 = weight1 / sum_weight 120 | out_weight2 = weight2 / sum_weight 121 | return out_weight1, out_weight2 122 | 123 | def transform_all_pts(self, all_pts, pose): 124 | all_pts_out = [] 125 | for pts in all_pts: 126 | pts_out = transform_pts_in_3D(pts, pose) 127 | all_pts_out.append(pts_out) 128 | return all_pts_out 129 | 130 | def render_pcd(self, pts1, pts2, rgbs1, rgbs2, feats1, feats2, mask, side_ids, R=None, t=None, time=0): 131 | 132 | pts = linear_interpolation(pts1, pts2, time) 133 | rgbs = linear_interpolation(rgbs1, rgbs2, time) 134 | feats = linear_interpolation(feats1, feats2, time) 135 | rgb_feat = torch.cat([rgbs, feats], dim=-1) 136 | 137 | num_sides = side_ids.max() + 1 138 | assert num_sides == 1 or num_sides == 2 139 | 140 | if R is None: 141 | R = torch.eye(3, device=self.device) 142 | if t is None: 143 | t = torch.zeros(3, device=self.device) 144 | 145 | pts_ = (R.mm(pts.T) + t.unsqueeze(-1)).T 146 | if self.args.adaptive_pts_radius: 147 | radius = self.args.point_radius / min(self.h, self.w) * 2.0 * pts[..., -1][None] / \ 148 | torch.clamp(pts_[..., -1][None], min=1e-6) 149 | else: 150 | radius = self.args.point_radius / min(self.h, self.w) * 2.0 151 | 152 | if self.args.vary_pts_radius and np.random.choice([0, 1], p=[0.6, 0.4]): 153 | if type(radius) == torch.Tensor: 154 | factor = 1 + (0.2 * (torch.rand_like(radius) - 0.5)) 155 | else: 156 | factor = 1 + (0.2 * (np.random.rand() - 0.5)) 157 | radius *= factor 158 | 159 | if self.args.use_mask_for_decoding: 160 | rgb_feat = torch.cat([rgb_feat, mask], dim=-1) 161 | 162 | if self.args.use_depth_for_decoding: 163 | disp = normalize_0_1(1. / torch.clamp(pts_[..., [-1]], min=1e-6)) 164 | rgb_feat = torch.cat([rgb_feat, disp], dim=-1) 165 | 166 | global_out_list = [] 167 | direct_color_out_list = [] 168 | meta = {} 169 | for j in range(num_sides): 170 | mask_side = side_ids == j 171 | renderer = create_pcd_renderer(self.h, self.w, self.tgt_intrinsic.squeeze()[:3, :3], 172 | radius=radius[:, mask_side] if type(radius) == torch.Tensor else radius) 173 | all_pcd_j = Pointclouds(points=[pts_[mask_side]], features=[rgb_feat[mask_side]]) 174 | global_out_j = renderer(all_pcd_j) 175 | all_colored_pcd_j = Pointclouds(points=[pts_[mask_side]], features=[rgbs[mask_side]]) 176 | direct_rgb_out_j = renderer(all_colored_pcd_j) 177 | 178 | global_out_list.append(global_out_j) 179 | direct_color_out_list.append(direct_rgb_out_j) 180 | 181 | w1, w2 = self.compute_weight_for_two_frame_blending(time, 182 | global_out_list[0][..., [-1]], 183 | global_out_list[-1][..., [-1]], 184 | global_out_list[0][..., [3]], 185 | global_out_list[-1][..., [3]] 186 | ) 187 | direct_rgb_out = w1 * direct_color_out_list[0] + w2 * direct_color_out_list[-1] 188 | pred_rgb = self.model.img_decoder(global_out_list[0].permute(0, 3, 1, 2), 189 | global_out_list[-1].permute(0, 3, 1, 2), 190 | time) 191 | 192 | direct_rgb = direct_rgb_out[..., :3].permute(0, 3, 1, 2) 193 | acc = 0.5 * (global_out_list[0][..., [3]] + global_out_list[1][..., [3]]).permute(0, 3, 1, 2) 194 | meta['acc'] = acc 195 | return pred_rgb, direct_rgb, meta 196 | 197 | 198 | def render_pcd_single(self, pts1, rgbs1, feats1, mask, side_ids, R=None, t=None, time=0): 199 | 200 | pts = pts1 201 | rgbs = rgbs1 202 | feats = feats1 203 | rgb_feat = torch.cat([rgbs, feats], dim=-1) 204 | 205 | num_sides = side_ids.max() + 1 206 | assert num_sides == 1 or num_sides == 2 207 | 208 | if R is None: 209 | R = torch.eye(3, device=self.device) 210 | if t is None: 211 | t = torch.zeros(3, device=self.device) 212 | 213 | pts_ = (R.mm(pts.T) + t.unsqueeze(-1)).T 214 | if self.args.adaptive_pts_radius: 215 | radius = self.args.point_radius / min(self.h, self.w) * 2.0 * pts[..., -1][None] / \ 216 | torch.clamp(pts_[..., -1][None], min=1e-6) 217 | else: 218 | radius = self.args.point_radius / min(self.h, self.w) * 2.0 219 | 220 | if self.args.vary_pts_radius and np.random.choice([0, 1], p=[0.6, 0.4]): 221 | if type(radius) == torch.Tensor: 222 | factor = 1 + (0.2 * (torch.rand_like(radius) - 0.5)) 223 | else: 224 | factor = 1 + (0.2 * (np.random.rand() - 0.5)) 225 | radius *= factor 226 | 227 | if self.args.use_mask_for_decoding: 228 | rgb_feat = torch.cat([rgb_feat, mask], dim=-1) 229 | 230 | if self.args.use_depth_for_decoding: 231 | disp = normalize_0_1(1. / torch.clamp(pts_[..., [-1]], min=1e-6)) 232 | rgb_feat = torch.cat([rgb_feat, disp], dim=-1) 233 | 234 | global_out_list = [] 235 | direct_color_out_list = [] 236 | meta = {} 237 | for j in range(num_sides): 238 | mask_side = side_ids == j 239 | renderer = create_pcd_renderer(self.h, self.w, self.tgt_intrinsic.squeeze()[:3, :3], 240 | radius=radius[:, mask_side] if type(radius) == torch.Tensor else radius) 241 | all_pcd_j = Pointclouds(points=[pts_[mask_side]], features=[rgb_feat[mask_side]]) 242 | global_out_j = renderer(all_pcd_j) 243 | all_colored_pcd_j = Pointclouds(points=[pts_[mask_side]], features=[rgbs[mask_side]]) 244 | direct_rgb_out_j = renderer(all_colored_pcd_j) 245 | 246 | global_out_list.append(global_out_j) 247 | direct_color_out_list.append(direct_rgb_out_j) 248 | 249 | direct_rgb_out = direct_color_out_list[0] 250 | pred_rgb = self.model.img_decoder(global_out_list[0].permute(0, 3, 1, 2), 251 | global_out_list[0].permute(0, 3, 1, 2), 252 | time=0) 253 | 254 | direct_rgb = direct_rgb_out[..., :3].permute(0, 3, 1, 2) 255 | acc = (global_out_list[0][..., [3]]).permute(0, 3, 1, 2) 256 | meta['acc'] = acc 257 | return pred_rgb, direct_rgb, meta 258 | 259 | 260 | def get_reprojection_mask(self, pts, R, t): 261 | pts1_ = (R.mm(pts.T) + t.unsqueeze(-1)).T 262 | mask1 = torch.ones_like(self.src_img1[:, :1].reshape(-1, 1)) 263 | mask_renderer = create_pcd_renderer(self.h, self.w, self.tgt_intrinsic.squeeze()[:3, :3], 264 | radius=1.0 / min(self.h, self.w) * 4.) 265 | mask_pcd = Pointclouds(points=[pts1_], features=[mask1]) 266 | mask = mask_renderer(mask_pcd).permute(0, 3, 1, 2) 267 | mask = F.max_pool2d(mask, kernel_size=7, stride=1, padding=3) 268 | return mask 269 | 270 | def get_cropping_ids(self, mask): 271 | assert mask.shape[:2] == (1, 1) 272 | mask = mask.squeeze() 273 | h, w = mask.shape 274 | mask_mean_x_axis = mask.mean(dim=0) 275 | x_valid = torch.nonzero(mask_mean_x_axis > 0.5) 276 | bad = False 277 | if len(x_valid) < 0.75 * w: 278 | left, right = 0, w - 1 # invalid 279 | bad = True 280 | else: 281 | left, right = x_valid[0][0], x_valid[-1][0] 282 | mask_mean_y_axis = mask.mean(dim=1) 283 | y_valid = torch.nonzero(mask_mean_y_axis > 0.5) 284 | if len(y_valid) < 0.75 * h: 285 | top, bottom = 0, h - 1 # invalid 286 | bad = True 287 | else: 288 | top, bottom = y_valid[0][0], y_valid[-1][0] 289 | assert 0 <= top <= h - 1 and 0 <= bottom <= h - 1 and 0 <= left <= w - 1 and 0 <= right <= w - 1 290 | return top, bottom, left, right, bad 291 | 292 | def render_depth_from_mdi(self, depth_layers, alpha_layers): 293 | ''' 294 | :param depth_layers: [n_layers, 1, h, w] 295 | :param alpha_layers: [n_layers, 1, h, w] 296 | :return: rendered depth [1, 1, h, w] 297 | ''' 298 | num_layers = len(depth_layers) 299 | h, w = depth_layers.shape[-2:] 300 | layer_id = torch.arange(num_layers, device=self.device).float() 301 | layer_id_maps = layer_id[..., None, None, None, None].repeat(1, 1, 1, h, w) 302 | T = torch.cumprod(1. - alpha_layers, dim=0)[:-1] 303 | T = torch.cat([torch.ones_like(T[:1]), T], dim=0) 304 | weights = alpha_layers * T 305 | depth_map = torch.sum(weights * depth_layers, dim=0) 306 | depth_map = torch.clamp(depth_map, min=1.) 307 | layer_id_map = torch.sum(weights * layer_id_maps, dim=0) 308 | return depth_map, layer_id_map 309 | 310 | def render_rgbda_layers_from_one_view(self, return_pts=False): 311 | depth_bins = get_depth_bins(depth=self.src_depth1) 312 | rgba_layers, depth_layers, mask_layers = \ 313 | self.inpainter.sequential_inpainting(self.src_img1, self.src_depth1, depth_bins) 314 | coord1 = get_coord_grids_pt(self.h, self.w, device=self.device).float() 315 | src_depth1 = self.apply_scale_shift(self.src_depth1, self.scale_shift1[0], self.scale_shift1[1]) 316 | pts1 = unproject_pts_pt(self.intrinsic1, coord1.reshape(-1, 2), src_depth1.flatten()) 317 | 318 | featmaps = self.feature_extraction(rgba_layers, mask_layers, depth_layers) 319 | depth_layers = self.apply_scale_shift(depth_layers, self.scale_shift1[0], self.scale_shift1[1]) 320 | num_layers = len(rgba_layers) 321 | all_pts = [] 322 | all_rgbas = [] 323 | all_feats = [] 324 | all_masks = [] 325 | for i in range(num_layers): 326 | alpha_i = rgba_layers[i][:, -1] > 0.5 327 | rgba_i = rgba_layers[i] 328 | mask_i = mask_layers[i] 329 | featmap = featmaps[i][None] 330 | featmap = F.interpolate(featmap, size=(self.h, self.w), mode='bilinear', align_corners=True) 331 | pts1_i = unproject_pts_pt(self.intrinsic1, coord1.reshape(-1, 2), depth_layers[i].flatten()) 332 | pts1_i = pts1_i.reshape(1, self.h, self.w, 3) 333 | all_pts.append(pts1_i[alpha_i]) 334 | all_rgbas.append(rgba_i.permute(0, 2, 3, 1)[alpha_i]) 335 | all_feats.append(featmap.permute(0, 2, 3, 1)[alpha_i]) 336 | all_masks.append(mask_i.permute(0, 2, 3, 1)[alpha_i]) 337 | 338 | all_pts = torch.cat(all_pts) 339 | all_rgbas = torch.cat(all_rgbas) 340 | all_feats = torch.cat(all_feats) 341 | all_masks = torch.cat(all_masks) 342 | all_side_ids = torch.zeros_like(all_masks.squeeze(), dtype=torch.long) 343 | 344 | 345 | if return_pts: 346 | return all_pts, all_rgbas, all_feats, \ 347 | all_masks, all_side_ids 348 | 349 | R = self.tgt_pose[0, :3, :3] 350 | t = self.tgt_pose[0, :3, 3] 351 | 352 | pred_img, direct_rgb_out, meta = self.render_pcd(all_pts, all_pts, 353 | all_rgbas, all_rgbas, 354 | all_feats, all_feats, 355 | all_masks, all_side_ids, 356 | R, t, 0) 357 | 358 | mask = self.get_reprojection_mask(pts1, R, t) 359 | t, b, l, r, bad = self.get_cropping_ids(mask) 360 | skip = False 361 | if not skip and not self.args.eval_mode: 362 | pred_img = pred_img[:, :, t:b, l:r] 363 | mask = mask[:, :, t:b, l:r] 364 | direct_rgb_out = direct_rgb_out[:, :, t:b, l:r] 365 | else: 366 | skip = True 367 | 368 | res_dict = { 369 | 'src_img1': self.src_img1, 370 | 'pred_img': pred_img, 371 | 'mask': mask, 372 | 'direct_rgb_out': direct_rgb_out, 373 | 'skip': skip 374 | } 375 | return res_dict 376 | 377 | def compute_scene_flow_one_side(self, coord, pose, 378 | rgb1, rgb2, 379 | rgba_layers1, rgba_layers2, 380 | featmaps1, featmaps2, 381 | pts1, pts2, 382 | depth_layers1, depth_layers2, 383 | mask_layers1, mask_layers2, 384 | flow_f, flow_b, kernel, 385 | with_inpainted=False): 386 | 387 | num_layers1 = len(rgba_layers1) 388 | pts2 = transform_pts_in_3D(pts2, pose).T.reshape(1, 3, self.h, self.w) 389 | 390 | mask_mutual_flow = self.scene_flow_estimator.get_mutual_matches(flow_f, flow_b, th=5, return_mask=True).float() 391 | mask_mutual_flow = mask_mutual_flow.unsqueeze(1) 392 | 393 | coord1_corsp = coord + flow_f 394 | coord1_corsp_normed = normalize_for_grid_sample(coord1_corsp, self.h, self.w) 395 | pts2_sampled = F.grid_sample(pts2, coord1_corsp_normed, align_corners=True, 396 | mode='nearest', padding_mode="border") 397 | depth2_sampled = pts2_sampled[:, -1:] 398 | 399 | rgb2_sampled = F.grid_sample(rgb2, coord1_corsp_normed, align_corners=True, padding_mode="border") 400 | mask_layers2_ds = F.interpolate(mask_layers2.squeeze(1), size=featmaps2.shape[-2:], mode='area') 401 | featmap2 = torch.sum(featmaps2 * mask_layers2_ds, dim=0, keepdim=True) 402 | context2 = torch.sum(mask_layers2_ds, dim=0, keepdim=True) 403 | featmap2_sampled = F.grid_sample(featmap2, coord1_corsp_normed, align_corners=True, padding_mode="border") 404 | context2_sampled = F.grid_sample(context2, coord1_corsp_normed, align_corners=True, padding_mode="border") 405 | mask2_sampled = F.grid_sample(self.src_mask2, coord1_corsp_normed, align_corners=True, padding_mode="border") 406 | 407 | featmap2_sampled = featmap2_sampled / torch.clamp(context2_sampled, min=1e-6) 408 | context2_sampled = (context2_sampled > 0.5).float() 409 | last_pts2_i = torch.zeros_like(pts2.permute(0, 2, 3, 1)) 410 | last_alpha_i = torch.zeros_like(rgba_layers1[0][:, -1], dtype=torch.bool) 411 | 412 | all_pts = [] 413 | all_rgbas = [] 414 | all_feats = [] 415 | all_rgbas_end = [] 416 | all_feats_end = [] 417 | all_masks = [] 418 | all_pts_end = [] 419 | all_optical_flows = [] 420 | for i in range(num_layers1): 421 | alpha_i = (rgba_layers1[i][:, -1]*self.src_mask1.squeeze(1)*mask2_sampled.squeeze(1)) > 0.5 422 | rgba_i = rgba_layers1[i] 423 | mask_i = mask_layers1[i] 424 | mask_no_mutual_flow = mask_i * context2_sampled 425 | mask_gau_i = mask_no_mutual_flow * mask_mutual_flow 426 | mask_no_mutual_flow = erosion(mask_no_mutual_flow, kernel) 427 | mask_gau_i = erosion(mask_gau_i, kernel) 428 | 429 | featmap1 = featmaps1[i][None] 430 | featmap1 = F.interpolate(featmap1, size=(self.h, self.w), mode='bilinear', align_corners=True) 431 | pts1_i = unproject_pts_pt(self.intrinsic1, coord.reshape(-1, 2), depth_layers1[i].flatten()) 432 | pts1_i = pts1_i.reshape(1, self.h, self.w, 3) 433 | 434 | flow_inpainted, mask_no_mutual_flow_ = self.masked_diffuse(flow_f.permute(0, 3, 1, 2), 435 | mask_no_mutual_flow, 436 | kernel_size=15, iter=7) 437 | 438 | coord_inpainted = coord.clone() 439 | coord_inpainted_ = coord + flow_inpainted.permute(0, 2, 3, 1) 440 | mask_no_mutual_flow_bool = (mask_no_mutual_flow_ > 1e-6).squeeze(1) 441 | coord_inpainted[mask_no_mutual_flow_bool] = coord_inpainted_[mask_no_mutual_flow_bool] 442 | 443 | depth_inpainted = depth_layers1[i].clone() 444 | depth_inpainted_, mask_gau_i_ = self.masked_diffuse(depth2_sampled, mask_gau_i, 445 | kernel_size=15, iter=7) 446 | mask_gau_i_bool = (mask_gau_i_ > 1e-6).squeeze(1) 447 | depth_inpainted.squeeze(1)[mask_gau_i_bool] = depth_inpainted_.squeeze(1)[mask_gau_i_bool] 448 | pts2_i = unproject_pts_pt(self.intrinsic2, coord_inpainted.contiguous().reshape(-1, 2), 449 | depth_inpainted.flatten()).reshape(1, self.h, self.w, 3) 450 | 451 | if i > 0: 452 | mask_wrong_ordering = (pts2_i[..., -1] <= last_pts2_i[..., -1]) * last_alpha_i 453 | pts2_i[mask_wrong_ordering] = last_pts2_i[mask_wrong_ordering] * 1.01 454 | 455 | rgba_end = mask_gau_i * torch.cat([rgb2_sampled, mask_gau_i], dim=1) + (1 - mask_gau_i) * rgba_i 456 | feat_end = mask_gau_i * featmap2_sampled + (1 - mask_gau_i) * featmap1 457 | last_alpha_i[alpha_i] = True 458 | last_pts2_i[alpha_i] = pts2_i[alpha_i] 459 | 460 | if with_inpainted: 461 | mask_keep = alpha_i 462 | else: 463 | mask_keep = mask_i.squeeze(1).bool() 464 | 465 | all_pts.append(pts1_i[mask_keep]) 466 | all_rgbas.append(rgba_i.permute(0, 2, 3, 1)[mask_keep]) 467 | all_feats.append(featmap1.permute(0, 2, 3, 1)[mask_keep]) 468 | all_masks.append(mask_i.permute(0, 2, 3, 1)[mask_keep]) 469 | all_pts_end.append(pts2_i[mask_keep]) 470 | all_rgbas_end.append(rgba_end.permute(0, 2, 3, 1)[mask_keep]) 471 | all_feats_end.append(feat_end.permute(0, 2, 3, 1)[mask_keep]) 472 | all_optical_flows.append(flow_inpainted.permute(0, 2, 3, 1)[mask_keep]) 473 | 474 | return all_pts, all_pts_end, all_rgbas, all_rgbas_end, all_feats, all_feats_end, all_masks, all_optical_flows 475 | 476 | def render_rgbda_layers_with_scene_flow(self, return_pts=False): 477 | kernel = torch.ones(5, 5, device=self.device) 478 | flow_f = self.scene_flow_estimator.compute_optical_flow(self.src_img1, self.src_img2) 479 | flow_b = self.scene_flow_estimator.compute_optical_flow(self.src_img2, self.src_img1) 480 | 481 | depth_bins1 = get_depth_bins(depth=self.src_depth1) 482 | depth_bins2 = get_depth_bins(depth=self.src_depth2) 483 | 484 | rgba_layers1, depth_layers1, mask_layers1 = \ 485 | self.inpainter.sequential_inpainting(self.src_img1, self.src_depth1, depth_bins1) 486 | rgba_layers2, depth_layers2, mask_layers2 = \ 487 | self.inpainter.sequential_inpainting(self.src_img2, self.src_depth2, depth_bins2) 488 | if self.args.visualize_rgbda_layers: 489 | self.save_rgbda_layers(self.src_rgb_file1, rgba_layers1, depth_layers1, mask_layers1) 490 | self.save_rgbda_layers(self.src_rgb_file2, rgba_layers2, depth_layers2, mask_layers2) 491 | 492 | featmaps1 = self.feature_extraction(rgba_layers1, mask_layers1, depth_layers1) 493 | featmaps2 = self.feature_extraction(rgba_layers2, mask_layers2, depth_layers2) 494 | 495 | depth_layers1 = self.apply_scale_shift(depth_layers1, self.scale_shift1[0], self.scale_shift1[1]) 496 | depth_layers2 = self.apply_scale_shift(depth_layers2, self.scale_shift2[0], self.scale_shift2[1]) 497 | 498 | processed_depth1, layer_id_map1 = self.render_depth_from_mdi(depth_layers1, rgba_layers1[:, :, -1:]) 499 | processed_depth2, layer_id_map2 = self.render_depth_from_mdi(depth_layers2, rgba_layers2[:, :, -1:]) 500 | 501 | assert self.src_img1.shape[-2:] == self.src_img2.shape[-2:] 502 | h, w = self.src_img1.shape[-2:] 503 | coord = get_coord_grids_pt(h, w, device=self.device).float()[None] 504 | pts1 = unproject_pts_pt(self.intrinsic1, coord.reshape(-1, 2), processed_depth1.flatten()) 505 | pts2 = unproject_pts_pt(self.intrinsic2, coord.reshape(-1, 2), processed_depth2.flatten()) 506 | 507 | all_pts_11, all_pts_12, all_rgbas_11, all_rgbas_12, all_feats_11, all_feats_12,\ 508 | all_masks_1, all_optical_flow_1 = \ 509 | self.compute_scene_flow_one_side(coord, torch.inverse(self.pose), self.src_img1, self.src_img2, 510 | rgba_layers1, rgba_layers2, featmaps1, featmaps2, 511 | pts1, pts2, depth_layers1, depth_layers2, mask_layers1, mask_layers2, 512 | flow_f, flow_b, kernel, with_inpainted=True) 513 | 514 | all_pts_22, all_pts_21, all_rgbas_22, all_rgbas_21, all_feats_22, all_feats_21,\ 515 | all_masks_2, all_optical_flow_2 = \ 516 | self.compute_scene_flow_one_side(coord, self.pose, self.src_img2, self.src_img1, 517 | rgba_layers2, rgba_layers1, featmaps2, featmaps1, 518 | pts2, pts1, depth_layers2, depth_layers1, mask_layers2, mask_layers1, 519 | flow_b, flow_f, kernel, with_inpainted=True) 520 | 521 | if not torch.allclose(self.pose, torch.eye(4, device=self.device)): 522 | all_pts_21 = self.transform_all_pts(all_pts_21, torch.inverse(self.pose)) 523 | all_pts_22 = self.transform_all_pts(all_pts_22, torch.inverse(self.pose)) 524 | 525 | all_pts = torch.cat(all_pts_11+all_pts_21) 526 | all_rgbas = torch.cat(all_rgbas_11+all_rgbas_21) 527 | all_feats = torch.cat(all_feats_11+all_feats_21) 528 | all_masks = torch.cat(all_masks_1+all_masks_2) 529 | all_pts_end = torch.cat(all_pts_12+all_pts_22) 530 | all_rgbas_end = torch.cat(all_rgbas_12+all_rgbas_22) 531 | all_feats_end = torch.cat(all_feats_12+all_feats_22) 532 | all_side_ids = torch.zeros_like(all_masks.squeeze(), dtype=torch.long) 533 | num_pts_2 = sum([len(x) for x in all_pts_21]) 534 | all_side_ids[-num_pts_2:] = 1 535 | all_optical_flow = torch.cat(all_optical_flow_1+all_optical_flow_2) 536 | 537 | if return_pts: 538 | return all_pts, all_pts_end, all_rgbas, all_rgbas_end, all_feats, all_feats_end, \ 539 | all_masks, all_side_ids, all_optical_flow 540 | 541 | R = self.tgt_pose[0, :3, :3] 542 | t = self.tgt_pose[0, :3, 3] 543 | pred_img, direct_rgb_out, meta = self.render_pcd(all_pts, all_pts_end, 544 | all_rgbas, all_rgbas_end, 545 | all_feats, all_feats_end, 546 | all_masks, all_side_ids, 547 | R, t, self.time) 548 | mask1 = self.get_reprojection_mask(pts1, R, t) 549 | pose2_to_tgt = self.tgt_pose.bmm(torch.inverse(self.pose)) 550 | mask2 = self.get_reprojection_mask(pts2, pose2_to_tgt[0, :3, :3], pose2_to_tgt[0, :3, 3]) 551 | mask = (mask1+mask2) * 0.5 552 | gt_img = self.tgt_img 553 | t, b, l, r, bad = self.get_cropping_ids(mask) 554 | skip = False 555 | if not skip and not self.args.eval_mode: 556 | pred_img = pred_img[:, :, t:b, l:r] 557 | mask = mask[:, :, t:b, l:r] 558 | direct_rgb_out = direct_rgb_out[:, :, t:b, l:r] 559 | gt_img = gt_img[:, :, t:b, l:r] 560 | else: 561 | skip = True 562 | 563 | res_dict = { 564 | 'src_img1': self.src_img1, 565 | 'src_img2': self.src_img2, 566 | 'pred_img': pred_img, 567 | 'gt_img': gt_img, 568 | 'mask': mask, 569 | 'direct_rgb_out': direct_rgb_out, 570 | 'alpha_layers1': rgba_layers1[:, :, [-1]], 571 | 'alpha_layers2': rgba_layers2[:, :, [-1]], 572 | 'mask_layers1': mask_layers1, 573 | 'mask_layers2': mask_layers2, 574 | 'skip': skip 575 | } 576 | return res_dict 577 | 578 | def dynamic_view_synthesis_with_inpainting(self): 579 | if self.is_multi_view: 580 | return self.render_rgbda_layers_from_one_view() 581 | else: 582 | return self.render_rgbda_layers_with_scene_flow() 583 | 584 | def get_prediction(self, data): 585 | # process data first 586 | self.process_data(data) 587 | return self.dynamic_view_synthesis_with_inpainting() 588 | 589 | def save_rgbda_layers(self, src_rgb_file, rgba_layers, depth_layers, mask_layers): 590 | frame_id = os.path.basename(src_rgb_file).split('.')[0] 591 | scene_id = src_rgb_file.split('/')[-3] 592 | out_dir = os.path.join(self.args.rootdir, 'out', self.args.expname, 'vis', 593 | '{}-{}'.format(scene_id, frame_id)) 594 | os.makedirs(out_dir, exist_ok=True) 595 | 596 | alpha_layers = rgba_layers[:, :, [-1]] 597 | for i, rgba_layer in enumerate(rgba_layers): 598 | save_filename = os.path.join(out_dir, 'rgb_original_{}.png'.format(i)) 599 | rgba_layer_ = rgba_layer * mask_layers[i] 600 | rgba_np = rgba_layer_.detach().squeeze().permute(1, 2, 0).cpu().numpy() 601 | imageio.imwrite(save_filename, float2uint8(rgba_np)) 602 | 603 | for i, rgba_layer in enumerate(rgba_layers): 604 | save_filename = os.path.join(out_dir, 'rgb_{}.png'.format(i)) 605 | rgba_np = rgba_layer.detach().squeeze().permute(1, 2, 0).cpu().numpy() 606 | imageio.imwrite(save_filename, float2uint8(rgba_np)) 607 | 608 | for i, depth_layer in enumerate(depth_layers): 609 | save_filename = os.path.join(out_dir, 'disparity_original_{}.png'.format(i)) 610 | disparity = (1. / torch.clamp(depth_layer, min=1e-6)) * alpha_layers[i] 611 | disparity = torch.cat([disparity, disparity, disparity, alpha_layers[i]*mask_layers[i]], dim=1) 612 | disparity_np = disparity.detach().squeeze().cpu().numpy().transpose(1, 2, 0) 613 | imageio.imwrite(save_filename, float2uint8(disparity_np)) 614 | 615 | for i, depth_layer in enumerate(depth_layers): 616 | save_filename = os.path.join(out_dir, 'disparity_{}.png'.format(i)) 617 | disparity = (1. / torch.clamp(depth_layer, min=1e-6)) * alpha_layers[i] 618 | disparity = torch.cat([disparity, disparity, disparity, alpha_layers[i]], dim=1) 619 | disparity_np = disparity.detach().squeeze().cpu().numpy().transpose(1, 2, 0) 620 | imageio.imwrite(save_filename, float2uint8(disparity_np)) 621 | 622 | for i, mask_layer in enumerate(mask_layers): 623 | save_filename = os.path.join(out_dir, 'mask_{}.png'.format(i)) 624 | tri_mask = 0.5 * alpha_layers[i] + 0.5 * mask_layer 625 | tri_mask_np = tri_mask.detach().squeeze().cpu().numpy() 626 | imageio.imwrite(save_filename, float2uint8(tri_mask_np)) 627 | 628 | 629 | 630 | -------------------------------------------------------------------------------- /core/scene_flow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from third_party.RAFT.core.utils.utils import InputPadder 17 | from core.utils import * 18 | 19 | 20 | class SceneFlowEstimator(): 21 | def __init__(self, args, model): 22 | device = "cuda:{}".format(args.local_rank) 23 | self.device = device 24 | self.raft_model = model 25 | self.train_raft = args.train_raft 26 | 27 | def compute_optical_flow(self, img1, img2, return_np_array=False): 28 | ''' 29 | :param img1: [B, 3, H, W] 30 | :param img2: [B, 3, H, W] 31 | :return: optical_flow, [B, H, W, 2] 32 | ''' 33 | if not self.train_raft: 34 | with torch.no_grad(): 35 | assert img1.max() <= 1 and img2.max() <= 1 36 | image1 = img1 * 255. 37 | image2 = img2 * 255. 38 | padder = InputPadder(image1.shape) 39 | image1, image2 = padder.pad(image1, image2) 40 | 41 | flow_low, flow_up = self.raft_model.module(image1, image2, iters=20, test_mode=True, padder=padder) 42 | 43 | if return_np_array: 44 | return flow_up.cpu().numpy().transpose(0, 2, 3, 1) 45 | 46 | return flow_up.permute(0, 2, 3, 1).detach() # [B, h, w, 2] 47 | else: 48 | assert img1.max() <= 1 and img2.max() <= 1 49 | image1 = img1 * 255. 50 | image2 = img2 * 255. 51 | padder = InputPadder(image1.shape) 52 | image1, image2 = padder.pad(image1, image2) 53 | flow_predictions = self.raft_model.module(image1, image2, iters=20, padder=padder) 54 | return flow_predictions[-1].permute(0, 2, 3, 1) # [B, h, w, 2] 55 | 56 | def get_mutual_matches(self, flow_f, flow_b, th=2., return_mask=False): 57 | assert flow_f.shape == flow_b.shape 58 | batch_size = flow_f.shape[0] 59 | assert flow_f.shape[1:3] == flow_b.shape[1:3] 60 | h, w = flow_f.shape[1:3] 61 | grid = get_coord_grids_pt(h, w, self.device)[None].float().repeat(batch_size, 1, 1, 1) # [B, h, w, 2] 62 | grid2 = grid + flow_f 63 | mask_boundary = (grid2[..., 0] >= 0) * (grid2[..., 0] <= w - 1) * \ 64 | (grid2[..., 1] >= 0) * (grid2[..., 1] <= h - 1) 65 | grid2_normed = normalize_for_grid_sample(grid2, h, w) 66 | flow_b_sampled = F.grid_sample(flow_b.permute(0, 3, 1, 2), grid2_normed, 67 | align_corners=True).permute(0, 2, 3, 1) 68 | grid1 = grid2 + flow_b_sampled 69 | mask_boundary *= (grid1[..., 0] >= 0) * (grid1[..., 0] <= w - 1) * \ 70 | (grid1[..., 1] >= 0) * (grid1[..., 1] <= h - 1) 71 | 72 | fb_map = flow_f + flow_b_sampled 73 | mask_valid = mask_boundary * (torch.norm(fb_map, dim=-1) < th) 74 | if return_mask: 75 | return mask_valid 76 | coords1 = grid[mask_valid] # [n, 2] 77 | coords2 = grid2[mask_valid] # [n, 2] 78 | return coords1, coords2 -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Tuple 17 | import numpy as np 18 | import torch 19 | import torch.nn.functional as F 20 | from kornia.filters import gaussian_blur2d, box_blur, median_blur 21 | from kornia.filters.kernels import get_binary_kernel2d 22 | from kornia.morphology import erosion 23 | from scipy.interpolate import interp1d 24 | from scipy.ndimage import median_filter 25 | 26 | 27 | def float2uint8(x): 28 | return (255. * x).astype(np.uint8) 29 | 30 | 31 | def float2uint16(x): 32 | return (65535 * x).astype(np.uint16) 33 | 34 | 35 | def normalize_0_1(x): 36 | x_min, x_max = x.min(), x.max() 37 | return (x - x_min) / (x_max - x_min) 38 | 39 | 40 | def homogenize_np(coord): 41 | """ 42 | append ones in the last dimension 43 | :param coord: [...., 2/3] 44 | :return: homogenous coordinates 45 | """ 46 | return np.concatenate([coord, np.ones_like(coord[..., :1])], axis=-1) 47 | 48 | 49 | def homogenize_pt(coord): 50 | return torch.cat([coord, torch.ones_like(coord[..., :1])], dim=-1) 51 | 52 | 53 | def get_coord_grids_pt(h, w, device, homogeneous=False): 54 | """ 55 | create pxiel coordinate grid 56 | :param h: height 57 | :param w: weight 58 | :param device: device 59 | :param homogeneous: if homogeneous coordinate 60 | :return: coordinates [h, w, 2] 61 | """ 62 | y = torch.arange(0, h).to(device) 63 | x = torch.arange(0, w).to(device) 64 | grid_y, grid_x = torch.meshgrid(y, x) 65 | if homogeneous: 66 | return torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1) 67 | return torch.stack([grid_x, grid_y], dim=-1) # [h, w, 2] 68 | 69 | 70 | def normalize_for_grid_sample(coords, h, w): 71 | device = coords.device 72 | coords_normed = coords / torch.tensor([w-1., h-1.]).to(device) * 2. - 1. 73 | return coords_normed 74 | 75 | 76 | def unproject_pts_np(intrinsics, coords, depth): 77 | if coords.shape[-1] == 2: 78 | coords = homogenize_np(coords) 79 | intrinsics = intrinsics.squeeze()[:3, :3] 80 | coords = np.linalg.inv(intrinsics).dot(coords.T) * depth.reshape(1, -1) 81 | return coords.T # [n, 3] 82 | 83 | 84 | def unproject_pts_pt(intrinsics, coords, depth, normalize_depth=False): 85 | if coords.shape[-1] == 2: 86 | coords = homogenize_pt(coords) 87 | intrinsics = intrinsics.squeeze()[:3, :3] 88 | coords = torch.inverse(intrinsics).mm(coords.T) 89 | if normalize_depth: 90 | coords = F.normalize(coords, dim=0, p=2.0) 91 | print("erer") 92 | coords = coords * depth.reshape(1, -1) 93 | return coords.T # [n, 3] 94 | 95 | 96 | def pixel2cam(depth, pixel_coords, intrinsics, is_homogeneous=True): 97 | """Transforms coordinates in the pixel frame to the camera frame. 98 | Args: 99 | depth: [batch, height, width] 100 | pixel_coords: homogeneous pixel coordinates [batch, 3, height, width] 101 | intrinsics: camera intrinsics [batch, 3, 3] 102 | is_homogeneous: return in homogeneous coordinates 103 | Returns: 104 | Coords in the camera frame [batch, 3 (4 if homogeneous), height, width] 105 | """ 106 | if depth.ndim == 4: 107 | assert depth.shape[1] == 1 108 | depth = depth.squeeze(1) 109 | batch, height, width = depth.shape 110 | depth = depth.reshape(batch, 1, -1) 111 | pixel_coords = pixel_coords.reshape(batch, 3, -1) 112 | cam_coords = torch.inverse(intrinsics).bmm(pixel_coords) * depth 113 | if is_homogeneous: 114 | ones = torch.ones_like(depth) 115 | cam_coords = torch.cat([cam_coords, ones], dim=1) 116 | cam_coords = cam_coords.reshape(batch, -1, height, width) 117 | return cam_coords 118 | 119 | 120 | def transform_pts_in_3D(pts, pose, return_homogeneous=False): 121 | ''' 122 | :param pts: nx3, tensor 123 | :param pose: 4x4, tensor 124 | :return: nx3 or nx4, tensor 125 | ''' 126 | pts_h = homogenize_pt(pts) 127 | pose = pose.squeeze() 128 | assert pose.shape == (4, 4) 129 | transformed_pts_h = pose.mm(pts_h.T).T # [n, 4] 130 | if return_homogeneous: 131 | return transformed_pts_h 132 | return transformed_pts_h[..., :3] 133 | 134 | 135 | def crop_boundary(x, ratio): 136 | h, w = x.shape[-2:] 137 | crop_h = int(h * ratio) 138 | crop_w = int(w * ratio) 139 | return x[:, :, crop_h:h-crop_h, crop_w:w-crop_w] 140 | 141 | 142 | def masked_smooth_filter(x, mask, kernel_size=9, sigma=1): 143 | ''' 144 | :param x: [B, n, h, w] 145 | :param mask: [B, 1, h, w] 146 | :return: [B, n, h, w] 147 | ''' 148 | x_ = x * mask 149 | x_ = box_blur(x_, (kernel_size, kernel_size), border_type='constant') 150 | mask_ = box_blur(mask, (kernel_size, kernel_size), border_type='constant') 151 | x_ = x_ / torch.clamp(mask_, min=1e-6) 152 | mask_bool = (mask.repeat(1, x.shape[1], 1, 1) > 1e-6).float() 153 | out = mask_bool * x + (1. - mask_bool) * x_ 154 | return out, mask_ 155 | 156 | 157 | def remove_noise_in_dpt_disparity(disparity, kernel_size=5): 158 | return median_filter(disparity, size=kernel_size) 159 | 160 | 161 | def _compute_zero_padding(kernel_size: Tuple[int, int]) -> Tuple[int, int]: 162 | r"""Utility function that computes zero padding tuple.""" 163 | computed: List[int] = [(k - 1) // 2 for k in kernel_size] 164 | return computed[0], computed[1] 165 | 166 | 167 | def masked_median_blur(input, mask, kernel_size=9): 168 | assert input.shape == mask.shape 169 | if not isinstance(input, torch.Tensor): 170 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") 171 | 172 | if not len(input.shape) == 4: 173 | raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") 174 | 175 | padding: Tuple[int, int] = _compute_zero_padding((kernel_size, kernel_size)) 176 | 177 | # prepare kernel 178 | kernel: torch.Tensor = get_binary_kernel2d((kernel_size, kernel_size)).to(input) 179 | b, c, h, w = input.shape 180 | 181 | # map the local window to single vector 182 | features: torch.Tensor = F.conv2d(input.reshape(b * c, 1, h, w), kernel, padding=padding, stride=1) 183 | masks: torch.Tensor = F.conv2d(mask.reshape(b * c, 1, h, w), kernel, padding=padding, stride=1) 184 | features = features.view(b, c, -1, h, w).permute(0, 1, 3, 4, 2) # BxCxxHxWx(K_h * K_w) 185 | min_value, max_value = features.min(), features.max() 186 | masks = masks.view(b, c, -1, h, w).permute(0, 1, 3, 4, 2) # BxCxHxWx(K_h * K_w) 187 | index_invalid = (1 - masks).nonzero(as_tuple=True) 188 | index_b, index_c, index_h, index_w, index_k = index_invalid 189 | features[(index_b[::2], index_c[::2], index_h[::2], index_w[::2], index_k[::2])] = min_value 190 | features[(index_b[1::2], index_c[1::2], index_h[1::2], index_w[1::2], index_k[1::2])] = max_value 191 | # compute the median along the feature axis 192 | median: torch.Tensor = torch.median(features, dim=-1)[0] 193 | 194 | return median 195 | 196 | 197 | def define_camera_path(num_frames, x, y, z, path_type='circle', return_t_only=False): 198 | generic_pose = np.eye(4) 199 | tgt_poses = [] 200 | if path_type == 'straight-line': 201 | corner_points = np.array([[0, 0, 0], [(0 + x) * 0.5, (0 + y) * 0.5, (0 + z) * 0.5], [x, y, z]]) 202 | corner_t = np.linspace(0, 1, len(corner_points)) 203 | t = np.linspace(0, 1, num_frames) 204 | cs = interp1d(corner_t, corner_points, axis=0, kind='quadratic') 205 | spline = cs(t) 206 | xs, ys, zs = [xx.squeeze() for xx in np.split(spline, 3, 1)] 207 | elif path_type == 'double-straight-line': 208 | corner_points = np.array([[-x, -y, -z], [0, 0, 0], [x, y, z]]) 209 | corner_t = np.linspace(0, 1, len(corner_points)) 210 | t = np.linspace(0, 1, num_frames) 211 | cs = interp1d(corner_t, corner_points, axis=0, kind='quadratic') 212 | spline = cs(t) 213 | xs, ys, zs = [xx.squeeze() for xx in np.split(spline, 3, 1)] 214 | elif path_type == 'circle': 215 | xs, ys, zs = [], [], [] 216 | for frame_id, bs_shift_val in enumerate(np.arange(-2.0, 2.0, (4./num_frames))): 217 | xs += [np.cos(bs_shift_val * np.pi) * 1 * x] 218 | ys += [np.sin(bs_shift_val * np.pi) * 1 * y] 219 | zs += [np.cos(bs_shift_val * np.pi/2.) * 1 * z] 220 | xs, ys, zs = np.array(xs), np.array(ys), np.array(zs) 221 | elif path_type == 'debug': 222 | xs = np.array([x, 0, -x, 0, 0]) 223 | ys = np.array([0, y, 0, -y, 0]) 224 | zs = np.array([0, 0, 0, 0, z]) 225 | else: 226 | raise NotImplementedError 227 | 228 | xs, ys, zs = np.array(xs), np.array(ys), np.array(zs) 229 | if return_t_only: 230 | return np.stack([xs, ys, zs], axis=1) # [n, 3] 231 | for xx, yy, zz in zip(xs, ys, zs): 232 | tgt_poses.append(generic_pose * 1.) 233 | tgt_poses[-1][:3, -1] = np.array([xx, yy, zz]) 234 | return tgt_poses -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # # download DPT (https://github.com/isl-org/DPT) pretrained weights into DPT/weights 17 | # wget https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt -P third_party/DPT/weights 18 | 19 | # # download RAFT (https://github.com/princeton-vl/RAFT) pretrained weights into RAFT/models/ 20 | # wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip 21 | # unzip models.zip -d third_party/RAFT/ 22 | # rm -rf models.zip 23 | 24 | # download rgbd inpainting pretrained weights into inpainting_ckpts/ 25 | wget https://filebox.ece.vt.edu/~jbhuang/project/3DPhoto/model/color-model.pth 26 | wget https://filebox.ece.vt.edu/~jbhuang/project/3DPhoto/model/depth-model.pth 27 | wget https://filebox.ece.vt.edu/~jbhuang/project/3DPhoto/model/edge-model.pth 28 | mkdir inpainting_ckpts/ 29 | mv color-model.pth inpainting_ckpts/ 30 | mv depth-model.pth inpainting_ckpts/ 31 | mv edge-model.pth inpainting_ckpts/ 32 | 33 | # download the 3D moments pretrained model: 34 | gdown https://drive.google.com/uc?id=1keqdnl2roBO2MjXhbd0VbfYaGAhyjUed -O pretrained/ -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: reshader 2 | channels: 3 | - pytorch3d 4 | - pytorch 5 | - iopath 6 | - fvcore 7 | - conda-forge 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=main 11 | - _openmp_mutex=4.5=1_gnu 12 | - blas=1.0=mkl 13 | - ca-certificates=2021.10.8=ha878542_0 14 | - certifi=2021.10.8=py37h89c1867_0 15 | - colorama=0.4.4=pyh9f0ad1d_0 16 | - cudatoolkit=10.1.243=h6bb024c_0 17 | - freetype=2.10.4=h5ab3b9f_0 18 | - fvcore=0.1.5.post20210915=py37 19 | - intel-openmp=2021.3.0=h06a4308_3350 20 | - iopath=0.1.9=py37 21 | - jpeg=9b=h024ee3a_2 22 | - lcms2=2.12=h3be6417_0 23 | - ld_impl_linux-64=2.35.1=h7274673_9 24 | - libffi=3.3=he6710b0_2 25 | - libgcc-ng=9.3.0=h5101ec6_17 26 | - libgomp=9.3.0=h5101ec6_17 27 | - libpng=1.6.37=hbc83047_0 28 | - libstdcxx-ng=9.3.0=hd4cf53a_17 29 | - libtiff=4.2.0=h85742a9_0 30 | - libuv=1.40.0=h7b6447c_0 31 | - libwebp-base=1.2.0=h27cfd23_0 32 | - lz4-c=1.9.3=h295c915_1 33 | - mkl=2021.3.0=h06a4308_520 34 | - mkl-service=2.4.0=py37h7f8727e_0 35 | - mkl_fft=1.3.0=py37h42c9631_2 36 | - mkl_random=1.2.2=py37h51133e4_0 37 | - ncurses=6.2=he6710b0_1 38 | - ninja=1.10.2=hff7bd54_1 39 | - numpy=1.20.3=py37hf144106_0 40 | - numpy-base=1.20.3=py37h74d4b33_0 41 | - olefile=0.46=py37_0 42 | - openjpeg=2.4.0=h3ad879b_0 43 | - openssl=1.1.1l=h7f8727e_0 44 | - pillow=8.3.1=py37h2c7a002_0 45 | - portalocker=2.3.2=py37h89c1867_0 46 | - python=3.7.11=h12debd9_0 47 | - python_abi=3.7=2_cp37m 48 | - pytorch=1.7.1=py3.7_cuda10.1.243_cudnn7.6.3_0 49 | - pytorch3d=0.6.0=py37_cu101_pyt171 50 | - pyyaml=5.4.1=py37h5e8e339_0 51 | - readline=8.1=h27cfd23_0 52 | - six=1.16.0=pyhd3eb1b0_0 53 | - sqlite=3.36.0=hc218d9a_0 54 | - tabulate=0.8.9=pyhd8ed1ab_0 55 | - termcolor=1.1.0=py_2 56 | - tk=8.6.11=h1ccaba5_0 57 | - torchaudio=0.7.2=py37 58 | - torchvision=0.8.2=py37_cu101 59 | - tqdm=4.62.3=pyhd8ed1ab_0 60 | - typing_extensions=3.10.0.2=pyh06a4308_0 61 | - wheel=0.37.0=pyhd3eb1b0_1 62 | - xz=5.2.5=h7b6447c_0 63 | - yacs=0.1.6=py_0 64 | - yaml=0.2.5=h516909a_0 65 | - zlib=1.2.11=h7b6447c_3 66 | - zstd=1.4.9=haebb681_0 67 | - pip: 68 | - absl-py==0.15.0 69 | - beautifulsoup4==4.10.0 70 | - cached-property==1.5.2 71 | - cachetools==4.2.4 72 | - charset-normalizer==2.0.7 73 | - configargparse==1.5.3 74 | - cupy-cuda101==9.6.0 75 | - cycler==0.10.0 76 | - dill==0.3.4 77 | - fastrlock==0.8 78 | - filelock==3.3.0 79 | - gdown==4.4.0 80 | - google-auth==2.3.0 81 | - google-auth-oauthlib==0.4.6 82 | - grpcio==1.41.0 83 | - h5py==3.5.0 84 | - idna==3.2 85 | - imageio==2.9.0 86 | - imageio-ffmpeg==0.4.5 87 | - importlib-metadata==4.8.1 88 | - joblib==1.1.0 89 | - kiwisolver==1.3.2 90 | - kornia==0.5.11 91 | - lpips==0.1.4 92 | - markdown==3.3.4 93 | - matplotlib==3.4.3 94 | - networkx==2.6.3 95 | - oauthlib==3.1.1 96 | - opencv-python==4.5.3.56 97 | - packaging==21.0 98 | - pandas==1.3.4 99 | - pip==22.1.2 100 | - plotly==5.3.1 101 | - protobuf==3.18.1 102 | - pyasn1==0.4.8 103 | - pyasn1-modules==0.2.8 104 | - pyparsing==2.4.7 105 | - pyquaternion==0.9.9 106 | - pysocks==1.7.1 107 | - python-dateutil==2.8.2 108 | - pytz==2021.3 109 | - pywavelets==1.1.1 110 | - requests==2.26.0 111 | - requests-oauthlib==1.3.0 112 | - rsa==4.7.2 113 | - scikit-image==0.18.3 114 | - scikit-learn==1.0 115 | - scipy==1.7.1 116 | - setuptools==62.4.0 117 | - sklearn==0.0 118 | - soupsieve==2.2.1 119 | - tenacity==8.0.1 120 | - tensorboard==2.7.0 121 | - tensorboard-data-server==0.6.1 122 | - tensorboard-plugin-wit==1.8.0 123 | - tensorboardx==2.4 124 | - threadpoolctl==3.0.0 125 | - tifffile==2021.10.12 126 | - timm==0.4.5 127 | - urllib3==1.26.7 128 | - werkzeug==2.0.2 129 | - zipp==3.6.0 130 | prefix: /phoenix/S7/qw246/anaconda3/envs/3d_moments 131 | -------------------------------------------------------------------------------- /examples/bottle/disp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/bottle/disp.npy -------------------------------------------------------------------------------- /examples/bottle/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/bottle/input.jpg -------------------------------------------------------------------------------- /examples/burger/disp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/burger/disp.npy -------------------------------------------------------------------------------- /examples/burger/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/burger/input.jpg -------------------------------------------------------------------------------- /examples/camera/disp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/camera/disp.npy -------------------------------------------------------------------------------- /examples/camera/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/camera/input.jpg -------------------------------------------------------------------------------- /examples/car/disp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/car/disp.npy -------------------------------------------------------------------------------- /examples/car/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/car/input.jpg -------------------------------------------------------------------------------- /examples/fireworks/disp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/fireworks/disp.npy -------------------------------------------------------------------------------- /examples/fireworks/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/fireworks/input.jpg -------------------------------------------------------------------------------- /examples/rook/disp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/rook/disp.npy -------------------------------------------------------------------------------- /examples/rook/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/rook/input.jpg -------------------------------------------------------------------------------- /examples/spoon/disp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/spoon/disp.npy -------------------------------------------------------------------------------- /examples/spoon/input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/spoon/input.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class down(nn.Module): 7 | """ 8 | A class for creating neural network blocks containing layers: 9 | 10 | Average Pooling --> Convlution + Leaky ReLU --> Convolution + Leaky ReLU 11 | 12 | This is used in the UNet Class to create a UNet like NN architecture. 13 | 14 | ... 15 | 16 | Methods 17 | ------- 18 | forward(x) 19 | Returns output tensor after passing input `x` to the neural network 20 | block. 21 | """ 22 | 23 | 24 | def __init__(self, inChannels, outChannels, filterSize): 25 | """ 26 | Parameters 27 | ---------- 28 | inChannels : int 29 | number of input channels for the first convolutional layer. 30 | outChannels : int 31 | number of output channels for the first convolutional layer. 32 | This is also used as input and output channels for the 33 | second convolutional layer. 34 | filterSize : int 35 | filter size for the convolution filter. input N would create 36 | a N x N filter. 37 | """ 38 | 39 | 40 | super(down, self).__init__() 41 | # Initialize convolutional layers. 42 | self.conv1 = nn.Conv2d(inChannels, outChannels, filterSize, stride=1, padding=int((filterSize - 1) / 2)) 43 | self.conv2 = nn.Conv2d(outChannels, outChannels, filterSize, stride=1, padding=int((filterSize - 1) / 2)) 44 | 45 | def forward(self, x): 46 | """ 47 | Returns output tensor after passing input `x` to the neural network 48 | block. 49 | 50 | Parameters 51 | ---------- 52 | x : tensor 53 | input to the NN block. 54 | 55 | Returns 56 | ------- 57 | tensor 58 | output of the NN block. 59 | """ 60 | 61 | 62 | # Average pooling with kernel size 2 (2 x 2). 63 | x = F.avg_pool2d(x, 2) 64 | # Convolution + Leaky ReLU 65 | x = F.leaky_relu(self.conv1(x), negative_slope = 0.1) 66 | # Convolution + Leaky ReLU 67 | x = F.leaky_relu(self.conv2(x), negative_slope = 0.1) 68 | return x 69 | 70 | class up(nn.Module): 71 | """ 72 | A class for creating neural network blocks containing layers: 73 | 74 | Bilinear interpolation --> Convlution + Leaky ReLU --> Convolution + Leaky ReLU 75 | 76 | This is used in the UNet Class to create a UNet like NN architecture. 77 | 78 | ... 79 | 80 | Methods 81 | ------- 82 | forward(x, skpCn) 83 | Returns output tensor after passing input `x` to the neural network 84 | block. 85 | """ 86 | 87 | 88 | def __init__(self, inChannels, outChannels, vChannels=0): 89 | """ 90 | Parameters 91 | ---------- 92 | inChannels : int 93 | number of input channels for the first convolutional layer. 94 | outChannels : int 95 | number of output channels for the first convolutional layer. 96 | This is also used for setting input and output channels for 97 | the second convolutional layer. 98 | """ 99 | 100 | 101 | super(up, self).__init__() 102 | self.vChannels = vChannels 103 | # Initialize convolutional layers. 104 | self.conv1 = nn.Conv2d(inChannels, outChannels, 3, stride=1, padding=1) 105 | # (2 * outChannels) is used for accommodating skip connection. 106 | self.conv2 = nn.Conv2d(2 * outChannels+vChannels, outChannels, 3, stride=1, padding=1) 107 | 108 | def forward(self, x, skpCn, v=None): 109 | """ 110 | Returns output tensor after passing input `x` to the neural network 111 | block. 112 | 113 | Parameters 114 | ---------- 115 | x : tensor 116 | input to the NN block. 117 | skpCn : tensor 118 | skip connection input to the NN block. 119 | 120 | Returns 121 | ------- 122 | tensor 123 | output of the NN block. 124 | """ 125 | 126 | # Bilinear interpolation with scaling 2. 127 | x = F.interpolate(x, scale_factor=2, mode='bilinear') 128 | # Convolution + Leaky ReLU 129 | x = F.leaky_relu(self.conv1(x), negative_slope = 0.1) 130 | # Convolution + Leaky ReLU on (`x`, `skpCn`) 131 | if v is None: 132 | x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope = 0.1) 133 | else: 134 | n, _, h, w = x.shape 135 | v = v.expand(n, self.vChannels, h, w) 136 | x = F.leaky_relu(self.conv2(torch.cat((x, skpCn, v), 1)), negative_slope = 0.1) 137 | return x 138 | 139 | 140 | 141 | class UNet(nn.Module): 142 | """ 143 | A class for creating UNet like architecture as specified by the 144 | Super SloMo paper. 145 | 146 | ... 147 | 148 | Methods 149 | ------- 150 | forward(x) 151 | Returns output tensor after passing input `x` to the neural network 152 | block. 153 | """ 154 | 155 | 156 | def __init__(self, inChannels, outChannels, vChannels=128): 157 | """ 158 | Parameters 159 | ---------- 160 | inChannels : int 161 | number of input channels for the UNet. 162 | outChannels : int 163 | number of output channels for the UNet. 164 | """ 165 | 166 | 167 | super(UNet, self).__init__() 168 | # Initialize neural network blocks. 169 | self.conv1 = nn.Conv2d(inChannels, 32, 7, stride=1, padding=3) 170 | self.conv2 = nn.Conv2d(32, 32, 7, stride=1, padding=3) 171 | self.down1 = down(32, 64, 5) 172 | self.down2 = down(64, 128, 3) 173 | self.down3 = down(128, 256, 3) 174 | self.down4 = down(256, 512, 3) 175 | self.down5 = down(512, 512, 3) 176 | self.up1 = up(512, 512, vChannels) 177 | self.up2 = up(512, 256) 178 | self.up3 = up(256, 128) 179 | self.up4 = up(128, 64) 180 | self.up5 = up(64, 32) 181 | self.conv3 = nn.Conv2d(32, outChannels, 3, stride=1, padding=1) 182 | 183 | def forward(self, x, v): 184 | """ 185 | Returns output tensor after passing input `x` to the neural network. 186 | 187 | Parameters 188 | ---------- 189 | x : tensor 190 | input to the UNet. 191 | 192 | Returns 193 | ------- 194 | tensor 195 | output of the UNet. 196 | """ 197 | inp = x 198 | v = v[:, :, None, None] 199 | 200 | x = F.leaky_relu(self.conv1(x), negative_slope = 0.1) 201 | s1 = F.leaky_relu(self.conv2(x), negative_slope = 0.1) 202 | s2 = self.down1(s1) 203 | s3 = self.down2(s2) 204 | s4 = self.down3(s3) 205 | s5 = self.down4(s4) 206 | x = self.down5(s5) 207 | x = self.up1(x, s5, v) 208 | x = self.up2(x, s4) 209 | x = self.up3(x, s3) 210 | x = self.up4(x, s2) 211 | x = self.up5(x, s1) 212 | x = torch.tanh(self.conv3(x)) 213 | 214 | x = x + inp[:, :3] 215 | 216 | return x 217 | 218 | class FCN(nn.Module): 219 | 220 | def __init__(self): 221 | 222 | super(FCN, self).__init__() 223 | 224 | self.model = nn.Sequential( 225 | nn.Linear(3, 8), 226 | nn.LeakyReLU(), 227 | nn.Linear(8, 32), 228 | nn.LeakyReLU(), 229 | nn.Linear(32, 125), 230 | nn.LeakyReLU(), 231 | nn.Linear(125, 125) 232 | ) 233 | 234 | def forward(self, inp): 235 | 236 | return torch.cat((self.model(inp), inp), 1) -------------------------------------------------------------------------------- /model_3dm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import torch 18 | from networks.resunet import ResUNet 19 | from networks.img_decoder import ImgDecoder 20 | from utils import de_parallel 21 | 22 | 23 | class Namespace: 24 | 25 | def __init__(self, **kwargs): 26 | for name in kwargs: 27 | setattr(self, name, kwargs[name]) 28 | 29 | def __eq__(self, other): 30 | if not isinstance(other, Namespace): 31 | return NotImplemented 32 | return vars(self) == vars(other) 33 | 34 | def __contains__(self, key): 35 | return key in self.__dict__ 36 | 37 | 38 | 39 | ######################################################################################################################## 40 | # creation/saving/loading of the model 41 | ######################################################################################################################## 42 | 43 | 44 | class SpaceTimeModel(object): 45 | def __init__(self, args): 46 | self.args = args 47 | load_opt = not args.no_load_opt 48 | load_scheduler = not args.no_load_scheduler 49 | device = torch.device('cuda:{}'.format(args.local_rank)) 50 | 51 | # initialize feature extraction network 52 | feat_in_ch = 4 53 | if args.use_inpainting_mask_for_feature: 54 | feat_in_ch += 1 55 | if args.use_depth_for_feature: 56 | feat_in_ch += 1 57 | self.feature_net = ResUNet(args, in_ch=feat_in_ch, out_ch=args.feature_dim).to(device) 58 | # initialize decoder 59 | decoder_in_ch = args.feature_dim + 4 60 | decoder_out_ch = 3 61 | 62 | if args.use_depth_for_decoding: 63 | decoder_in_ch += 1 64 | if args.use_mask_for_decoding: 65 | decoder_in_ch += 1 66 | 67 | self.img_decoder = ImgDecoder(args, in_ch=decoder_in_ch, out_ch=decoder_out_ch).to(device) 68 | 69 | learnable_params = list(self.feature_net.parameters()) 70 | learnable_params += list(self.img_decoder.parameters()) 71 | 72 | self.learnable_params = learnable_params 73 | self.optimizer = torch.optim.Adam(learnable_params, lr=args.lr, weight_decay=1e-4, betas=(0.9, 0.999)) 74 | 75 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 76 | step_size=args.lrate_decay_steps, 77 | gamma=args.lrate_decay_factor) 78 | 79 | out_folder = os.path.join(args.rootdir, 'out', args.expname) 80 | self.start_step = self.load_from_ckpt(out_folder, 81 | load_opt=load_opt, 82 | load_scheduler=load_scheduler) 83 | 84 | if args.distributed: 85 | self.feature_net = torch.nn.parallel.DistributedDataParallel( 86 | self.feature_net, 87 | device_ids=[args.local_rank], 88 | output_device=args.local_rank, 89 | ) 90 | 91 | self.img_decoder = torch.nn.parallel.DistributedDataParallel( 92 | self.img_decoder, 93 | device_ids=[args.local_rank], 94 | output_device=args.local_rank, 95 | ) 96 | 97 | def switch_to_eval(self): 98 | self.feature_net.eval() 99 | self.img_decoder.eval() 100 | 101 | def switch_to_train(self): 102 | self.feature_net.train() 103 | self.img_decoder.train() 104 | 105 | def save_model(self, filename): 106 | to_save = {'optimizer': self.optimizer.state_dict(), 107 | 'scheduler': self.scheduler.state_dict(), 108 | 'feature_net': de_parallel(self.feature_net).state_dict(), 109 | 'img_decoder': de_parallel(self.img_decoder).state_dict() 110 | } 111 | torch.save(to_save, filename) 112 | 113 | def load_model(self, filename, load_opt=True, load_scheduler=True): 114 | if self.args.distributed: 115 | to_load = torch.load(filename, map_location='cuda:{}'.format(self.args.local_rank)) 116 | else: 117 | to_load = torch.load(filename) 118 | 119 | if load_opt: 120 | self.optimizer.load_state_dict(to_load['optimizer']) 121 | if load_scheduler: 122 | self.scheduler.load_state_dict(to_load['scheduler']) 123 | 124 | self.feature_net.load_state_dict(to_load['feature_net']) 125 | self.img_decoder.load_state_dict(to_load['img_decoder']) 126 | 127 | def load_from_ckpt(self, out_folder, 128 | load_opt=True, 129 | load_scheduler=True, 130 | force_latest_ckpt=False): 131 | ''' 132 | load model from existing checkpoints and return the current step 133 | :param out_folder: the directory that stores ckpts 134 | :return: the current starting step 135 | ''' 136 | 137 | # all existing ckpts 138 | ckpts = [] 139 | if os.path.exists(out_folder): 140 | ckpts = [os.path.join(out_folder, f) 141 | for f in sorted(os.listdir(out_folder)) if f.endswith('.pth')] 142 | 143 | if self.args.ckpt_path is not None and not force_latest_ckpt: 144 | if os.path.isfile(self.args.ckpt_path): # load the specified ckpt 145 | ckpts = [self.args.ckpt_path] 146 | 147 | if len(ckpts) > 0 and not self.args.no_reload: 148 | fpath = ckpts[-1] 149 | self.load_model(fpath, load_opt, load_scheduler) 150 | step = int(fpath[-10:-4]) 151 | print('Reloading from {}, starting at step={}'.format(fpath, step)) 152 | else: 153 | print('No ckpts found, training from scratch...') 154 | step = 0 155 | 156 | return step 157 | 158 | 159 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/networks/__init__.py -------------------------------------------------------------------------------- /networks/img_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """ Parts of the U-Net model """ 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | 23 | class DoubleConv(nn.Module): 24 | """(convolution => [GN] => PReLU) * 2""" 25 | 26 | def __init__(self, in_channels, out_channels, mid_channels=None, group=None): 27 | super().__init__() 28 | if not mid_channels: 29 | mid_channels = out_channels 30 | #if not group: 31 | # group = out_channels//16 32 | self.double_conv1 = nn.Sequential( 33 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, padding_mode='reflect'), 34 | nn.GroupNorm(mid_channels//32, mid_channels), 35 | nn.PReLU() 36 | ) 37 | self.double_conv2 = nn.Sequential( 38 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, padding_mode='reflect'), 39 | nn.GroupNorm(out_channels//32, out_channels), 40 | nn.PReLU() 41 | ) 42 | 43 | def forward(self, x): 44 | x1 = self.double_conv1(x) 45 | x2 = self.double_conv2(x1) 46 | return x1, x2 47 | 48 | 49 | class Down(nn.Module): 50 | """Downscaling with maxpool then double conv""" 51 | 52 | def __init__(self, in_channels, out_channels, group=None): 53 | super().__init__() 54 | self.maxpool_conv = nn.Sequential( 55 | nn.MaxPool2d(2), 56 | DoubleConv(in_channels, out_channels, group) 57 | ) 58 | 59 | def forward(self, x): 60 | return self.maxpool_conv(x) 61 | 62 | 63 | class ConcatDoubleConv(nn.Module): 64 | """(convolution => [GN] => PReLU) * 2""" 65 | 66 | def __init__(self, in_channels, out_channels, mid_channels=None, group=None): 67 | super().__init__() 68 | if not mid_channels: 69 | mid_channels = out_channels 70 | #if not group: 71 | # group = out_channels//16 72 | self.double_conv1 = nn.Sequential( 73 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, padding_mode='reflect'), 74 | nn.GroupNorm(mid_channels//32, mid_channels), 75 | nn.PReLU() 76 | ) 77 | self.double_conv2 = nn.Sequential( 78 | nn.Conv2d(mid_channels*2, out_channels, kernel_size=3, padding=1, padding_mode='reflect'), 79 | nn.GroupNorm(out_channels//32, out_channels), 80 | nn.PReLU() 81 | ) 82 | 83 | def forward(self, x, xc1): 84 | x1 = self.double_conv1(x) 85 | x2 = self.double_conv2(torch.cat([xc1, x1], dim=1)) 86 | return x2 87 | 88 | 89 | class Up(nn.Module): 90 | """Upscaling then double conv""" 91 | def __init__(self, in_channels, out_channels, mid_channels, bilinear=True): 92 | super().__init__() 93 | 94 | # if bilinear, use the normal convolutions to reduce the number of channels 95 | if bilinear: 96 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 97 | self.conv = ConcatDoubleConv(in_channels, out_channels, mid_channels) 98 | else: 99 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 100 | self.conv = ConcatDoubleConv(in_channels, out_channels, mid_channels) 101 | 102 | def forward(self, x, xc1, xc2): 103 | x1 = self.up(x) 104 | # input is CHW 105 | diffY = xc1.size()[2] - x1.size()[2] 106 | diffX = xc1.size()[3] - x1.size()[3] 107 | 108 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 109 | diffY // 2, diffY - diffY // 2], mode='reflect') 110 | # if you have padding issues, see 111 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 112 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 113 | x = torch.cat([xc1, x1], dim=1) 114 | return self.conv(x, xc2) 115 | 116 | 117 | class OutConv(nn.Module): 118 | def __init__(self, in_channels, out_channels): 119 | super(OutConv, self).__init__() 120 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 121 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect') 122 | #self.gn = nn.GroupNorm(1, out_channels) 123 | self.act = nn.Sigmoid() 124 | 125 | def forward(self, x, xc1): 126 | x1 = self.up(x) 127 | x2 = torch.cat([xc1, x1], dim=1) 128 | #return self.act(self.gn(self.conv(x2))) 129 | return self.act(self.conv(x2)) 130 | 131 | 132 | class ImgDecoder(nn.Module): 133 | def __init__(self, args, in_ch, out_ch): 134 | super(ImgDecoder, self).__init__() 135 | self.firstconv = nn.Conv2d(in_ch, 32, kernel_size=7, stride=1, padding=3, bias=False) #256x256 136 | self.firstprelu = nn.PReLU() 137 | self.down1 = Down(32, 64) #128x128 138 | self.down2 = Down(64, 128) #64x64 139 | self.down3 = Down(128, 256) #32x32 140 | self.down4 = Down(256, 512) #16x16 141 | self.conv5 = DoubleConv(512, 512) #16x16 142 | 143 | self.up1 = Up(1024, 256, 512) 144 | self.up2 = Up(512, 128, 256) 145 | self.up3 = Up(256, 64, 128) 146 | self.up4 = Up(128, 32, 64) 147 | self.outc = OutConv(64, out_ch) 148 | self.alpha = nn.Parameter(torch.tensor(1.)) 149 | 150 | def compute_weight_for_two_frame_blending(self, time, disp1, disp2, alpha0, alpha1): 151 | weight1 = (1 - time) * torch.exp(self.alpha*disp1) * alpha0 152 | weight2 = time * torch.exp(self.alpha*disp2) * alpha1 153 | sum_weight = torch.clamp(weight1 + weight2, min=1e-6) 154 | out_weight1 = weight1 / sum_weight 155 | out_weight2 = weight2 / sum_weight 156 | return out_weight1, out_weight2 157 | 158 | def forward(self, x0, x1, time): 159 | disp0 = x0[:, [-1]] 160 | disp1 = x1[:, [-1]] 161 | alpha0 = x0[:, [3]] 162 | alpha1 = x1[:, [3]] 163 | w0, w1 = self.compute_weight_for_two_frame_blending(time, disp0, disp1, alpha0, alpha1) 164 | x = w0 * x0 + w1 * x1 165 | 166 | x0 = self.firstprelu(self.firstconv(x)) 167 | x20, x21 = self.down1(x0) 168 | x30, x31 = self.down2(x21) 169 | x40, x41 = self.down3(x31) 170 | x50, x51 = self.down4(x41) 171 | x60, x61 = self.conv5(x51) 172 | 173 | xt1 = self.up1(x61, x51, x50) 174 | xt2 = self.up2(xt1, x41, x40) 175 | xt3 = self.up3(xt2, x31, x30) 176 | xt4 = self.up4(xt3, x21, x20) 177 | target_img = self.outc(xt4, x0) 178 | 179 | return target_img -------------------------------------------------------------------------------- /networks/inpainting_nets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import math 17 | import torch 18 | import torch.nn as nn 19 | import numpy as np 20 | import torch.nn.functional as F 21 | 22 | 23 | class BaseNetwork(nn.Module): 24 | def __init__(self): 25 | super(BaseNetwork, self).__init__() 26 | 27 | def init_weights(self, init_type='normal', gain=0.02): 28 | ''' 29 | initialize network's weights 30 | init_type: normal | xavier | kaiming | orthogonal 31 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 32 | ''' 33 | 34 | def init_func(m): 35 | classname = m.__class__.__name__ 36 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 37 | if init_type == 'normal': 38 | nn.init.normal_(m.weight.data, 0.0, gain) 39 | elif init_type == 'xavier': 40 | nn.init.xavier_normal_(m.weight.data, gain=gain) 41 | elif init_type == 'kaiming': 42 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 43 | elif init_type == 'orthogonal': 44 | nn.init.orthogonal_(m.weight.data, gain=gain) 45 | 46 | if hasattr(m, 'bias') and m.bias is not None: 47 | nn.init.constant_(m.bias.data, 0.0) 48 | 49 | elif classname.find('BatchNorm2d') != -1: 50 | nn.init.normal_(m.weight.data, 1.0, gain) 51 | nn.init.constant_(m.bias.data, 0.0) 52 | 53 | self.apply(init_func) 54 | 55 | 56 | def weights_init(init_type='gaussian'): 57 | def init_fun(m): 58 | classname = m.__class__.__name__ 59 | if (classname.find('Conv') == 0 or classname.find( 60 | 'Linear') == 0) and hasattr(m, 'weight'): 61 | if init_type == 'gaussian': 62 | nn.init.normal_(m.weight, 0.0, 0.02) 63 | elif init_type == 'xavier': 64 | nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) 65 | elif init_type == 'kaiming': 66 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 67 | elif init_type == 'orthogonal': 68 | nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) 69 | elif init_type == 'default': 70 | pass 71 | else: 72 | assert 0, "Unsupported initialization: {}".format(init_type) 73 | if hasattr(m, 'bias') and m.bias is not None: 74 | nn.init.constant_(m.bias, 0.0) 75 | 76 | return init_fun 77 | 78 | 79 | class PartialConv(nn.Module): 80 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 81 | padding=0, dilation=1, groups=1, bias=True): 82 | super().__init__() 83 | self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size, 84 | stride, padding, dilation, groups, bias) 85 | self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, 86 | stride, padding, dilation, groups, False) 87 | self.input_conv.apply(weights_init('kaiming')) 88 | self.slide_winsize = in_channels * kernel_size * kernel_size 89 | 90 | torch.nn.init.constant_(self.mask_conv.weight, 1.0) 91 | 92 | # mask is not updated 93 | for param in self.mask_conv.parameters(): 94 | param.requires_grad = False 95 | 96 | def forward(self, input, mask): 97 | # http://masc.cs.gmu.edu/wiki/partialconv 98 | # C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M) 99 | # W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0) 100 | output = self.input_conv(input * mask) 101 | if self.input_conv.bias is not None: 102 | output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as( 103 | output) 104 | else: 105 | output_bias = torch.zeros_like(output) 106 | 107 | with torch.no_grad(): 108 | output_mask = self.mask_conv(mask) 109 | 110 | no_update_holes = output_mask == 0 111 | 112 | mask_sum = output_mask.masked_fill_(no_update_holes, 1.0) 113 | 114 | output_pre = ((output - output_bias) * self.slide_winsize) / mask_sum + output_bias 115 | output = output_pre.masked_fill_(no_update_holes, 0.0) 116 | 117 | new_mask = torch.ones_like(output) 118 | new_mask = new_mask.masked_fill_(no_update_holes, 0.0) 119 | 120 | return output, new_mask 121 | 122 | 123 | class PCBActiv(nn.Module): 124 | def __init__(self, in_ch, out_ch, bn=True, no_tracking_stats=False, sample='none-3', activ='relu', 125 | conv_bias=False): 126 | super().__init__() 127 | if sample == 'down-5': 128 | self.conv = PartialConv(in_ch, out_ch, 5, 2, 2, bias=conv_bias) 129 | elif sample == 'down-7': 130 | self.conv = PartialConv(in_ch, out_ch, 7, 2, 3, bias=conv_bias) 131 | elif sample == 'down-3': 132 | self.conv = PartialConv(in_ch, out_ch, 3, 2, 1, bias=conv_bias) 133 | else: 134 | self.conv = PartialConv(in_ch, out_ch, 3, 1, 1, bias=conv_bias) 135 | 136 | if bn: 137 | if no_tracking_stats: 138 | self.bn = nn.BatchNorm2d(out_ch, track_running_stats=False, affine=True) 139 | else: 140 | self.bn = nn.BatchNorm2d(out_ch) 141 | if activ == 'relu': 142 | self.activation = nn.ReLU() 143 | elif activ == 'leaky': 144 | self.activation = nn.LeakyReLU(negative_slope=0.2) 145 | 146 | def forward(self, input, input_mask): 147 | h, h_mask = self.conv(input, input_mask) 148 | if hasattr(self, 'bn'): 149 | h = self.bn(h) 150 | if hasattr(self, 'activation'): 151 | h = self.activation(h) 152 | return h, h_mask 153 | 154 | 155 | class Inpaint_Depth_Net(nn.Module): 156 | def __init__(self, layer_size=7, upsampling_mode='nearest'): 157 | super().__init__() 158 | in_channels = 4 159 | out_channels = 1 160 | self.freeze_enc_bn = False 161 | self.upsampling_mode = upsampling_mode 162 | self.layer_size = layer_size 163 | self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7', conv_bias=True) 164 | self.enc_2 = PCBActiv(64, 128, sample='down-5', conv_bias=True) 165 | self.enc_3 = PCBActiv(128, 256, sample='down-5') 166 | self.enc_4 = PCBActiv(256, 512, sample='down-3') 167 | for i in range(4, self.layer_size): 168 | name = 'enc_{:d}'.format(i + 1) 169 | setattr(self, name, PCBActiv(512, 512, sample='down-3')) 170 | 171 | for i in range(4, self.layer_size): 172 | name = 'dec_{:d}'.format(i + 1) 173 | setattr(self, name, PCBActiv(512 + 512, 512, activ='leaky')) 174 | self.dec_4 = PCBActiv(512 + 256, 256, activ='leaky') 175 | self.dec_3 = PCBActiv(256 + 128, 128, activ='leaky') 176 | self.dec_2 = PCBActiv(128 + 64, 64, activ='leaky') 177 | self.dec_1 = PCBActiv(64 + in_channels, out_channels, 178 | bn=False, activ=None, conv_bias=True) 179 | 180 | def add_border(self, input, mask_flag, PCONV=True): 181 | with torch.no_grad(): 182 | h = input.shape[-2] 183 | w = input.shape[-1] 184 | require_len_unit = 2 ** self.layer_size 185 | residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit 186 | residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit 187 | enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device) 188 | if mask_flag: 189 | if PCONV is False: 190 | enlarge_input += 1.0 191 | enlarge_input = enlarge_input.clamp(0.0, 1.0) 192 | else: 193 | enlarge_input[:, 2, ...] = 0.0 194 | anchor_h = residual_h//2 195 | anchor_w = residual_w//2 196 | enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input 197 | 198 | return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w] 199 | 200 | def forward_3P(self, mask, context, depth, edge, unit_length=128, cuda=None): 201 | input = torch.cat((depth, edge, context, mask), dim=1) 202 | n, c, h, w = input.shape 203 | residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h) 204 | residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w) 205 | anchor_h = residual_h//2 206 | anchor_w = residual_w//2 207 | enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda) 208 | enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input 209 | # enlarge_input[:, 3] = 1. - enlarge_input[:, 3] 210 | depth_output = self.forward(enlarge_input) 211 | depth_output = depth_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] 212 | # import pdb; pdb.set_trace() 213 | 214 | return depth_output 215 | 216 | def forward(self, input_feat, refine_border=False, sample=False, PCONV=True): 217 | input = input_feat 218 | input_mask = (input_feat[:, -2:-1] + input_feat[:, -1:]).clamp(0, 1).repeat(1, input.shape[1], 1, 1) 219 | 220 | vis_input = input.cpu().data.numpy() 221 | vis_input_mask = input_mask.cpu().data.numpy() 222 | H, W = input.shape[-2:] 223 | if refine_border is True: 224 | input, anchor = self.add_border(input, mask_flag=False) 225 | input_mask, anchor = self.add_border(input_mask, mask_flag=True, PCONV=PCONV) 226 | h_dict = {} # for the output of enc_N 227 | h_mask_dict = {} # for the output of enc_N 228 | h_dict['h_0'], h_mask_dict['h_0'] = input, input_mask 229 | 230 | h_key_prev = 'h_0' 231 | for i in range(1, self.layer_size + 1): 232 | l_key = 'enc_{:d}'.format(i) 233 | h_key = 'h_{:d}'.format(i) 234 | h_dict[h_key], h_mask_dict[h_key] = getattr(self, l_key)( 235 | h_dict[h_key_prev], h_mask_dict[h_key_prev]) 236 | h_key_prev = h_key 237 | 238 | h_key = 'h_{:d}'.format(self.layer_size) 239 | h, h_mask = h_dict[h_key], h_mask_dict[h_key] 240 | 241 | for i in range(self.layer_size, 0, -1): 242 | enc_h_key = 'h_{:d}'.format(i - 1) 243 | dec_l_key = 'dec_{:d}'.format(i) 244 | 245 | h = F.interpolate(h, scale_factor=2, mode=self.upsampling_mode) 246 | h_mask = F.interpolate(h_mask, scale_factor=2, mode='nearest') 247 | 248 | h = torch.cat([h, h_dict[enc_h_key]], dim=1) 249 | h_mask = torch.cat([h_mask, h_mask_dict[enc_h_key]], dim=1) 250 | h, h_mask = getattr(self, dec_l_key)(h, h_mask) 251 | output = h 252 | if refine_border is True: 253 | h_mask = h_mask[..., anchor[0]:anchor[1], anchor[2]:anchor[3]] 254 | output = output[..., anchor[0]:anchor[1], anchor[2]:anchor[3]] 255 | 256 | return output 257 | 258 | 259 | class Inpaint_Edge_Net(BaseNetwork): 260 | def __init__(self, residual_blocks=8, init_weights=True): 261 | super(Inpaint_Edge_Net, self).__init__() 262 | in_channels = 7 263 | out_channels = 1 264 | self.encoder = [] 265 | # 0 266 | self.encoder_0 = nn.Sequential( 267 | nn.ReflectionPad2d(3), 268 | spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0), True), 269 | nn.InstanceNorm2d(64, track_running_stats=False), 270 | nn.ReLU(True)) 271 | # 1 272 | self.encoder_1 = nn.Sequential( 273 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), True), 274 | nn.InstanceNorm2d(128, track_running_stats=False), 275 | nn.ReLU(True)) 276 | # 2 277 | self.encoder_2 = nn.Sequential( 278 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), True), 279 | nn.InstanceNorm2d(256, track_running_stats=False), 280 | nn.ReLU(True)) 281 | # 3 282 | blocks = [] 283 | for _ in range(residual_blocks): 284 | block = ResnetBlock(256, 2) 285 | blocks.append(block) 286 | 287 | self.middle = nn.Sequential(*blocks) 288 | # + 3 289 | self.decoder_0 = nn.Sequential( 290 | spectral_norm(nn.ConvTranspose2d(in_channels=256+256, out_channels=128, kernel_size=4, stride=2, padding=1), True), 291 | nn.InstanceNorm2d(128, track_running_stats=False), 292 | nn.ReLU(True)) 293 | # + 2 294 | self.decoder_1 = nn.Sequential( 295 | spectral_norm(nn.ConvTranspose2d(in_channels=128+128, out_channels=64, kernel_size=4, stride=2, padding=1), True), 296 | nn.InstanceNorm2d(64, track_running_stats=False), 297 | nn.ReLU(True)) 298 | # + 1 299 | self.decoder_2 = nn.Sequential( 300 | nn.ReflectionPad2d(3), 301 | nn.Conv2d(in_channels=64+64, out_channels=out_channels, kernel_size=7, padding=0), 302 | ) 303 | 304 | if init_weights: 305 | self.init_weights() 306 | 307 | def add_border(self, input, channel_pad_1=None): 308 | h = input.shape[-2] 309 | w = input.shape[-1] 310 | require_len_unit = 16 311 | residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit 312 | residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit 313 | enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device) 314 | if channel_pad_1 is not None: 315 | for channel in channel_pad_1: 316 | enlarge_input[:, channel] = 1 317 | anchor_h = residual_h//2 318 | anchor_w = residual_w//2 319 | enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input 320 | 321 | return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w] 322 | 323 | def forward_3P(self, mask, context, rgb, disp, edge, unit_length=128, cuda=None): 324 | input = torch.cat((rgb, disp/disp.max(), edge, context, mask), dim=1) 325 | n, c, h, w = input.shape 326 | residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h) 327 | residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w) 328 | anchor_h = residual_h//2 329 | anchor_w = residual_w//2 330 | enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda) 331 | enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input 332 | edge_output = self.forward(enlarge_input) 333 | edge_output = edge_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] 334 | return edge_output 335 | 336 | def forward(self, x, refine_border=False): 337 | if refine_border: 338 | x, anchor = self.add_border(x, [5]) 339 | x1 = self.encoder_0(x) 340 | x2 = self.encoder_1(x1) 341 | x3 = self.encoder_2(x2) 342 | x4 = self.middle(x3) 343 | x5 = self.decoder_0(torch.cat((x4, x3), dim=1)) 344 | x6 = self.decoder_1(torch.cat((x5, x2), dim=1)) 345 | x7 = self.decoder_2(torch.cat((x6, x1), dim=1)) 346 | x = torch.sigmoid(x7) 347 | if refine_border: 348 | x = x[..., anchor[0]:anchor[1], anchor[2]:anchor[3]] 349 | 350 | return x 351 | 352 | 353 | class Inpaint_Color_Net(nn.Module): 354 | def __init__(self, layer_size=7, upsampling_mode='nearest', add_hole_mask=False, add_two_layer=False, add_border=False): 355 | super().__init__() 356 | self.freeze_enc_bn = False 357 | self.upsampling_mode = upsampling_mode 358 | self.layer_size = layer_size 359 | in_channels = 6 360 | self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7') 361 | self.enc_2 = PCBActiv(64, 128, sample='down-5') 362 | self.enc_3 = PCBActiv(128, 256, sample='down-5') 363 | self.enc_4 = PCBActiv(256, 512, sample='down-3') 364 | self.enc_5 = PCBActiv(512, 512, sample='down-3') 365 | self.enc_6 = PCBActiv(512, 512, sample='down-3') 366 | self.enc_7 = PCBActiv(512, 512, sample='down-3') 367 | 368 | self.dec_7 = PCBActiv(512+512, 512, activ='leaky') 369 | self.dec_6 = PCBActiv(512+512, 512, activ='leaky') 370 | 371 | self.dec_5A = PCBActiv(512 + 512, 512, activ='leaky') 372 | self.dec_4A = PCBActiv(512 + 256, 256, activ='leaky') 373 | self.dec_3A = PCBActiv(256 + 128, 128, activ='leaky') 374 | self.dec_2A = PCBActiv(128 + 64, 64, activ='leaky') 375 | self.dec_1A = PCBActiv(64 + in_channels, 3, bn=False, activ=None, conv_bias=True) 376 | ''' 377 | self.dec_5B = PCBActiv(512 + 512, 512, activ='leaky') 378 | self.dec_4B = PCBActiv(512 + 256, 256, activ='leaky') 379 | self.dec_3B = PCBActiv(256 + 128, 128, activ='leaky') 380 | self.dec_2B = PCBActiv(128 + 64, 64, activ='leaky') 381 | self.dec_1B = PCBActiv(64 + 4, 1, bn=False, activ=None, conv_bias=True) 382 | ''' 383 | def cat(self, A, B): 384 | return torch.cat((A, B), dim=1) 385 | 386 | def upsample(self, feat, mask): 387 | feat = F.interpolate(feat, scale_factor=2, mode=self.upsampling_mode) 388 | mask = F.interpolate(mask, scale_factor=2, mode='nearest') 389 | 390 | return feat, mask 391 | 392 | def forward_3P(self, mask, context, rgb, edge, unit_length=128, cuda=None): 393 | input = torch.cat((rgb, edge, context, mask), dim=1) 394 | n, c, h, w = input.shape 395 | residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h) # + 128 396 | residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w) # + 256 397 | anchor_h = residual_h//2 398 | anchor_w = residual_w//2 399 | enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda) 400 | enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input 401 | # enlarge_input[:, 3] = 1. - enlarge_input[:, 3] 402 | enlarge_input = enlarge_input.to(cuda) 403 | rgb_output = self.forward(enlarge_input) 404 | rgb_output = rgb_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] 405 | return rgb_output 406 | 407 | def forward(self, input, add_border=False): 408 | input_mask = (input[:, -2:-1] + input[:, -1:]).clamp(0, 1) 409 | H, W = input.shape[-2:] 410 | f_0, h_0 = input, input_mask.repeat((1,input.shape[1],1,1)) 411 | f_1, h_1 = self.enc_1(f_0, h_0) 412 | f_2, h_2 = self.enc_2(f_1, h_1) 413 | f_3, h_3 = self.enc_3(f_2, h_2) 414 | f_4, h_4 = self.enc_4(f_3, h_3) 415 | f_5, h_5 = self.enc_5(f_4, h_4) 416 | f_6, h_6 = self.enc_6(f_5, h_5) 417 | f_7, h_7 = self.enc_7(f_6, h_6) 418 | 419 | o_7, k_7 = self.upsample(f_7, h_7) 420 | o_6, k_6 = self.dec_7(self.cat(o_7, f_6), self.cat(k_7, h_6)) 421 | o_6, k_6 = self.upsample(o_6, k_6) 422 | o_5, k_5 = self.dec_6(self.cat(o_6, f_5), self.cat(k_6, h_5)) 423 | o_5, k_5 = self.upsample(o_5, k_5) 424 | o_5A, k_5A = o_5, k_5 425 | o_5B, k_5B = o_5, k_5 426 | ############### 427 | o_4A, k_4A = self.dec_5A(self.cat(o_5A, f_4), self.cat(k_5A, h_4)) 428 | o_4A, k_4A = self.upsample(o_4A, k_4A) 429 | o_3A, k_3A = self.dec_4A(self.cat(o_4A, f_3), self.cat(k_4A, h_3)) 430 | o_3A, k_3A = self.upsample(o_3A, k_3A) 431 | o_2A, k_2A = self.dec_3A(self.cat(o_3A, f_2), self.cat(k_3A, h_2)) 432 | o_2A, k_2A = self.upsample(o_2A, k_2A) 433 | o_1A, k_1A = self.dec_2A(self.cat(o_2A, f_1), self.cat(k_2A, h_1)) 434 | o_1A, k_1A = self.upsample(o_1A, k_1A) 435 | o_0A, k_0A = self.dec_1A(self.cat(o_1A, f_0), self.cat(k_1A, h_0)) 436 | 437 | return torch.sigmoid(o_0A) 438 | 439 | def train(self, mode=True): 440 | """ 441 | Override the default train() to freeze the BN parameters 442 | """ 443 | super().train(mode) 444 | if self.freeze_enc_bn: 445 | for name, module in self.named_modules(): 446 | if isinstance(module, nn.BatchNorm2d) and 'enc' in name: 447 | module.eval() 448 | 449 | 450 | class Discriminator(BaseNetwork): 451 | def __init__(self, use_sigmoid=True, use_spectral_norm=True, init_weights=True, in_channels=None): 452 | super(Discriminator, self).__init__() 453 | self.use_sigmoid = use_sigmoid 454 | self.conv1 = self.features = nn.Sequential( 455 | spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), 456 | nn.LeakyReLU(0.2, inplace=True), 457 | ) 458 | 459 | self.conv2 = nn.Sequential( 460 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), 461 | nn.LeakyReLU(0.2, inplace=True), 462 | ) 463 | 464 | self.conv3 = nn.Sequential( 465 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), 466 | nn.LeakyReLU(0.2, inplace=True), 467 | ) 468 | 469 | self.conv4 = nn.Sequential( 470 | spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), 471 | nn.LeakyReLU(0.2, inplace=True), 472 | ) 473 | 474 | self.conv5 = nn.Sequential( 475 | spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), 476 | ) 477 | 478 | if init_weights: 479 | self.init_weights() 480 | 481 | def forward(self, x): 482 | conv1 = self.conv1(x) 483 | conv2 = self.conv2(conv1) 484 | conv3 = self.conv3(conv2) 485 | conv4 = self.conv4(conv3) 486 | conv5 = self.conv5(conv4) 487 | 488 | outputs = conv5 489 | if self.use_sigmoid: 490 | outputs = torch.sigmoid(conv5) 491 | 492 | return outputs, [conv1, conv2, conv3, conv4, conv5] 493 | 494 | 495 | class ResnetBlock(nn.Module): 496 | def __init__(self, dim, dilation=1): 497 | super(ResnetBlock, self).__init__() 498 | self.conv_block = nn.Sequential( 499 | nn.ReflectionPad2d(dilation), 500 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not True), True), 501 | nn.InstanceNorm2d(dim, track_running_stats=False), 502 | nn.LeakyReLU(negative_slope=0.2), 503 | 504 | nn.ReflectionPad2d(1), 505 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not True), True), 506 | nn.InstanceNorm2d(dim, track_running_stats=False), 507 | ) 508 | 509 | def forward(self, x): 510 | out = x + self.conv_block(x) 511 | 512 | # Remove ReLU at the end of the residual block 513 | # http://torch.ch/blog/2016/02/04/resnets.html 514 | 515 | return out 516 | 517 | 518 | def spectral_norm(module, mode=True): 519 | if mode: 520 | return nn.utils.spectral_norm(module) 521 | 522 | return module 523 | -------------------------------------------------------------------------------- /networks/resunet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import importlib 20 | 21 | 22 | def class_for_name(module_name, class_name): 23 | # load the module, will raise ImportError if module cannot be loaded 24 | m = importlib.import_module(module_name) 25 | return getattr(m, class_name) 26 | 27 | 28 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 29 | """3x3 convolution with padding""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 31 | padding=dilation, groups=groups, bias=False, dilation=dilation, padding_mode='reflect') 32 | 33 | 34 | def conv1x1(in_planes, out_planes, stride=1): 35 | """1x1 convolution""" 36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False, padding_mode='reflect') 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | expansion = 1 41 | 42 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 43 | base_width=64, dilation=1, norm_layer=None): 44 | super(BasicBlock, self).__init__() 45 | if norm_layer is None: 46 | norm_layer = nn.BatchNorm2d 47 | if groups != 1 or base_width != 64: 48 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 49 | if dilation > 1: 50 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 51 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 52 | self.conv1 = conv3x3(inplanes, planes, stride) 53 | self.bn1 = norm_layer(planes, track_running_stats=False, affine=True) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv2 = conv3x3(planes, planes) 56 | self.bn2 = norm_layer(planes, track_running_stats=False, affine=True) 57 | self.downsample = downsample 58 | self.stride = stride 59 | 60 | def forward(self, x): 61 | identity = x 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.relu(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | 70 | if self.downsample is not None: 71 | identity = self.downsample(x) 72 | 73 | out += identity 74 | out = self.relu(out) 75 | 76 | return out 77 | 78 | 79 | class Bottleneck(nn.Module): 80 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 81 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 82 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 83 | # This variant is also known as ResNet V1.5 and improves accuracy according to 84 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 85 | 86 | expansion = 4 87 | 88 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 89 | base_width=64, dilation=1, norm_layer=None): 90 | super(Bottleneck, self).__init__() 91 | if norm_layer is None: 92 | norm_layer = nn.BatchNorm2d 93 | width = int(planes * (base_width / 64.)) * groups 94 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 95 | self.conv1 = conv1x1(inplanes, width) 96 | self.bn1 = norm_layer(width, track_running_stats=False, affine=True) 97 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 98 | self.bn2 = norm_layer(width, track_running_stats=False, affine=True) 99 | self.conv3 = conv1x1(width, planes * self.expansion) 100 | self.bn3 = norm_layer(planes * self.expansion, track_running_stats=False, affine=True) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.downsample = downsample 103 | self.stride = stride 104 | 105 | def forward(self, x): 106 | identity = x 107 | 108 | out = self.conv1(x) 109 | out = self.bn1(out) 110 | out = self.relu(out) 111 | 112 | out = self.conv2(out) 113 | out = self.bn2(out) 114 | out = self.relu(out) 115 | 116 | out = self.conv3(out) 117 | out = self.bn3(out) 118 | 119 | if self.downsample is not None: 120 | identity = self.downsample(x) 121 | 122 | out += identity 123 | out = self.relu(out) 124 | 125 | return out 126 | 127 | 128 | class conv(nn.Module): 129 | def __init__(self, num_in_layers, num_out_layers, kernel_size, stride): 130 | super(conv, self).__init__() 131 | self.kernel_size = kernel_size 132 | self.conv = nn.Conv2d(num_in_layers, 133 | num_out_layers, 134 | kernel_size=kernel_size, 135 | stride=stride, 136 | padding=(self.kernel_size - 1) // 2, 137 | padding_mode='reflect') 138 | self.bn = nn.BatchNorm2d(num_out_layers, track_running_stats=False, affine=True) 139 | 140 | def forward(self, x): 141 | return F.elu(self.bn(self.conv(x)), inplace=True) 142 | 143 | 144 | class upconv(nn.Module): 145 | def __init__(self, num_in_layers, num_out_layers, kernel_size, scale): 146 | super(upconv, self).__init__() 147 | self.scale = scale 148 | self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1) 149 | 150 | def forward(self, x): 151 | x = nn.functional.interpolate(x, scale_factor=self.scale, align_corners=True, mode='bilinear') 152 | return self.conv(x) 153 | 154 | 155 | class ResUNet(nn.Module): 156 | def __init__(self, args, 157 | encoder='resnet34', 158 | in_ch=8, 159 | out_ch=32, 160 | norm_layer=None, 161 | ): 162 | 163 | super(ResUNet, self).__init__() 164 | assert encoder in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'], "Incorrect encoder type" 165 | if encoder in ['resnet18', 'resnet34']: 166 | filters = [64, 128, 256, 512] 167 | else: 168 | filters = [256, 512, 1024, 2048] 169 | # resnet = class_for_name("torchvision.models", encoder)(pretrained=True).to("cuda:{}".format(args.local_rank)) 170 | 171 | # original 172 | layers = [3, 4, 6, 3] 173 | if norm_layer is None: 174 | norm_layer = nn.BatchNorm2d 175 | # norm_layer = nn.InstanceNorm2d 176 | self._norm_layer = norm_layer 177 | self.dilation = 1 178 | block = BasicBlock 179 | replace_stride_with_dilation = [False, False, False] 180 | self.inplanes = 64 181 | self.groups = 1 182 | self.base_width = 64 183 | self.conv1 = nn.Conv2d(in_ch, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False, 184 | padding_mode='reflect') 185 | self.bn1 = norm_layer(self.inplanes, track_running_stats=False, affine=True) 186 | self.relu = nn.ReLU(inplace=True) 187 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 188 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 189 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 190 | dilate=replace_stride_with_dilation[0]) 191 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 192 | dilate=replace_stride_with_dilation[1]) 193 | 194 | # if in_ch != 3: # Number of input channels 195 | # self.conv1 = nn.Conv2d(in_ch, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 196 | # else: 197 | # self.conv1 = resnet.conv1 # H/2 198 | # self.bn1 = resnet.bn1 199 | # self.relu = resnet.relu 200 | # self.maxpool = resnet.maxpool # H/4 201 | # 202 | # # encoder 203 | # self.layer1 = resnet.layer1 # H/4 204 | # self.layer2 = resnet.layer2 # H/8 205 | # self.layer3 = resnet.layer3 # H/16 206 | 207 | # decoder 208 | self.upconv3 = upconv(filters[2], 128, 3, 2) 209 | self.iconv3 = conv(filters[1] + 128, 128, 3, 1) 210 | self.upconv2 = upconv(128, 64, 3, 2) 211 | self.iconv2 = conv(filters[0] + 64, out_ch, 3, 1) 212 | 213 | # fine-level conv 214 | self.out_conv = nn.Conv2d(out_ch, out_ch, 1, 1) 215 | 216 | for m in self.modules(): 217 | if isinstance(m, nn.Conv2d): 218 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 219 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 220 | nn.init.constant_(m.weight, 1) 221 | nn.init.constant_(m.bias, 0) 222 | 223 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 224 | norm_layer = self._norm_layer 225 | downsample = None 226 | previous_dilation = self.dilation 227 | if dilate: 228 | self.dilation *= stride 229 | stride = 1 230 | if stride != 1 or self.inplanes != planes * block.expansion: 231 | downsample = nn.Sequential( 232 | conv1x1(self.inplanes, planes * block.expansion, stride), 233 | norm_layer(planes * block.expansion, track_running_stats=False, affine=True), 234 | ) 235 | 236 | layers = [] 237 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 238 | self.base_width, previous_dilation, norm_layer)) 239 | self.inplanes = planes * block.expansion 240 | for _ in range(1, blocks): 241 | layers.append(block(self.inplanes, planes, groups=self.groups, 242 | base_width=self.base_width, dilation=self.dilation, 243 | norm_layer=norm_layer)) 244 | 245 | return nn.Sequential(*layers) 246 | 247 | def skipconnect(self, x1, x2): 248 | diffY = x2.size()[2] - x1.size()[2] 249 | diffX = x2.size()[3] - x1.size()[3] 250 | 251 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, 252 | diffY // 2, diffY - diffY // 2)) 253 | 254 | # for padding issues, see 255 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 256 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 257 | 258 | x = torch.cat([x2, x1], dim=1) 259 | return x 260 | 261 | def forward(self, x): 262 | x = self.relu(self.bn1(self.conv1(x))) 263 | x = self.maxpool(x) 264 | 265 | x1 = self.layer1(x) 266 | x2 = self.layer2(x1) 267 | x3 = self.layer3(x2) 268 | 269 | x = self.upconv3(x3) 270 | x = self.skipconnect(x2, x) 271 | x = self.iconv3(x) 272 | 273 | x = self.upconv2(x) 274 | x = self.skipconnect(x1, x) 275 | x = self.iconv2(x) 276 | 277 | x_out = self.out_conv(x) 278 | 279 | return x_out 280 | -------------------------------------------------------------------------------- /posenc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Embedder: 4 | 5 | def __init__(self, **kwargs): 6 | 7 | self.kwargs = kwargs 8 | self.create_embedding_fn() 9 | 10 | def create_embedding_fn(self): 11 | 12 | embed_fns = [] 13 | d = self.kwargs['input_dims'] 14 | out_dim = 0 15 | if self.kwargs['include_input']: 16 | embed_fns.append(lambda x: x) 17 | out_dim += d 18 | 19 | max_freq = self.kwargs['max_freq_log2'] 20 | N_freqs = self.kwargs['num_freqs'] 21 | 22 | if self.kwargs['log_sampling']: 23 | freq_bands = 2.**np.linspace(0., max_freq, N_freqs) 24 | else: 25 | freq_bands = np.linspace(2.**0., 2.**max_freq, N_freqs) 26 | 27 | for freq in freq_bands: 28 | for p_fn in self.kwargs['periodic_fns']: 29 | embed_fns.append(lambda x, p_fn=p_fn, 30 | freq=freq: p_fn(x * freq)) 31 | out_dim += d 32 | 33 | self.embed_fns = embed_fns 34 | self.out_dim = out_dim 35 | 36 | def embed(self, inputs): 37 | 38 | # print([fn(inputs) for fn in self.embed_fns]) 39 | # exit() 40 | return np.stack([fn(inputs) for fn in self.embed_fns], -1) 41 | 42 | 43 | def get_embedder(multires, i=0): 44 | 45 | if i == -1: 46 | return np.identity, 3 47 | 48 | embed_kwargs = { 49 | 'include_input': True, 50 | 'input_dims': 3, 51 | 'max_freq_log2': multires-1, 52 | 'num_freqs': multires, 53 | 'log_sampling': True, 54 | 'periodic_fns': [np.sin, np.cos], 55 | } 56 | 57 | embedder_obj = Embedder(**embed_kwargs) 58 | def embed(x, eo=embedder_obj): return eo.embed(x) 59 | return embed -------------------------------------------------------------------------------- /renderer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import glob 16 | import time 17 | import imageio 18 | import cv2 19 | from PIL import ImageFile 20 | 21 | import config 22 | import math 23 | import torchvision 24 | from torchvision.transforms import Pad 25 | import torch.utils.data.distributed 26 | from tqdm import tqdm 27 | from utils import * 28 | from model_3dm import SpaceTimeModel 29 | from core.utils import * 30 | from core.renderer import ImgRenderer 31 | from core.inpainter import Inpainter 32 | from model import UNet, FCN 33 | from posenc import get_embedder 34 | ImageFile.LOAD_TRUNCATED_IMAGES = True 35 | 36 | 37 | def process_boundary_mask(mask): 38 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) 39 | closing = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=5) 40 | dilation = cv2.dilate(closing, kernel, iterations=1, borderType=cv2.BORDER_CONSTANT, borderValue=0.) 41 | return dilation 42 | 43 | 44 | def get_input_data(args, ds_factor=1): 45 | to_tensor = torchvision.transforms.ToTensor() 46 | input_dir = args.input_dir 47 | img_file = (sorted(glob.glob(os.path.join(input_dir, 'input.png'))) + \ 48 | sorted(glob.glob(os.path.join(input_dir, 'input.jpg'))))[0] 49 | src_img = imageio.imread(img_file) / 255. 50 | h1, w1 = src_img.shape[:2] 51 | 52 | disparity = (np.abs(np.load(os.path.join(input_dir, 'disp.npy')))) 53 | disparity = np.maximum(disparity, 1e-3) 54 | src_depth = (65535/disparity) / args.dscl 55 | 56 | fy = w1 / (2 * np.tan(np.deg2rad(args.fov) / 2)) 57 | fx = fy 58 | 59 | intrinsic = np.array([[fx, 0, w1 // 2], 60 | [0, fy, h1 // 2], 61 | [0, 0, 1]]) 62 | 63 | pose = np.eye(4) 64 | return { 65 | 'src_img1': to_tensor(src_img).float()[None], 66 | 'src_depth1': to_tensor(src_depth).float()[None], 67 | 'intrinsic1': torch.from_numpy(intrinsic).float()[None], 68 | 'tgt_intrinsic': torch.from_numpy(intrinsic).float()[None], 69 | 'pose': torch.from_numpy(pose).float()[None], 70 | 'scale_shift1': torch.tensor([1., 0.]).float()[None], 71 | 'src_rgb_file1': [img_file], 72 | 'multi_view': [False] 73 | } 74 | 75 | def reshading(reshader, dir_model, rgbd, xyz): 76 | 77 | xyz = xyz[None] * torch.tensor([-1, 1, 1]).cuda() 78 | pos_feat = dir_model(xyz) 79 | pred = reshader(rgbd, pos_feat) 80 | pred = torch.clamp(pred, 0, 1) 81 | 82 | return pred 83 | 84 | def render(args): 85 | to_tensor = torchvision.transforms.ToTensor() 86 | device = "cuda:{}".format(args.local_rank) 87 | 88 | print('========================= Reshading Init...=========================') 89 | runpath = "runs/" 90 | 91 | reshader = UNet(14, 3).cuda() 92 | dir_model = FCN().cuda() 93 | reshader.load_state_dict(torch.load(f'{runpath}{args.model_dir}/model{args.ckpt}.pth')) 94 | dir_model.load_state_dict(torch.load(f'{runpath}{args.model_dir}/dir_model{args.ckpt}.pth')) 95 | embed_fn = get_embedder(args.pos_enc_freq) 96 | 97 | 98 | print('=========================run 3D Moments...=========================') 99 | 100 | data = get_input_data(args) 101 | rgb_file1 = data['src_rgb_file1'][0] 102 | frame_id1 = os.path.basename(rgb_file1).split('.')[0] 103 | scene_id = rgb_file1.split('/')[-3] 104 | 105 | video_out_folder = os.path.join(args.input_dir, 'out') 106 | os.makedirs(video_out_folder, exist_ok=True) 107 | 108 | im_h, im_w = data['src_img1'].shape[2:] 109 | pad_h, pad_w = (32 * math.ceil(im_h / 32) - im_h), (32 * math.ceil(im_w / 32) - im_w) 110 | padder = Pad(padding=(0, 0, pad_w, pad_h)) 111 | 112 | model = SpaceTimeModel(args) 113 | if model.start_step == 0: 114 | raise Exception('no pretrained model found! please check the model path.') 115 | 116 | inpainter = Inpainter(args) 117 | renderer = ImgRenderer(args, model, None, inpainter, device) 118 | 119 | model.switch_to_eval() 120 | with torch.no_grad(): 121 | renderer.process_data_single(data) 122 | 123 | pts1, rgb1, feat1, mask, side_ids = \ 124 | renderer.render_rgbda_layers_from_one_view(return_pts=True) 125 | 126 | num_frames = 60#[60, 60, 60, 90] 127 | video_paths = ['circle']#['up-down', 'zoom-in', 'side', 'circle'] 128 | Ts = [ 129 | # define_camera_path(num_frames[0], 0., -0.08, 0., path_type='double-straight-line', return_t_only=True), 130 | # define_camera_path(num_frames[1], 0., 0., -0.24, path_type='straight-line', return_t_only=True), 131 | # define_camera_path(num_frames[2], -0.09, 0, -0, path_type='double-straight-line', return_t_only=True), 132 | # define_camera_path(num_frames, -0.15, -0.15, -0.15, path_type='circle', return_t_only=True), 133 | # define_camera_path(num_frames, -0.14, -0.14, -0.14, path_type='circle', return_t_only=True), 134 | define_camera_path(num_frames, -0.14, -0.14, -0.14, path_type='circle', return_t_only=True), 135 | ] 136 | crop = 32 137 | ##### the max value of the relative coordinates (above) should not exceed 0.3 (max used in training) 138 | 139 | ref_input = data['src_img1'] 140 | for j, T in enumerate(Ts): 141 | print(video_paths[j]) 142 | T = torch.from_numpy(T).float().to(renderer.device) 143 | time_steps = np.linspace(0, 1, num_frames) 144 | frames = [] 145 | reshaded_frames = [] 146 | 147 | for i, t_step in tqdm(enumerate(time_steps), total=len(time_steps), 148 | desc='generating video of {} camera trajectory'.format(video_paths[j])): 149 | 150 | ######### RESHADING INSERT############# 151 | disparity = 1 / data['src_depth1'] 152 | # disparity is scaled by 4 during training (should not exceed 0.25 (max during training)) 153 | disparity = (disparity / 4) 154 | disparity = torch.Tensor(embed_fn(disparity[0])[0]).permute(2, 0, 1)[None] 155 | 156 | rgbd = padder(torch.cat((ref_input, disparity), 1).cuda()) 157 | reshaded = reshading(reshader, dir_model, rgbd, T[i])[:, :, :im_h, :im_w] 158 | reshaded_frames.append((255. * reshaded.detach().cpu().squeeze().permute(1, 2, 0).numpy()).astype(np.uint8)) 159 | 160 | data['src_img1'] = reshaded 161 | 162 | renderer.process_data_single(data) 163 | 164 | pts1, rgb1, feat1, mask, side_ids = \ 165 | renderer.render_rgbda_layers_from_one_view(return_pts=True) 166 | ################################# 167 | 168 | pred_img, _, meta = renderer.render_pcd_single(pts1, rgb1, 169 | feat1, mask, side_ids, 170 | t=T[i], R=None, time=0) 171 | frame = (255. * pred_img.detach().cpu().squeeze().permute(1, 2, 0).numpy()).astype(np.uint8) 172 | # mask out fuzzy image boundaries due to no outpainting 173 | img_boundary_mask = (meta['acc'] > 0.5).detach().cpu().squeeze().numpy().astype(np.uint8) 174 | img_boundary_mask_cleaned = process_boundary_mask(img_boundary_mask) 175 | frame = frame * img_boundary_mask_cleaned[..., None] 176 | frame = frame[crop:-crop, crop:-crop] 177 | frames.append(frame) 178 | 179 | video_out_file = os.path.join(video_out_folder, '{}_{}-{}.mp4'.format( 180 | video_paths[j], scene_id, frame_id1)) 181 | imageio.mimwrite(video_out_file, frames, fps=25, quality=8) 182 | imageio.mimwrite(f"{video_out_folder}/out_{args.model_dir}_model{args.ckpt}.mp4", reshaded_frames, fps=25, quality=8) 183 | 184 | 185 | print('output videos have been saved in {}.'.format(video_out_folder)) 186 | 187 | if __name__ == '__main__': 188 | args = config.config_parser() 189 | 190 | render(args) -------------------------------------------------------------------------------- /runs/reshader/dir_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/runs/reshader/dir_model.pth -------------------------------------------------------------------------------- /runs/reshader/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/runs/reshader/model.pth -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import numpy as np 18 | import torch 19 | from datetime import datetime 20 | import shutil 21 | 22 | TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision 23 | 24 | 25 | def de_parallel(model): 26 | return model.module if hasattr(model, 'module') else model 27 | 28 | 29 | def cycle(iterable): 30 | while True: 31 | for x in iterable: 32 | yield x 33 | 34 | 35 | def dict_to_device(dict_): 36 | for k in dict_.keys(): 37 | if type(dict_[k]) == torch.Tensor: 38 | dict_[k] = dict_[k].cuda() 39 | 40 | return dict_ 41 | 42 | 43 | def save_current_code(outdir): 44 | now = datetime.now() # current date and time 45 | date_time = now.strftime("%m_%d-%H:%M:%S") 46 | src_dir = '.' 47 | code_out_dir = os.path.join(outdir, 'code') 48 | os.makedirs(code_out_dir, exist_ok=True) 49 | dst_dir = os.path.join(code_out_dir, '{}'.format(date_time)) 50 | shutil.copytree(src_dir, dst_dir, 51 | ignore=shutil.ignore_patterns('pretrained*', '*logs*', 'out*', '*.png', '*.mp4', 'eval*', 52 | '*__pycache__*', '*.git*', '*.idea*', '*.zip', '*.jpg')) 53 | 54 | 55 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 56 | # assert isinstance(input, torch.Tensor) 57 | if posinf is None: 58 | posinf = torch.finfo(input.dtype).max 59 | if neginf is None: 60 | neginf = torch.finfo(input.dtype).min 61 | assert nan == 0 62 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 63 | 64 | 65 | def img2mse(x, y, mask=None): 66 | ''' 67 | :param x: img 1, [(...), 3] 68 | :param y: img 2, [(...), 3] 69 | :param mask: optional, [(...)] 70 | :return: mse score 71 | ''' 72 | if mask is None: 73 | return torch.mean((x - y) * (x - y)) 74 | else: 75 | return torch.sum((x - y) * (x - y) * mask.unsqueeze(-1)) / (torch.sum(mask) * x.shape[-1] + TINY_NUMBER) 76 | 77 | 78 | mse2psnr = lambda x: -10. * np.log(x+TINY_NUMBER) / np.log(10.) 79 | 80 | 81 | def img2psnr(x, y, mask=None): 82 | return mse2psnr(img2mse(x, y, mask).item()) 83 | 84 | --------------------------------------------------------------------------------