├── .gitignore
├── LICENSE
├── README.md
├── assets
└── teaser.gif
├── config.py
├── configs
└── render.txt
├── core
├── __init__.py
├── depth_layering.py
├── inpainter.py
├── pcd.py
├── renderer.py
├── scene_flow.py
└── utils.py
├── download.sh
├── environment.yml
├── examples
├── bottle
│ ├── disp.npy
│ └── input.jpg
├── burger
│ ├── disp.npy
│ └── input.jpg
├── camera
│ ├── disp.npy
│ └── input.jpg
├── car
│ ├── disp.npy
│ └── input.jpg
├── fireworks
│ ├── disp.npy
│ └── input.jpg
├── rook
│ ├── disp.npy
│ └── input.jpg
└── spoon
│ ├── disp.npy
│ └── input.png
├── model.py
├── model_3dm.py
├── networks
├── __init__.py
├── img_decoder.py
├── inpainting_nets.py
└── resunet.py
├── posenc.py
├── renderer.py
├── runs
└── reshader
│ ├── dir_model.pth
│ └── model.pth
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | .DS_Store
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Avinash Paliwal
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ReShader
2 |
3 | > ReShader: View-Dependent Highlights for Single Image View-Synthesis
4 | > [Avinash Paliwal](http://avinashpaliwal.com/),
5 | > [Brandon Nguyen](https://brandon.nguyen.vc/about/),
6 | > [Andrii Tsarov](https://www.linkedin.com/in/andrii-tsarov-b8a9bb13),
7 | > [Nima Khademi Kalantari](http://nkhademi.com/)
8 | > SIGGRAPH Asia 2023 (TOG)
9 |
10 | [](https://arxiv.org/abs/2309.10689)
11 | [](https://people.engr.tamu.edu/nimak/Papers/SIGAsia2023_Reshader)
12 | [](https://youtu.be/XW-tl48D3Ok)
13 |
14 | ---------------------------------------------------
15 |
16 |
17 |
18 |
19 |
20 |
21 | ## Prerequisites
22 | You can setup the anaconda environment using:
23 | ```
24 | conda env create -f environment.yml
25 | conda activate reshader
26 | ```
27 |
28 | Download pretrained models.
29 | The following script from [3D Moments](https://github.com/google-research/3d-moments) will download their pretrained models and [RGBD-inpainting networks](https://github.com/vt-vl-lab/3d-photo-inpainting).
30 | ```
31 | ./download.sh
32 | ```
33 |
34 |
35 | ## Demos
36 | We provided some examples in the `examples/` folder. You can render novel views with view-dependent highlights using:
37 |
38 | ```
39 | python renderer.py --input_dir examples/camera/ --config configs/render.txt
40 | ```
41 |
42 | ## Training
43 | Training code and dataset to be added.
44 |
45 | ## Citation
46 | ```
47 | @article{paliwal2023reshader,
48 | author = {Paliwal, Avinash and Nguyen, Brandon G. and Tsarov, Andrii and Kalantari, Nima Khademi},
49 | title = {ReShader: View-Dependent Highlights for Single Image View-Synthesis},
50 | year = {2023},
51 | issue_date = {December 2023},
52 | volume = {42},
53 | number = {6},
54 | journal = {ACM Trans. Graph.},
55 | month = {dec},
56 | articleno = {216},
57 | numpages = {9},
58 | }
59 | ```
60 |
61 |
62 | ## Acknowledgement
63 | The novel view synthesis part of the code is borrowed from [3D Moments](https://github.com/google-research/3d-moments).
--------------------------------------------------------------------------------
/assets/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/assets/teaser.gif
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import configargparse
16 |
17 |
18 | def config_parser():
19 | parser = configargparse.ArgumentParser()
20 | parser.add_argument('--config', is_config_file=True, help='config file path')
21 | # general
22 | parser.add_argument('--rootdir', type=str, default='./',
23 | help='the path to the project root directory.')
24 | parser.add_argument("--expname", type=str, default='exp', help='experiment name')
25 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
26 | help='number of data loading workers (default: 8)')
27 | parser.add_argument('--distributed', action='store_true', help='if use distributed training')
28 | parser.add_argument("--local_rank", type=int, default=0, help='rank for distributed training')
29 | parser.add_argument("--eval_mode", action='store_true', help='if in eval mode')
30 |
31 | ########## dataset options ##########
32 | # train and eval dataset
33 | parser.add_argument("--train_dataset", type=str, default='vimeo',
34 | help='the training dataset')
35 | parser.add_argument("--dataset_weights", nargs='+', type=float, default=[],
36 | help='the weights for training datasets, used when multiple datasets are used.')
37 | parser.add_argument('--eval_dataset', type=str, default='vimeo', help='the dataset to evaluate')
38 | parser.add_argument("--batch_size", type=int, default=1, help='batch size, currently only support 1')
39 |
40 | ########## network architecture ##########
41 | parser.add_argument("--feature_dim", type=int, default=32, help='the dimension of the extracted features')
42 |
43 | ########## training options ##########
44 | parser.add_argument("--use_inpainting_mask_for_feature", action='store_true')
45 | parser.add_argument("--inpainting", action='store_true', help='if do inpainting')
46 | parser.add_argument("--train_raft", action='store_true', help='if train raft')
47 | parser.add_argument('--boundary_crop_ratio', type=float, default=0, help='crop the image before computing loss')
48 | parser.add_argument("--vary_pts_radius", action='store_true', help='if vary point radius as augmentation')
49 | parser.add_argument("--adaptive_pts_radius", action='store_true', help='if use adaptive point radius')
50 | parser.add_argument("--use_mask_for_decoding", action='store_true', help='if use mask for decoding')
51 |
52 | ########## rendering/evaluation ##########
53 | parser.add_argument("--use_depth_for_feature", action='store_true',
54 | help='if use depth map when extracting features')
55 | parser.add_argument("--use_depth_for_decoding", action='store_true',
56 | help='if use depth map when decoding')
57 | parser.add_argument("--point_radius", type=float, default=1.5,
58 | help='point radius for rasterization')
59 | parser.add_argument("--input_dir", type=str, default='', help='input folder that contains a pair of images')
60 | parser.add_argument("--visualize_rgbda_layers", action='store_true',
61 | help="if visualize rgbda layers, save in out dir")
62 |
63 | ########### iterations & learning rate options & loss ##########
64 | parser.add_argument("--n_iters", type=int, default=250000, help='num of iterations')
65 | parser.add_argument("--lr", type=float, default=3e-4, help='learning rate for feature extractor')
66 | parser.add_argument("--lr_raft", type=float, default=5e-6, help='learning rate for raft')
67 | parser.add_argument("--lrate_decay_factor", type=float, default=0.5,
68 | help='decay learning rate by a factor every specified number of steps')
69 | parser.add_argument("--lrate_decay_steps", type=int, default=50000,
70 | help='decay learning rate by a factor every specified number of steps')
71 | parser.add_argument('--loss_mode', type=str, default='lpips',
72 | help='the loss function to use')
73 |
74 | ########## checkpoints ##########
75 | parser.add_argument("--ckpt_path", type=str, default="",
76 | help='specific weights npy file to reload for coarse network')
77 | parser.add_argument("--no_reload", action='store_true',
78 | help='do not reload weights from saved ckpt')
79 | parser.add_argument("--no_load_opt", action='store_true',
80 | help='do not load optimizer when reloading')
81 | parser.add_argument("--no_load_scheduler", action='store_true',
82 | help='do not load scheduler when reloading')
83 |
84 | ########## logging/saving options ##########
85 | parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin')
86 | parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging')
87 | parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving')
88 |
89 | ############ demo parameters ##############
90 | parser.add_argument("--spec", action='store_true', help='use specular frames')
91 | parser.add_argument("--tung", action='store_true', help='using tungsten depths')
92 | parser.add_argument("--normalize_depth", action='store_true', help='normalize depth when depth map is euclidean distance')
93 | parser.add_argument("--fov", type=float, default=45.0, help='fov of camera')
94 | parser.add_argument("--spd", type=str, default="246", help='spec directory suffix')
95 | parser.add_argument("--dscl", type=int, default=1, help='depth scaling')
96 |
97 | #training schedule
98 | parser.add_argument('--num_iterations', type=int, default=300000, help='total epochs to train')
99 | parser.add_argument('-train_batch_size', type=int, default=10)
100 | parser.add_argument('-val_batch_size', type=int, default=4)
101 | parser.add_argument('-checkpoint', type=int, default=10, help='save checkpoint for every epochs. Be aware that! It will replace the previous checkpoint.')
102 | parser.add_argument('-tb_toc',type=int, default=100, help="print output to terminal for every tb_toc iterations")
103 |
104 | #lr schedule
105 | parser.add_argument('-lr', '--learning_rate', type=float, default=1e-4, help='learning rate of the network')
106 |
107 | #loss
108 | parser.add_argument('-style_coeff', type=float, default=1, help='hyperparameter for style loss')
109 | parser.add_argument('-prcp_coeff', type=float, default=0.01, help='hyperparameter for perceptual loss')
110 | parser.add_argument('-mse_coeff', type=float, default=1.0, help='hyperparameter for MSE loss')
111 | parser.add_argument('-l1_coeff', type=float, default=0.1, help='hyperparameter for L1 loss')
112 | #training and eval data
113 | parser.add_argument('-dataset', type=str, default="/data2/avinash/datasets/specular_fixed/specular/", help='directory to the dataset')
114 |
115 | #training utility
116 | parser.add_argument('--model_dir', type=str, default="unet_prcp_gmm_mse", help='model (scene) directory which store in runs//')
117 | parser.add_argument('-clean', action='store_true', help='delete old weight without start training process')
118 | parser.add_argument('--clip', type=float, default=1.0)
119 |
120 | #model
121 | parser.add_argument('-multi', type=bool, default=True, help='append multi level direction vector')
122 | parser.add_argument('-use_mlp', type=bool, default=False, help='use mlp for feature vector from direction vector')
123 | parser.add_argument('--start_iter',type=int, default=0, help="starting iteration")
124 | parser.add_argument('-basis_out',type=int, default=8, help="num of basis functions")
125 | parser.add_argument('-pos_enc_freq',type=int, default=5, help="num of freqs in positional encoding")
126 | parser.add_argument('--losses', type=str, nargs='+', help='losses to use', default=['mse', 'prcp', 'gmm'])
127 | parser.add_argument('--ckpt', type=str, default=None, help='checkpopint to continue from')
128 | parser.add_argument('--example_index', type=str, default=None, help='example index for testing')
129 | parser.add_argument('--test_root', type=str, default="real_data/", help='test examples root dir')
130 | parser.add_argument('-pad', type=bool, default=False, help='use mlp for feature vector from direction vector')
131 | parser.add_argument('--use_depth_posenc', type=bool, default=False, help='use mlp for feature vector from direction vector')
132 |
133 |
134 | args = parser.parse_args()
135 | return args
136 |
137 |
--------------------------------------------------------------------------------
/configs/render.txt:
--------------------------------------------------------------------------------
1 | no_load_opt = True
2 | no_load_scheduler = True
3 | distributed = False
4 | loss_mode = vgg19
5 | train_dataset = tiktok
6 | eval_dataset = jamie
7 | eval_mode = True
8 |
9 | use_depth_for_decoding = True
10 | adaptive_pts_radius = True
11 | train_raft = False
12 | visualize_rgbda_layers = False
13 |
14 | ckpt_path = pretrained/model_250000.pth
15 |
16 |
17 | model_dir = reshader
18 | ckpt = ""
19 | use_depth_posenc = True
20 | dscl = 2
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/core/__init__.py
--------------------------------------------------------------------------------
/core/depth_layering.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import torch
17 | from sklearn.cluster import AgglomerativeClustering
18 |
19 |
20 | def get_depth_bins(depth=None, disparity=None, num_bins=None):
21 | """
22 | :param depth: [1, 1, H, W]
23 | :param disparity: [1, 1, H, W]
24 | :return: depth_bins
25 | """
26 |
27 | assert (disparity is not None) or (depth is not None)
28 | if disparity is None:
29 | assert depth.min() > 1e-2
30 | disparity = 1. / depth
31 |
32 | if depth is None:
33 | depth = 1. / torch.clamp(disparity, min=1e-2)
34 |
35 | assert depth.shape[:2] == (1, 1) and disparity.shape[:2] == (1, 1)
36 | disparity_max = disparity.max().item()
37 | disparity_min = disparity.min().item()
38 | disparity_feat = disparity[:, :, ::10, ::10].reshape(-1, 1).cpu().numpy()
39 | disparity_feat = (disparity_feat - disparity_min) / (disparity_max - disparity_min)
40 | if num_bins is None:
41 | n_clusters = None
42 | distance_threshold = 5
43 | else:
44 | n_clusters = num_bins
45 | distance_threshold = None
46 | result = AgglomerativeClustering(n_clusters=n_clusters, distance_threshold=distance_threshold).fit(disparity_feat)
47 | num_bins = result.n_clusters_ if n_clusters is None else n_clusters
48 | depth_bins = [depth.min().item()]
49 | for i in range(num_bins):
50 | th = (disparity_feat[result.labels_ == i]).min()
51 | th = th * (disparity_max - disparity_min) + disparity_min
52 | depth_bins.append(1. / th)
53 |
54 | depth_bins = sorted(depth_bins)
55 | depth_bins[0] = depth.min() - 1e-6
56 | depth_bins[-1] = depth.max() + 1e-6
57 | return depth_bins
58 |
59 |
60 |
--------------------------------------------------------------------------------
/core/inpainter.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import torch
17 | from kornia.morphology import opening, erosion
18 | from kornia.filters import gaussian_blur2d
19 | from networks.inpainting_nets import Inpaint_Depth_Net, Inpaint_Color_Net
20 | from core.utils import masked_median_blur
21 |
22 |
23 | def refine_near_depth_discontinuity(depth, alpha, kernel_size=11):
24 | '''
25 | median filtering the depth discontinuity boundary
26 | '''
27 | depth = depth * alpha
28 | depth_median_blurred = masked_median_blur(depth, alpha, kernel_size=kernel_size) * alpha
29 | alpha_eroded = erosion(alpha, kernel=torch.ones(kernel_size, kernel_size).to(alpha.device))
30 | depth[alpha_eroded == 0] = depth_median_blurred[alpha_eroded == 0]
31 | return depth
32 |
33 |
34 | def define_inpainting_bbox(alpha, border=40):
35 | '''
36 | define the bounding box for inpainting
37 | :param alpha: alpha map [1, 1, h, w]
38 | :param border: the minimum distance from a valid pixel to the border of the bbox
39 | :return: [1, 1, h, w], a 0/1 map that indicates the inpainting region
40 | '''
41 | assert alpha.ndim == 4 and alpha.shape[:2] == (1, 1)
42 | x, y = torch.nonzero(alpha)[:, -2:].T
43 | h, w = alpha.shape[-2:]
44 | row_min, row_max = x.min(), x.max()
45 | col_min, col_max = y.min(), y.max()
46 | out = torch.zeros_like(alpha)
47 | x0, x1 = max(row_min - border, 0), min(row_max + border, h - 1)
48 | y0, y1 = max(col_min - border, 0), min(col_max + border, w - 1)
49 | out[:, :, x0:x1, y0:y1] = 1
50 | return out
51 |
52 |
53 | class Inpainter():
54 | def __init__(self, args, device='cuda'):
55 | self.args = args
56 | self.device = device
57 | print("Loading depth model...")
58 | depth_feat_model = Inpaint_Depth_Net()
59 | depth_feat_weight = torch.load('inpainting_ckpts/depth-model.pth', map_location=torch.device(device))
60 | depth_feat_model.load_state_dict(depth_feat_weight)
61 | depth_feat_model = depth_feat_model.to(device)
62 | depth_feat_model.eval()
63 | self.depth_feat_model = depth_feat_model.to(device)
64 | print("Loading rgb model...")
65 | rgb_model = Inpaint_Color_Net()
66 | rgb_feat_weight = torch.load('inpainting_ckpts/color-model.pth', map_location=torch.device(device))
67 | rgb_model.load_state_dict(rgb_feat_weight)
68 | rgb_model.eval()
69 | self.rgb_model = rgb_model.to(device)
70 |
71 | # kernels
72 | self.context_erosion_kernel = torch.ones(10, 10).to(self.device)
73 | self.alpha_kernel = torch.ones(3, 3).to(self.device)
74 |
75 | @staticmethod
76 | def process_depth_for_network(depth, context, log_depth=True):
77 | if log_depth:
78 | log_depth = torch.log(depth + 1e-8) * context
79 | mean_depth = torch.mean(log_depth[context > 0])
80 | zero_mean_depth = (log_depth - mean_depth) * context
81 | else:
82 | zero_mean_depth = depth
83 | mean_depth = 0
84 | return zero_mean_depth, mean_depth
85 |
86 | @staticmethod
87 | def deprocess_depth(zero_mean_depth, mean_depth, log_depth=True):
88 | if log_depth:
89 | depth = torch.exp(zero_mean_depth + mean_depth)
90 | else:
91 | depth = zero_mean_depth
92 | return depth
93 |
94 | def inpaint_rgb(self, holes, context, context_rgb, edge):
95 | # inpaint rgb
96 | with torch.no_grad():
97 | inpainted_rgb = self.rgb_model.forward_3P(holes, context, context_rgb, edge,
98 | unit_length=128, cuda=self.device)
99 | inpainted_rgb = inpainted_rgb.detach() * holes + context_rgb
100 | inpainted_a = holes + context
101 | inpainted_a = opening(inpainted_a, self.alpha_kernel)
102 | inpainted_rgba = torch.cat([inpainted_rgb, inpainted_a], dim=1)
103 | return inpainted_rgba
104 |
105 | def inpaint_depth(self, depth, holes, context, edge, depth_range):
106 | zero_mean_depth, mean_depth = self.process_depth_for_network(depth, context)
107 | with torch.no_grad():
108 | inpainted_depth = self.depth_feat_model.forward_3P(holes, context, zero_mean_depth, edge,
109 | unit_length=128, cuda=self.device)
110 | inpainted_depth = self.deprocess_depth(inpainted_depth.detach(), mean_depth)
111 | inpainted_depth[context > 0.5] = depth[context > 0.5]
112 | inpainted_depth = gaussian_blur2d(inpainted_depth, (3, 3), (1.5, 1.5))
113 | inpainted_depth[context > 0.5] = depth[context > 0.5]
114 | # if the inpainted depth in the background is smaller that the foreground depth,
115 | # then the inpainted content will mistakenly occlude the foreground.
116 | # Clipping the inpainted depth in this situation.
117 | mask_wrong_depth_ordering = inpainted_depth < depth
118 | inpainted_depth[mask_wrong_depth_ordering] = depth[mask_wrong_depth_ordering] * 1.01
119 | inpainted_depth = torch.clamp(inpainted_depth, min=min(depth_range)*0.9)
120 | return inpainted_depth
121 |
122 | def sequential_inpainting(self, rgb, depth, depth_bins):
123 | '''
124 | :param rgb: [1, 3, H, W]
125 | :param depth: [1, 1, H, W]
126 | :return: rgba_layers: [N, 1, 3, H, W]: the inpainted RGBA layers
127 | depth_layers: [N, 1, 1, H, W]: the inpainted depth layers
128 | mask_layers: [N, 1, 1, H, W]: the original alpha layers (before inpainting)
129 | '''
130 |
131 | num_bins = len(depth_bins) - 1
132 |
133 | rgba_layers = []
134 | depth_layers = []
135 | mask_layers = []
136 |
137 | for i in range(num_bins):
138 | alpha_i = (depth >= depth_bins[i]) * (depth < depth_bins[i+1])
139 | alpha_i = alpha_i.float()
140 |
141 | if i == 0:
142 | rgba_i = torch.cat([rgb*alpha_i, alpha_i], dim=1)
143 | rgba_layers.append(rgba_i)
144 | depth_i = refine_near_depth_discontinuity(depth, alpha_i)
145 | depth_layers.append(depth_i)
146 | mask_layers.append(alpha_i)
147 | pre_alpha = alpha_i.bool()
148 | pre_inpainted_depth = depth * alpha_i
149 | else:
150 | alpha_i_eroded = erosion(alpha_i, self.context_erosion_kernel)
151 | if alpha_i_eroded.sum() < 10:
152 | continue
153 | context = erosion((depth >= depth_bins[i]).float(), self.context_erosion_kernel)
154 | holes = 1. - context
155 | bbox = define_inpainting_bbox(context, border=40)
156 | holes *= bbox
157 | edge = torch.zeros_like(holes)
158 | context_rgb = rgb * context
159 | # inpaint depth
160 | inpainted_depth_i = self.inpaint_depth(depth, holes, context, edge, (depth_bins[i], depth_bins[i+1]))
161 | depth_near_mask = (inpainted_depth_i < depth_bins[i+1]).float()
162 | # inpaint rgb
163 | inpainted_rgba_i = self.inpaint_rgb(holes, context, context_rgb, edge)
164 |
165 | if i < num_bins - 1:
166 | # only keep the content whose depth is smaller than the upper limit of the current layer
167 | # otherwise the inpainted content on the far-depth edge will falsely occlude the next layer.
168 | inpainted_rgba_i *= depth_near_mask
169 | inpainted_depth_i = refine_near_depth_discontinuity(inpainted_depth_i, inpainted_rgba_i[:, [-1]])
170 |
171 | inpainted_alpha_i = inpainted_rgba_i[:, [-1]].bool()
172 | mask_wrong_ordering = (inpainted_depth_i <= pre_inpainted_depth) * inpainted_alpha_i
173 | inpainted_depth_i[mask_wrong_ordering] = pre_inpainted_depth[mask_wrong_ordering] * 1.05
174 |
175 | rgba_layers.append(inpainted_rgba_i)
176 | depth_layers.append(inpainted_depth_i)
177 | mask_layers.append(context * depth_near_mask) # original mask
178 |
179 | pre_alpha[inpainted_alpha_i] = True
180 | pre_inpainted_depth[inpainted_alpha_i > 0] = inpainted_depth_i[inpainted_alpha_i > 0]
181 |
182 | rgba_layers = torch.stack(rgba_layers)
183 | depth_layers = torch.stack(depth_layers)
184 | mask_layers = torch.stack(mask_layers)
185 |
186 | return rgba_layers, depth_layers, mask_layers
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
--------------------------------------------------------------------------------
/core/pcd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import sys
17 | sys.path.append('../')
18 | import config
19 | import numpy as np
20 | import torch
21 | import torch.nn as nn
22 | from pytorch3d.renderer import (
23 | PerspectiveCameras,
24 | PointsRasterizationSettings,
25 | PointsRasterizer,
26 | AlphaCompositor,
27 | )
28 |
29 | args = config.config_parser()
30 |
31 |
32 | class PointsRenderer(nn.Module):
33 | """
34 | A class for rendering a batch of points. The class should
35 | be initialized with a rasterizer and compositor class which each have a forward
36 | function.
37 | """
38 |
39 | def __init__(self, rasterizer, compositor) -> None:
40 | super().__init__()
41 | self.rasterizer = rasterizer
42 | self.compositor = compositor
43 |
44 | def to(self, device):
45 | # Manually move to device rasterizer as the cameras
46 | # within the class are not of type nn.Module
47 | self.rasterizer = self.rasterizer.to(device)
48 | self.compositor = self.compositor.to(device)
49 | return self
50 |
51 | def forward(self, point_clouds, **kwargs) -> torch.Tensor:
52 | fragments = self.rasterizer(point_clouds, **kwargs)
53 |
54 | # Construct weights based on the distance of a point to the true point.
55 | # However, this could be done differently: e.g. predicted as opposed
56 | # to a function of the weights.
57 | r = self.rasterizer.raster_settings.radius
58 |
59 | if type(r) == torch.Tensor:
60 | if r.shape[-1] > 1:
61 | idx = fragments.idx.clone()
62 | idx[idx == -1] = 0
63 | r = r[:, idx.squeeze().long()]
64 | r = r.permute(0, 3, 1, 2)
65 |
66 | dists2 = fragments.dists.permute(0, 3, 1, 2)
67 | weights = 1 - dists2 / (r * r)
68 | images = self.compositor(
69 | fragments.idx.long().permute(0, 3, 1, 2),
70 | weights,
71 | point_clouds.features_packed().permute(1, 0),
72 | **kwargs,
73 | )
74 |
75 | # permute so image comes at the end
76 | images = images.permute(0, 2, 3, 1)
77 |
78 | return images
79 |
80 |
81 | def linear_interpolation(data0, data1, time):
82 | return (1. - time) * data0 + time * data1
83 |
84 |
85 | def create_pcd_renderer(h, w, intrinsics, R=None, T=None, radius=None, device="cuda"):
86 | # Initialize a camera.
87 | fx = intrinsics[0, 0]
88 | fy = intrinsics[1, 1]
89 | # print("CREATEPCD")
90 | if R is None:
91 | R = torch.eye(3)[None] # (1, 3, 3)
92 | # R[:, 1, 1] = -1
93 | # R[:, 2, 2] = -1
94 | # print("RNONE")
95 | # print(R)
96 | if T is None:
97 | T = torch.zeros(1, 3) # (1, 3)
98 |
99 | cameras = PerspectiveCameras(R=R, T=T,
100 | device=device,
101 | focal_length=((-fx, -fy),),
102 | principal_point=(tuple(intrinsics[:2, -1]),),
103 | image_size=((h, w),),
104 | in_ndc=False,
105 | )
106 |
107 | # Define the settings for rasterization and shading. Here we set the output image to be of size
108 | # 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1
109 | # and blur_radius=0.0. Refer to raster_points.py for explanations of these parameters.
110 | if radius is None:
111 | radius = args.point_radius / min(h, w) * 2.0
112 | if args.vary_pts_radius:
113 | if np.random.choice([0, 1], p=[0.6, 0.4]):
114 | factor = 1 + (0.2 * (np.random.rand() - 0.5))
115 | radius *= factor
116 |
117 | raster_settings = PointsRasterizationSettings(
118 | image_size=(h, w),
119 | radius=radius,
120 | points_per_pixel=8,
121 | )
122 |
123 | # Create a points renderer by compositing points using an alpha compositor (nearer points
124 | # are weighted more heavily). See [1] for an explanation.
125 | rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
126 | renderer = PointsRenderer(
127 | rasterizer=rasterizer,
128 | compositor=AlphaCompositor()
129 | )
130 | return renderer
131 |
--------------------------------------------------------------------------------
/core/renderer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | import imageio
18 | import torch.utils.data.distributed
19 | from pytorch3d.structures import Pointclouds
20 | from core.utils import *
21 | from core.depth_layering import get_depth_bins
22 | from core.pcd import linear_interpolation, create_pcd_renderer
23 |
24 |
25 | class ImgRenderer():
26 | def __init__(self, args, model, scene_flow_estimator, inpainter, device):
27 | self.args = args
28 | self.model = model
29 | self.scene_flow_estimator = scene_flow_estimator
30 | self.inpainter = inpainter
31 | self.device = device
32 |
33 | def process_data(self, data):
34 | self.src_img1 = data['src_img1'].to(self.device)
35 | self.src_img2 = data['src_img2'].to(self.device)
36 | assert self.src_img1.shape == self.src_img2.shape
37 | self.h, self.w = self.src_img1.shape[-2:]
38 | self.src_depth1 = data['src_depth1'].to(self.device)
39 | self.src_depth2 = data['src_depth2'].to(self.device)
40 | self.intrinsic1 = data['intrinsic1'].to(self.device)
41 | self.intrinsic2 = data['intrinsic2'].to(self.device)
42 |
43 | self.pose = data['pose'].to(self.device)
44 | self.scale_shift1 = data['scale_shift1'][0]
45 | self.scale_shift2 = data['scale_shift2'][0]
46 | self.is_multi_view = data['multi_view'][0]
47 | self.src_rgb_file1 = data['src_rgb_file1'][0]
48 | self.src_rgb_file2 = data['src_rgb_file2'][0]
49 | if 'tgt_img' in data.keys():
50 | self.tgt_img = data['tgt_img'].to(self.device)
51 | if 'tgt_intrinsic' in data.keys():
52 | self.tgt_intrinsic = data['tgt_intrinsic'].to(self.device)
53 | if 'tgt_pose' in data.keys():
54 | self.tgt_pose = data['tgt_pose'].to(self.device)
55 | if 'time' in data.keys():
56 | self.time = data['time'].item()
57 | if 'src_mask1' in data.keys():
58 | self.src_mask1 = data['src_mask1'].to(self.device)
59 | else:
60 | self.src_mask1 = torch.ones_like(self.src_depth1)
61 | if 'src_mask2' in data.keys():
62 | self.src_mask2 = data['src_mask2'].to(self.device)
63 | else:
64 | self.src_mask2 = torch.ones_like(self.src_depth2)
65 |
66 |
67 |
68 | def process_data_single(self, data):
69 | self.src_img1 = data['src_img1'].to(self.device)
70 | self.h, self.w = self.src_img1.shape[-2:]
71 | self.src_depth1 = data['src_depth1'].to(self.device)
72 | self.intrinsic1 = data['intrinsic1'].to(self.device)
73 |
74 | self.pose = data['pose'].to(self.device)
75 | self.scale_shift1 = data['scale_shift1'][0]
76 | self.is_multi_view = data['multi_view'][0]
77 | self.src_rgb_file1 = data['src_rgb_file1'][0]
78 | if 'tgt_img' in data.keys():
79 | self.tgt_img = data['tgt_img'].to(self.device)
80 | if 'tgt_intrinsic' in data.keys():
81 | self.tgt_intrinsic = data['tgt_intrinsic'].to(self.device)
82 | if 'tgt_pose' in data.keys():
83 | self.tgt_pose = data['tgt_pose'].to(self.device)
84 | if 'time' in data.keys():
85 | self.time = data['time'].item()
86 | if 'src_mask1' in data.keys():
87 | self.src_mask1 = data['src_mask1'].to(self.device)
88 | else:
89 | self.src_mask1 = torch.ones_like(self.src_depth1)
90 |
91 | def feature_extraction(self, rgba_layers, mask_layers, depth_layers):
92 | rgba_layers_in = rgba_layers.squeeze(1)
93 |
94 | if self.args.use_inpainting_mask_for_feature:
95 | rgba_layers_in = torch.cat([rgba_layers_in, mask_layers.squeeze(1)], dim=1)
96 |
97 | if self.args.use_depth_for_feature:
98 | rgba_layers_in = torch.cat([rgba_layers_in, 1. / torch.clamp(depth_layers.squeeze(1), min=1.)], dim=1)
99 | featmaps = self.model.feature_net(rgba_layers_in)
100 | return featmaps
101 |
102 | def apply_scale_shift(self, depth, scale, shift):
103 | disp = 1. / torch.clamp(depth, min=1e-3)
104 | disp = scale * disp + shift
105 | return 1 / torch.clamp(disp, min=1e-3*scale)
106 |
107 | def masked_diffuse(self, x, mask, iter=10, kernel_size=35, median_blur=False):
108 | if median_blur:
109 | x = masked_median_blur(x, mask.repeat(1, x.shape[1], 1, 1), kernel_size=5)
110 | for _ in range(iter):
111 | x, mask = masked_smooth_filter(x, mask, kernel_size=kernel_size)
112 | return x, mask
113 |
114 | def compute_weight_for_two_frame_blending(self, time, disp1, disp2, alpha1, alpha2):
115 | alpha = 4
116 | weight1 = (1 - time) * torch.exp(alpha*disp1) * alpha1
117 | weight2 = time * torch.exp(alpha*disp2) * alpha2
118 | sum_weight = torch.clamp(weight1 + weight2, min=1e-6)
119 | out_weight1 = weight1 / sum_weight
120 | out_weight2 = weight2 / sum_weight
121 | return out_weight1, out_weight2
122 |
123 | def transform_all_pts(self, all_pts, pose):
124 | all_pts_out = []
125 | for pts in all_pts:
126 | pts_out = transform_pts_in_3D(pts, pose)
127 | all_pts_out.append(pts_out)
128 | return all_pts_out
129 |
130 | def render_pcd(self, pts1, pts2, rgbs1, rgbs2, feats1, feats2, mask, side_ids, R=None, t=None, time=0):
131 |
132 | pts = linear_interpolation(pts1, pts2, time)
133 | rgbs = linear_interpolation(rgbs1, rgbs2, time)
134 | feats = linear_interpolation(feats1, feats2, time)
135 | rgb_feat = torch.cat([rgbs, feats], dim=-1)
136 |
137 | num_sides = side_ids.max() + 1
138 | assert num_sides == 1 or num_sides == 2
139 |
140 | if R is None:
141 | R = torch.eye(3, device=self.device)
142 | if t is None:
143 | t = torch.zeros(3, device=self.device)
144 |
145 | pts_ = (R.mm(pts.T) + t.unsqueeze(-1)).T
146 | if self.args.adaptive_pts_radius:
147 | radius = self.args.point_radius / min(self.h, self.w) * 2.0 * pts[..., -1][None] / \
148 | torch.clamp(pts_[..., -1][None], min=1e-6)
149 | else:
150 | radius = self.args.point_radius / min(self.h, self.w) * 2.0
151 |
152 | if self.args.vary_pts_radius and np.random.choice([0, 1], p=[0.6, 0.4]):
153 | if type(radius) == torch.Tensor:
154 | factor = 1 + (0.2 * (torch.rand_like(radius) - 0.5))
155 | else:
156 | factor = 1 + (0.2 * (np.random.rand() - 0.5))
157 | radius *= factor
158 |
159 | if self.args.use_mask_for_decoding:
160 | rgb_feat = torch.cat([rgb_feat, mask], dim=-1)
161 |
162 | if self.args.use_depth_for_decoding:
163 | disp = normalize_0_1(1. / torch.clamp(pts_[..., [-1]], min=1e-6))
164 | rgb_feat = torch.cat([rgb_feat, disp], dim=-1)
165 |
166 | global_out_list = []
167 | direct_color_out_list = []
168 | meta = {}
169 | for j in range(num_sides):
170 | mask_side = side_ids == j
171 | renderer = create_pcd_renderer(self.h, self.w, self.tgt_intrinsic.squeeze()[:3, :3],
172 | radius=radius[:, mask_side] if type(radius) == torch.Tensor else radius)
173 | all_pcd_j = Pointclouds(points=[pts_[mask_side]], features=[rgb_feat[mask_side]])
174 | global_out_j = renderer(all_pcd_j)
175 | all_colored_pcd_j = Pointclouds(points=[pts_[mask_side]], features=[rgbs[mask_side]])
176 | direct_rgb_out_j = renderer(all_colored_pcd_j)
177 |
178 | global_out_list.append(global_out_j)
179 | direct_color_out_list.append(direct_rgb_out_j)
180 |
181 | w1, w2 = self.compute_weight_for_two_frame_blending(time,
182 | global_out_list[0][..., [-1]],
183 | global_out_list[-1][..., [-1]],
184 | global_out_list[0][..., [3]],
185 | global_out_list[-1][..., [3]]
186 | )
187 | direct_rgb_out = w1 * direct_color_out_list[0] + w2 * direct_color_out_list[-1]
188 | pred_rgb = self.model.img_decoder(global_out_list[0].permute(0, 3, 1, 2),
189 | global_out_list[-1].permute(0, 3, 1, 2),
190 | time)
191 |
192 | direct_rgb = direct_rgb_out[..., :3].permute(0, 3, 1, 2)
193 | acc = 0.5 * (global_out_list[0][..., [3]] + global_out_list[1][..., [3]]).permute(0, 3, 1, 2)
194 | meta['acc'] = acc
195 | return pred_rgb, direct_rgb, meta
196 |
197 |
198 | def render_pcd_single(self, pts1, rgbs1, feats1, mask, side_ids, R=None, t=None, time=0):
199 |
200 | pts = pts1
201 | rgbs = rgbs1
202 | feats = feats1
203 | rgb_feat = torch.cat([rgbs, feats], dim=-1)
204 |
205 | num_sides = side_ids.max() + 1
206 | assert num_sides == 1 or num_sides == 2
207 |
208 | if R is None:
209 | R = torch.eye(3, device=self.device)
210 | if t is None:
211 | t = torch.zeros(3, device=self.device)
212 |
213 | pts_ = (R.mm(pts.T) + t.unsqueeze(-1)).T
214 | if self.args.adaptive_pts_radius:
215 | radius = self.args.point_radius / min(self.h, self.w) * 2.0 * pts[..., -1][None] / \
216 | torch.clamp(pts_[..., -1][None], min=1e-6)
217 | else:
218 | radius = self.args.point_radius / min(self.h, self.w) * 2.0
219 |
220 | if self.args.vary_pts_radius and np.random.choice([0, 1], p=[0.6, 0.4]):
221 | if type(radius) == torch.Tensor:
222 | factor = 1 + (0.2 * (torch.rand_like(radius) - 0.5))
223 | else:
224 | factor = 1 + (0.2 * (np.random.rand() - 0.5))
225 | radius *= factor
226 |
227 | if self.args.use_mask_for_decoding:
228 | rgb_feat = torch.cat([rgb_feat, mask], dim=-1)
229 |
230 | if self.args.use_depth_for_decoding:
231 | disp = normalize_0_1(1. / torch.clamp(pts_[..., [-1]], min=1e-6))
232 | rgb_feat = torch.cat([rgb_feat, disp], dim=-1)
233 |
234 | global_out_list = []
235 | direct_color_out_list = []
236 | meta = {}
237 | for j in range(num_sides):
238 | mask_side = side_ids == j
239 | renderer = create_pcd_renderer(self.h, self.w, self.tgt_intrinsic.squeeze()[:3, :3],
240 | radius=radius[:, mask_side] if type(radius) == torch.Tensor else radius)
241 | all_pcd_j = Pointclouds(points=[pts_[mask_side]], features=[rgb_feat[mask_side]])
242 | global_out_j = renderer(all_pcd_j)
243 | all_colored_pcd_j = Pointclouds(points=[pts_[mask_side]], features=[rgbs[mask_side]])
244 | direct_rgb_out_j = renderer(all_colored_pcd_j)
245 |
246 | global_out_list.append(global_out_j)
247 | direct_color_out_list.append(direct_rgb_out_j)
248 |
249 | direct_rgb_out = direct_color_out_list[0]
250 | pred_rgb = self.model.img_decoder(global_out_list[0].permute(0, 3, 1, 2),
251 | global_out_list[0].permute(0, 3, 1, 2),
252 | time=0)
253 |
254 | direct_rgb = direct_rgb_out[..., :3].permute(0, 3, 1, 2)
255 | acc = (global_out_list[0][..., [3]]).permute(0, 3, 1, 2)
256 | meta['acc'] = acc
257 | return pred_rgb, direct_rgb, meta
258 |
259 |
260 | def get_reprojection_mask(self, pts, R, t):
261 | pts1_ = (R.mm(pts.T) + t.unsqueeze(-1)).T
262 | mask1 = torch.ones_like(self.src_img1[:, :1].reshape(-1, 1))
263 | mask_renderer = create_pcd_renderer(self.h, self.w, self.tgt_intrinsic.squeeze()[:3, :3],
264 | radius=1.0 / min(self.h, self.w) * 4.)
265 | mask_pcd = Pointclouds(points=[pts1_], features=[mask1])
266 | mask = mask_renderer(mask_pcd).permute(0, 3, 1, 2)
267 | mask = F.max_pool2d(mask, kernel_size=7, stride=1, padding=3)
268 | return mask
269 |
270 | def get_cropping_ids(self, mask):
271 | assert mask.shape[:2] == (1, 1)
272 | mask = mask.squeeze()
273 | h, w = mask.shape
274 | mask_mean_x_axis = mask.mean(dim=0)
275 | x_valid = torch.nonzero(mask_mean_x_axis > 0.5)
276 | bad = False
277 | if len(x_valid) < 0.75 * w:
278 | left, right = 0, w - 1 # invalid
279 | bad = True
280 | else:
281 | left, right = x_valid[0][0], x_valid[-1][0]
282 | mask_mean_y_axis = mask.mean(dim=1)
283 | y_valid = torch.nonzero(mask_mean_y_axis > 0.5)
284 | if len(y_valid) < 0.75 * h:
285 | top, bottom = 0, h - 1 # invalid
286 | bad = True
287 | else:
288 | top, bottom = y_valid[0][0], y_valid[-1][0]
289 | assert 0 <= top <= h - 1 and 0 <= bottom <= h - 1 and 0 <= left <= w - 1 and 0 <= right <= w - 1
290 | return top, bottom, left, right, bad
291 |
292 | def render_depth_from_mdi(self, depth_layers, alpha_layers):
293 | '''
294 | :param depth_layers: [n_layers, 1, h, w]
295 | :param alpha_layers: [n_layers, 1, h, w]
296 | :return: rendered depth [1, 1, h, w]
297 | '''
298 | num_layers = len(depth_layers)
299 | h, w = depth_layers.shape[-2:]
300 | layer_id = torch.arange(num_layers, device=self.device).float()
301 | layer_id_maps = layer_id[..., None, None, None, None].repeat(1, 1, 1, h, w)
302 | T = torch.cumprod(1. - alpha_layers, dim=0)[:-1]
303 | T = torch.cat([torch.ones_like(T[:1]), T], dim=0)
304 | weights = alpha_layers * T
305 | depth_map = torch.sum(weights * depth_layers, dim=0)
306 | depth_map = torch.clamp(depth_map, min=1.)
307 | layer_id_map = torch.sum(weights * layer_id_maps, dim=0)
308 | return depth_map, layer_id_map
309 |
310 | def render_rgbda_layers_from_one_view(self, return_pts=False):
311 | depth_bins = get_depth_bins(depth=self.src_depth1)
312 | rgba_layers, depth_layers, mask_layers = \
313 | self.inpainter.sequential_inpainting(self.src_img1, self.src_depth1, depth_bins)
314 | coord1 = get_coord_grids_pt(self.h, self.w, device=self.device).float()
315 | src_depth1 = self.apply_scale_shift(self.src_depth1, self.scale_shift1[0], self.scale_shift1[1])
316 | pts1 = unproject_pts_pt(self.intrinsic1, coord1.reshape(-1, 2), src_depth1.flatten())
317 |
318 | featmaps = self.feature_extraction(rgba_layers, mask_layers, depth_layers)
319 | depth_layers = self.apply_scale_shift(depth_layers, self.scale_shift1[0], self.scale_shift1[1])
320 | num_layers = len(rgba_layers)
321 | all_pts = []
322 | all_rgbas = []
323 | all_feats = []
324 | all_masks = []
325 | for i in range(num_layers):
326 | alpha_i = rgba_layers[i][:, -1] > 0.5
327 | rgba_i = rgba_layers[i]
328 | mask_i = mask_layers[i]
329 | featmap = featmaps[i][None]
330 | featmap = F.interpolate(featmap, size=(self.h, self.w), mode='bilinear', align_corners=True)
331 | pts1_i = unproject_pts_pt(self.intrinsic1, coord1.reshape(-1, 2), depth_layers[i].flatten())
332 | pts1_i = pts1_i.reshape(1, self.h, self.w, 3)
333 | all_pts.append(pts1_i[alpha_i])
334 | all_rgbas.append(rgba_i.permute(0, 2, 3, 1)[alpha_i])
335 | all_feats.append(featmap.permute(0, 2, 3, 1)[alpha_i])
336 | all_masks.append(mask_i.permute(0, 2, 3, 1)[alpha_i])
337 |
338 | all_pts = torch.cat(all_pts)
339 | all_rgbas = torch.cat(all_rgbas)
340 | all_feats = torch.cat(all_feats)
341 | all_masks = torch.cat(all_masks)
342 | all_side_ids = torch.zeros_like(all_masks.squeeze(), dtype=torch.long)
343 |
344 |
345 | if return_pts:
346 | return all_pts, all_rgbas, all_feats, \
347 | all_masks, all_side_ids
348 |
349 | R = self.tgt_pose[0, :3, :3]
350 | t = self.tgt_pose[0, :3, 3]
351 |
352 | pred_img, direct_rgb_out, meta = self.render_pcd(all_pts, all_pts,
353 | all_rgbas, all_rgbas,
354 | all_feats, all_feats,
355 | all_masks, all_side_ids,
356 | R, t, 0)
357 |
358 | mask = self.get_reprojection_mask(pts1, R, t)
359 | t, b, l, r, bad = self.get_cropping_ids(mask)
360 | skip = False
361 | if not skip and not self.args.eval_mode:
362 | pred_img = pred_img[:, :, t:b, l:r]
363 | mask = mask[:, :, t:b, l:r]
364 | direct_rgb_out = direct_rgb_out[:, :, t:b, l:r]
365 | else:
366 | skip = True
367 |
368 | res_dict = {
369 | 'src_img1': self.src_img1,
370 | 'pred_img': pred_img,
371 | 'mask': mask,
372 | 'direct_rgb_out': direct_rgb_out,
373 | 'skip': skip
374 | }
375 | return res_dict
376 |
377 | def compute_scene_flow_one_side(self, coord, pose,
378 | rgb1, rgb2,
379 | rgba_layers1, rgba_layers2,
380 | featmaps1, featmaps2,
381 | pts1, pts2,
382 | depth_layers1, depth_layers2,
383 | mask_layers1, mask_layers2,
384 | flow_f, flow_b, kernel,
385 | with_inpainted=False):
386 |
387 | num_layers1 = len(rgba_layers1)
388 | pts2 = transform_pts_in_3D(pts2, pose).T.reshape(1, 3, self.h, self.w)
389 |
390 | mask_mutual_flow = self.scene_flow_estimator.get_mutual_matches(flow_f, flow_b, th=5, return_mask=True).float()
391 | mask_mutual_flow = mask_mutual_flow.unsqueeze(1)
392 |
393 | coord1_corsp = coord + flow_f
394 | coord1_corsp_normed = normalize_for_grid_sample(coord1_corsp, self.h, self.w)
395 | pts2_sampled = F.grid_sample(pts2, coord1_corsp_normed, align_corners=True,
396 | mode='nearest', padding_mode="border")
397 | depth2_sampled = pts2_sampled[:, -1:]
398 |
399 | rgb2_sampled = F.grid_sample(rgb2, coord1_corsp_normed, align_corners=True, padding_mode="border")
400 | mask_layers2_ds = F.interpolate(mask_layers2.squeeze(1), size=featmaps2.shape[-2:], mode='area')
401 | featmap2 = torch.sum(featmaps2 * mask_layers2_ds, dim=0, keepdim=True)
402 | context2 = torch.sum(mask_layers2_ds, dim=0, keepdim=True)
403 | featmap2_sampled = F.grid_sample(featmap2, coord1_corsp_normed, align_corners=True, padding_mode="border")
404 | context2_sampled = F.grid_sample(context2, coord1_corsp_normed, align_corners=True, padding_mode="border")
405 | mask2_sampled = F.grid_sample(self.src_mask2, coord1_corsp_normed, align_corners=True, padding_mode="border")
406 |
407 | featmap2_sampled = featmap2_sampled / torch.clamp(context2_sampled, min=1e-6)
408 | context2_sampled = (context2_sampled > 0.5).float()
409 | last_pts2_i = torch.zeros_like(pts2.permute(0, 2, 3, 1))
410 | last_alpha_i = torch.zeros_like(rgba_layers1[0][:, -1], dtype=torch.bool)
411 |
412 | all_pts = []
413 | all_rgbas = []
414 | all_feats = []
415 | all_rgbas_end = []
416 | all_feats_end = []
417 | all_masks = []
418 | all_pts_end = []
419 | all_optical_flows = []
420 | for i in range(num_layers1):
421 | alpha_i = (rgba_layers1[i][:, -1]*self.src_mask1.squeeze(1)*mask2_sampled.squeeze(1)) > 0.5
422 | rgba_i = rgba_layers1[i]
423 | mask_i = mask_layers1[i]
424 | mask_no_mutual_flow = mask_i * context2_sampled
425 | mask_gau_i = mask_no_mutual_flow * mask_mutual_flow
426 | mask_no_mutual_flow = erosion(mask_no_mutual_flow, kernel)
427 | mask_gau_i = erosion(mask_gau_i, kernel)
428 |
429 | featmap1 = featmaps1[i][None]
430 | featmap1 = F.interpolate(featmap1, size=(self.h, self.w), mode='bilinear', align_corners=True)
431 | pts1_i = unproject_pts_pt(self.intrinsic1, coord.reshape(-1, 2), depth_layers1[i].flatten())
432 | pts1_i = pts1_i.reshape(1, self.h, self.w, 3)
433 |
434 | flow_inpainted, mask_no_mutual_flow_ = self.masked_diffuse(flow_f.permute(0, 3, 1, 2),
435 | mask_no_mutual_flow,
436 | kernel_size=15, iter=7)
437 |
438 | coord_inpainted = coord.clone()
439 | coord_inpainted_ = coord + flow_inpainted.permute(0, 2, 3, 1)
440 | mask_no_mutual_flow_bool = (mask_no_mutual_flow_ > 1e-6).squeeze(1)
441 | coord_inpainted[mask_no_mutual_flow_bool] = coord_inpainted_[mask_no_mutual_flow_bool]
442 |
443 | depth_inpainted = depth_layers1[i].clone()
444 | depth_inpainted_, mask_gau_i_ = self.masked_diffuse(depth2_sampled, mask_gau_i,
445 | kernel_size=15, iter=7)
446 | mask_gau_i_bool = (mask_gau_i_ > 1e-6).squeeze(1)
447 | depth_inpainted.squeeze(1)[mask_gau_i_bool] = depth_inpainted_.squeeze(1)[mask_gau_i_bool]
448 | pts2_i = unproject_pts_pt(self.intrinsic2, coord_inpainted.contiguous().reshape(-1, 2),
449 | depth_inpainted.flatten()).reshape(1, self.h, self.w, 3)
450 |
451 | if i > 0:
452 | mask_wrong_ordering = (pts2_i[..., -1] <= last_pts2_i[..., -1]) * last_alpha_i
453 | pts2_i[mask_wrong_ordering] = last_pts2_i[mask_wrong_ordering] * 1.01
454 |
455 | rgba_end = mask_gau_i * torch.cat([rgb2_sampled, mask_gau_i], dim=1) + (1 - mask_gau_i) * rgba_i
456 | feat_end = mask_gau_i * featmap2_sampled + (1 - mask_gau_i) * featmap1
457 | last_alpha_i[alpha_i] = True
458 | last_pts2_i[alpha_i] = pts2_i[alpha_i]
459 |
460 | if with_inpainted:
461 | mask_keep = alpha_i
462 | else:
463 | mask_keep = mask_i.squeeze(1).bool()
464 |
465 | all_pts.append(pts1_i[mask_keep])
466 | all_rgbas.append(rgba_i.permute(0, 2, 3, 1)[mask_keep])
467 | all_feats.append(featmap1.permute(0, 2, 3, 1)[mask_keep])
468 | all_masks.append(mask_i.permute(0, 2, 3, 1)[mask_keep])
469 | all_pts_end.append(pts2_i[mask_keep])
470 | all_rgbas_end.append(rgba_end.permute(0, 2, 3, 1)[mask_keep])
471 | all_feats_end.append(feat_end.permute(0, 2, 3, 1)[mask_keep])
472 | all_optical_flows.append(flow_inpainted.permute(0, 2, 3, 1)[mask_keep])
473 |
474 | return all_pts, all_pts_end, all_rgbas, all_rgbas_end, all_feats, all_feats_end, all_masks, all_optical_flows
475 |
476 | def render_rgbda_layers_with_scene_flow(self, return_pts=False):
477 | kernel = torch.ones(5, 5, device=self.device)
478 | flow_f = self.scene_flow_estimator.compute_optical_flow(self.src_img1, self.src_img2)
479 | flow_b = self.scene_flow_estimator.compute_optical_flow(self.src_img2, self.src_img1)
480 |
481 | depth_bins1 = get_depth_bins(depth=self.src_depth1)
482 | depth_bins2 = get_depth_bins(depth=self.src_depth2)
483 |
484 | rgba_layers1, depth_layers1, mask_layers1 = \
485 | self.inpainter.sequential_inpainting(self.src_img1, self.src_depth1, depth_bins1)
486 | rgba_layers2, depth_layers2, mask_layers2 = \
487 | self.inpainter.sequential_inpainting(self.src_img2, self.src_depth2, depth_bins2)
488 | if self.args.visualize_rgbda_layers:
489 | self.save_rgbda_layers(self.src_rgb_file1, rgba_layers1, depth_layers1, mask_layers1)
490 | self.save_rgbda_layers(self.src_rgb_file2, rgba_layers2, depth_layers2, mask_layers2)
491 |
492 | featmaps1 = self.feature_extraction(rgba_layers1, mask_layers1, depth_layers1)
493 | featmaps2 = self.feature_extraction(rgba_layers2, mask_layers2, depth_layers2)
494 |
495 | depth_layers1 = self.apply_scale_shift(depth_layers1, self.scale_shift1[0], self.scale_shift1[1])
496 | depth_layers2 = self.apply_scale_shift(depth_layers2, self.scale_shift2[0], self.scale_shift2[1])
497 |
498 | processed_depth1, layer_id_map1 = self.render_depth_from_mdi(depth_layers1, rgba_layers1[:, :, -1:])
499 | processed_depth2, layer_id_map2 = self.render_depth_from_mdi(depth_layers2, rgba_layers2[:, :, -1:])
500 |
501 | assert self.src_img1.shape[-2:] == self.src_img2.shape[-2:]
502 | h, w = self.src_img1.shape[-2:]
503 | coord = get_coord_grids_pt(h, w, device=self.device).float()[None]
504 | pts1 = unproject_pts_pt(self.intrinsic1, coord.reshape(-1, 2), processed_depth1.flatten())
505 | pts2 = unproject_pts_pt(self.intrinsic2, coord.reshape(-1, 2), processed_depth2.flatten())
506 |
507 | all_pts_11, all_pts_12, all_rgbas_11, all_rgbas_12, all_feats_11, all_feats_12,\
508 | all_masks_1, all_optical_flow_1 = \
509 | self.compute_scene_flow_one_side(coord, torch.inverse(self.pose), self.src_img1, self.src_img2,
510 | rgba_layers1, rgba_layers2, featmaps1, featmaps2,
511 | pts1, pts2, depth_layers1, depth_layers2, mask_layers1, mask_layers2,
512 | flow_f, flow_b, kernel, with_inpainted=True)
513 |
514 | all_pts_22, all_pts_21, all_rgbas_22, all_rgbas_21, all_feats_22, all_feats_21,\
515 | all_masks_2, all_optical_flow_2 = \
516 | self.compute_scene_flow_one_side(coord, self.pose, self.src_img2, self.src_img1,
517 | rgba_layers2, rgba_layers1, featmaps2, featmaps1,
518 | pts2, pts1, depth_layers2, depth_layers1, mask_layers2, mask_layers1,
519 | flow_b, flow_f, kernel, with_inpainted=True)
520 |
521 | if not torch.allclose(self.pose, torch.eye(4, device=self.device)):
522 | all_pts_21 = self.transform_all_pts(all_pts_21, torch.inverse(self.pose))
523 | all_pts_22 = self.transform_all_pts(all_pts_22, torch.inverse(self.pose))
524 |
525 | all_pts = torch.cat(all_pts_11+all_pts_21)
526 | all_rgbas = torch.cat(all_rgbas_11+all_rgbas_21)
527 | all_feats = torch.cat(all_feats_11+all_feats_21)
528 | all_masks = torch.cat(all_masks_1+all_masks_2)
529 | all_pts_end = torch.cat(all_pts_12+all_pts_22)
530 | all_rgbas_end = torch.cat(all_rgbas_12+all_rgbas_22)
531 | all_feats_end = torch.cat(all_feats_12+all_feats_22)
532 | all_side_ids = torch.zeros_like(all_masks.squeeze(), dtype=torch.long)
533 | num_pts_2 = sum([len(x) for x in all_pts_21])
534 | all_side_ids[-num_pts_2:] = 1
535 | all_optical_flow = torch.cat(all_optical_flow_1+all_optical_flow_2)
536 |
537 | if return_pts:
538 | return all_pts, all_pts_end, all_rgbas, all_rgbas_end, all_feats, all_feats_end, \
539 | all_masks, all_side_ids, all_optical_flow
540 |
541 | R = self.tgt_pose[0, :3, :3]
542 | t = self.tgt_pose[0, :3, 3]
543 | pred_img, direct_rgb_out, meta = self.render_pcd(all_pts, all_pts_end,
544 | all_rgbas, all_rgbas_end,
545 | all_feats, all_feats_end,
546 | all_masks, all_side_ids,
547 | R, t, self.time)
548 | mask1 = self.get_reprojection_mask(pts1, R, t)
549 | pose2_to_tgt = self.tgt_pose.bmm(torch.inverse(self.pose))
550 | mask2 = self.get_reprojection_mask(pts2, pose2_to_tgt[0, :3, :3], pose2_to_tgt[0, :3, 3])
551 | mask = (mask1+mask2) * 0.5
552 | gt_img = self.tgt_img
553 | t, b, l, r, bad = self.get_cropping_ids(mask)
554 | skip = False
555 | if not skip and not self.args.eval_mode:
556 | pred_img = pred_img[:, :, t:b, l:r]
557 | mask = mask[:, :, t:b, l:r]
558 | direct_rgb_out = direct_rgb_out[:, :, t:b, l:r]
559 | gt_img = gt_img[:, :, t:b, l:r]
560 | else:
561 | skip = True
562 |
563 | res_dict = {
564 | 'src_img1': self.src_img1,
565 | 'src_img2': self.src_img2,
566 | 'pred_img': pred_img,
567 | 'gt_img': gt_img,
568 | 'mask': mask,
569 | 'direct_rgb_out': direct_rgb_out,
570 | 'alpha_layers1': rgba_layers1[:, :, [-1]],
571 | 'alpha_layers2': rgba_layers2[:, :, [-1]],
572 | 'mask_layers1': mask_layers1,
573 | 'mask_layers2': mask_layers2,
574 | 'skip': skip
575 | }
576 | return res_dict
577 |
578 | def dynamic_view_synthesis_with_inpainting(self):
579 | if self.is_multi_view:
580 | return self.render_rgbda_layers_from_one_view()
581 | else:
582 | return self.render_rgbda_layers_with_scene_flow()
583 |
584 | def get_prediction(self, data):
585 | # process data first
586 | self.process_data(data)
587 | return self.dynamic_view_synthesis_with_inpainting()
588 |
589 | def save_rgbda_layers(self, src_rgb_file, rgba_layers, depth_layers, mask_layers):
590 | frame_id = os.path.basename(src_rgb_file).split('.')[0]
591 | scene_id = src_rgb_file.split('/')[-3]
592 | out_dir = os.path.join(self.args.rootdir, 'out', self.args.expname, 'vis',
593 | '{}-{}'.format(scene_id, frame_id))
594 | os.makedirs(out_dir, exist_ok=True)
595 |
596 | alpha_layers = rgba_layers[:, :, [-1]]
597 | for i, rgba_layer in enumerate(rgba_layers):
598 | save_filename = os.path.join(out_dir, 'rgb_original_{}.png'.format(i))
599 | rgba_layer_ = rgba_layer * mask_layers[i]
600 | rgba_np = rgba_layer_.detach().squeeze().permute(1, 2, 0).cpu().numpy()
601 | imageio.imwrite(save_filename, float2uint8(rgba_np))
602 |
603 | for i, rgba_layer in enumerate(rgba_layers):
604 | save_filename = os.path.join(out_dir, 'rgb_{}.png'.format(i))
605 | rgba_np = rgba_layer.detach().squeeze().permute(1, 2, 0).cpu().numpy()
606 | imageio.imwrite(save_filename, float2uint8(rgba_np))
607 |
608 | for i, depth_layer in enumerate(depth_layers):
609 | save_filename = os.path.join(out_dir, 'disparity_original_{}.png'.format(i))
610 | disparity = (1. / torch.clamp(depth_layer, min=1e-6)) * alpha_layers[i]
611 | disparity = torch.cat([disparity, disparity, disparity, alpha_layers[i]*mask_layers[i]], dim=1)
612 | disparity_np = disparity.detach().squeeze().cpu().numpy().transpose(1, 2, 0)
613 | imageio.imwrite(save_filename, float2uint8(disparity_np))
614 |
615 | for i, depth_layer in enumerate(depth_layers):
616 | save_filename = os.path.join(out_dir, 'disparity_{}.png'.format(i))
617 | disparity = (1. / torch.clamp(depth_layer, min=1e-6)) * alpha_layers[i]
618 | disparity = torch.cat([disparity, disparity, disparity, alpha_layers[i]], dim=1)
619 | disparity_np = disparity.detach().squeeze().cpu().numpy().transpose(1, 2, 0)
620 | imageio.imwrite(save_filename, float2uint8(disparity_np))
621 |
622 | for i, mask_layer in enumerate(mask_layers):
623 | save_filename = os.path.join(out_dir, 'mask_{}.png'.format(i))
624 | tri_mask = 0.5 * alpha_layers[i] + 0.5 * mask_layer
625 | tri_mask_np = tri_mask.detach().squeeze().cpu().numpy()
626 | imageio.imwrite(save_filename, float2uint8(tri_mask_np))
627 |
628 |
629 |
630 |
--------------------------------------------------------------------------------
/core/scene_flow.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from third_party.RAFT.core.utils.utils import InputPadder
17 | from core.utils import *
18 |
19 |
20 | class SceneFlowEstimator():
21 | def __init__(self, args, model):
22 | device = "cuda:{}".format(args.local_rank)
23 | self.device = device
24 | self.raft_model = model
25 | self.train_raft = args.train_raft
26 |
27 | def compute_optical_flow(self, img1, img2, return_np_array=False):
28 | '''
29 | :param img1: [B, 3, H, W]
30 | :param img2: [B, 3, H, W]
31 | :return: optical_flow, [B, H, W, 2]
32 | '''
33 | if not self.train_raft:
34 | with torch.no_grad():
35 | assert img1.max() <= 1 and img2.max() <= 1
36 | image1 = img1 * 255.
37 | image2 = img2 * 255.
38 | padder = InputPadder(image1.shape)
39 | image1, image2 = padder.pad(image1, image2)
40 |
41 | flow_low, flow_up = self.raft_model.module(image1, image2, iters=20, test_mode=True, padder=padder)
42 |
43 | if return_np_array:
44 | return flow_up.cpu().numpy().transpose(0, 2, 3, 1)
45 |
46 | return flow_up.permute(0, 2, 3, 1).detach() # [B, h, w, 2]
47 | else:
48 | assert img1.max() <= 1 and img2.max() <= 1
49 | image1 = img1 * 255.
50 | image2 = img2 * 255.
51 | padder = InputPadder(image1.shape)
52 | image1, image2 = padder.pad(image1, image2)
53 | flow_predictions = self.raft_model.module(image1, image2, iters=20, padder=padder)
54 | return flow_predictions[-1].permute(0, 2, 3, 1) # [B, h, w, 2]
55 |
56 | def get_mutual_matches(self, flow_f, flow_b, th=2., return_mask=False):
57 | assert flow_f.shape == flow_b.shape
58 | batch_size = flow_f.shape[0]
59 | assert flow_f.shape[1:3] == flow_b.shape[1:3]
60 | h, w = flow_f.shape[1:3]
61 | grid = get_coord_grids_pt(h, w, self.device)[None].float().repeat(batch_size, 1, 1, 1) # [B, h, w, 2]
62 | grid2 = grid + flow_f
63 | mask_boundary = (grid2[..., 0] >= 0) * (grid2[..., 0] <= w - 1) * \
64 | (grid2[..., 1] >= 0) * (grid2[..., 1] <= h - 1)
65 | grid2_normed = normalize_for_grid_sample(grid2, h, w)
66 | flow_b_sampled = F.grid_sample(flow_b.permute(0, 3, 1, 2), grid2_normed,
67 | align_corners=True).permute(0, 2, 3, 1)
68 | grid1 = grid2 + flow_b_sampled
69 | mask_boundary *= (grid1[..., 0] >= 0) * (grid1[..., 0] <= w - 1) * \
70 | (grid1[..., 1] >= 0) * (grid1[..., 1] <= h - 1)
71 |
72 | fb_map = flow_f + flow_b_sampled
73 | mask_valid = mask_boundary * (torch.norm(fb_map, dim=-1) < th)
74 | if return_mask:
75 | return mask_valid
76 | coords1 = grid[mask_valid] # [n, 2]
77 | coords2 = grid2[mask_valid] # [n, 2]
78 | return coords1, coords2
--------------------------------------------------------------------------------
/core/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Tuple
17 | import numpy as np
18 | import torch
19 | import torch.nn.functional as F
20 | from kornia.filters import gaussian_blur2d, box_blur, median_blur
21 | from kornia.filters.kernels import get_binary_kernel2d
22 | from kornia.morphology import erosion
23 | from scipy.interpolate import interp1d
24 | from scipy.ndimage import median_filter
25 |
26 |
27 | def float2uint8(x):
28 | return (255. * x).astype(np.uint8)
29 |
30 |
31 | def float2uint16(x):
32 | return (65535 * x).astype(np.uint16)
33 |
34 |
35 | def normalize_0_1(x):
36 | x_min, x_max = x.min(), x.max()
37 | return (x - x_min) / (x_max - x_min)
38 |
39 |
40 | def homogenize_np(coord):
41 | """
42 | append ones in the last dimension
43 | :param coord: [...., 2/3]
44 | :return: homogenous coordinates
45 | """
46 | return np.concatenate([coord, np.ones_like(coord[..., :1])], axis=-1)
47 |
48 |
49 | def homogenize_pt(coord):
50 | return torch.cat([coord, torch.ones_like(coord[..., :1])], dim=-1)
51 |
52 |
53 | def get_coord_grids_pt(h, w, device, homogeneous=False):
54 | """
55 | create pxiel coordinate grid
56 | :param h: height
57 | :param w: weight
58 | :param device: device
59 | :param homogeneous: if homogeneous coordinate
60 | :return: coordinates [h, w, 2]
61 | """
62 | y = torch.arange(0, h).to(device)
63 | x = torch.arange(0, w).to(device)
64 | grid_y, grid_x = torch.meshgrid(y, x)
65 | if homogeneous:
66 | return torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1)
67 | return torch.stack([grid_x, grid_y], dim=-1) # [h, w, 2]
68 |
69 |
70 | def normalize_for_grid_sample(coords, h, w):
71 | device = coords.device
72 | coords_normed = coords / torch.tensor([w-1., h-1.]).to(device) * 2. - 1.
73 | return coords_normed
74 |
75 |
76 | def unproject_pts_np(intrinsics, coords, depth):
77 | if coords.shape[-1] == 2:
78 | coords = homogenize_np(coords)
79 | intrinsics = intrinsics.squeeze()[:3, :3]
80 | coords = np.linalg.inv(intrinsics).dot(coords.T) * depth.reshape(1, -1)
81 | return coords.T # [n, 3]
82 |
83 |
84 | def unproject_pts_pt(intrinsics, coords, depth, normalize_depth=False):
85 | if coords.shape[-1] == 2:
86 | coords = homogenize_pt(coords)
87 | intrinsics = intrinsics.squeeze()[:3, :3]
88 | coords = torch.inverse(intrinsics).mm(coords.T)
89 | if normalize_depth:
90 | coords = F.normalize(coords, dim=0, p=2.0)
91 | print("erer")
92 | coords = coords * depth.reshape(1, -1)
93 | return coords.T # [n, 3]
94 |
95 |
96 | def pixel2cam(depth, pixel_coords, intrinsics, is_homogeneous=True):
97 | """Transforms coordinates in the pixel frame to the camera frame.
98 | Args:
99 | depth: [batch, height, width]
100 | pixel_coords: homogeneous pixel coordinates [batch, 3, height, width]
101 | intrinsics: camera intrinsics [batch, 3, 3]
102 | is_homogeneous: return in homogeneous coordinates
103 | Returns:
104 | Coords in the camera frame [batch, 3 (4 if homogeneous), height, width]
105 | """
106 | if depth.ndim == 4:
107 | assert depth.shape[1] == 1
108 | depth = depth.squeeze(1)
109 | batch, height, width = depth.shape
110 | depth = depth.reshape(batch, 1, -1)
111 | pixel_coords = pixel_coords.reshape(batch, 3, -1)
112 | cam_coords = torch.inverse(intrinsics).bmm(pixel_coords) * depth
113 | if is_homogeneous:
114 | ones = torch.ones_like(depth)
115 | cam_coords = torch.cat([cam_coords, ones], dim=1)
116 | cam_coords = cam_coords.reshape(batch, -1, height, width)
117 | return cam_coords
118 |
119 |
120 | def transform_pts_in_3D(pts, pose, return_homogeneous=False):
121 | '''
122 | :param pts: nx3, tensor
123 | :param pose: 4x4, tensor
124 | :return: nx3 or nx4, tensor
125 | '''
126 | pts_h = homogenize_pt(pts)
127 | pose = pose.squeeze()
128 | assert pose.shape == (4, 4)
129 | transformed_pts_h = pose.mm(pts_h.T).T # [n, 4]
130 | if return_homogeneous:
131 | return transformed_pts_h
132 | return transformed_pts_h[..., :3]
133 |
134 |
135 | def crop_boundary(x, ratio):
136 | h, w = x.shape[-2:]
137 | crop_h = int(h * ratio)
138 | crop_w = int(w * ratio)
139 | return x[:, :, crop_h:h-crop_h, crop_w:w-crop_w]
140 |
141 |
142 | def masked_smooth_filter(x, mask, kernel_size=9, sigma=1):
143 | '''
144 | :param x: [B, n, h, w]
145 | :param mask: [B, 1, h, w]
146 | :return: [B, n, h, w]
147 | '''
148 | x_ = x * mask
149 | x_ = box_blur(x_, (kernel_size, kernel_size), border_type='constant')
150 | mask_ = box_blur(mask, (kernel_size, kernel_size), border_type='constant')
151 | x_ = x_ / torch.clamp(mask_, min=1e-6)
152 | mask_bool = (mask.repeat(1, x.shape[1], 1, 1) > 1e-6).float()
153 | out = mask_bool * x + (1. - mask_bool) * x_
154 | return out, mask_
155 |
156 |
157 | def remove_noise_in_dpt_disparity(disparity, kernel_size=5):
158 | return median_filter(disparity, size=kernel_size)
159 |
160 |
161 | def _compute_zero_padding(kernel_size: Tuple[int, int]) -> Tuple[int, int]:
162 | r"""Utility function that computes zero padding tuple."""
163 | computed: List[int] = [(k - 1) // 2 for k in kernel_size]
164 | return computed[0], computed[1]
165 |
166 |
167 | def masked_median_blur(input, mask, kernel_size=9):
168 | assert input.shape == mask.shape
169 | if not isinstance(input, torch.Tensor):
170 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
171 |
172 | if not len(input.shape) == 4:
173 | raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
174 |
175 | padding: Tuple[int, int] = _compute_zero_padding((kernel_size, kernel_size))
176 |
177 | # prepare kernel
178 | kernel: torch.Tensor = get_binary_kernel2d((kernel_size, kernel_size)).to(input)
179 | b, c, h, w = input.shape
180 |
181 | # map the local window to single vector
182 | features: torch.Tensor = F.conv2d(input.reshape(b * c, 1, h, w), kernel, padding=padding, stride=1)
183 | masks: torch.Tensor = F.conv2d(mask.reshape(b * c, 1, h, w), kernel, padding=padding, stride=1)
184 | features = features.view(b, c, -1, h, w).permute(0, 1, 3, 4, 2) # BxCxxHxWx(K_h * K_w)
185 | min_value, max_value = features.min(), features.max()
186 | masks = masks.view(b, c, -1, h, w).permute(0, 1, 3, 4, 2) # BxCxHxWx(K_h * K_w)
187 | index_invalid = (1 - masks).nonzero(as_tuple=True)
188 | index_b, index_c, index_h, index_w, index_k = index_invalid
189 | features[(index_b[::2], index_c[::2], index_h[::2], index_w[::2], index_k[::2])] = min_value
190 | features[(index_b[1::2], index_c[1::2], index_h[1::2], index_w[1::2], index_k[1::2])] = max_value
191 | # compute the median along the feature axis
192 | median: torch.Tensor = torch.median(features, dim=-1)[0]
193 |
194 | return median
195 |
196 |
197 | def define_camera_path(num_frames, x, y, z, path_type='circle', return_t_only=False):
198 | generic_pose = np.eye(4)
199 | tgt_poses = []
200 | if path_type == 'straight-line':
201 | corner_points = np.array([[0, 0, 0], [(0 + x) * 0.5, (0 + y) * 0.5, (0 + z) * 0.5], [x, y, z]])
202 | corner_t = np.linspace(0, 1, len(corner_points))
203 | t = np.linspace(0, 1, num_frames)
204 | cs = interp1d(corner_t, corner_points, axis=0, kind='quadratic')
205 | spline = cs(t)
206 | xs, ys, zs = [xx.squeeze() for xx in np.split(spline, 3, 1)]
207 | elif path_type == 'double-straight-line':
208 | corner_points = np.array([[-x, -y, -z], [0, 0, 0], [x, y, z]])
209 | corner_t = np.linspace(0, 1, len(corner_points))
210 | t = np.linspace(0, 1, num_frames)
211 | cs = interp1d(corner_t, corner_points, axis=0, kind='quadratic')
212 | spline = cs(t)
213 | xs, ys, zs = [xx.squeeze() for xx in np.split(spline, 3, 1)]
214 | elif path_type == 'circle':
215 | xs, ys, zs = [], [], []
216 | for frame_id, bs_shift_val in enumerate(np.arange(-2.0, 2.0, (4./num_frames))):
217 | xs += [np.cos(bs_shift_val * np.pi) * 1 * x]
218 | ys += [np.sin(bs_shift_val * np.pi) * 1 * y]
219 | zs += [np.cos(bs_shift_val * np.pi/2.) * 1 * z]
220 | xs, ys, zs = np.array(xs), np.array(ys), np.array(zs)
221 | elif path_type == 'debug':
222 | xs = np.array([x, 0, -x, 0, 0])
223 | ys = np.array([0, y, 0, -y, 0])
224 | zs = np.array([0, 0, 0, 0, z])
225 | else:
226 | raise NotImplementedError
227 |
228 | xs, ys, zs = np.array(xs), np.array(ys), np.array(zs)
229 | if return_t_only:
230 | return np.stack([xs, ys, zs], axis=1) # [n, 3]
231 | for xx, yy, zz in zip(xs, ys, zs):
232 | tgt_poses.append(generic_pose * 1.)
233 | tgt_poses[-1][:3, -1] = np.array([xx, yy, zz])
234 | return tgt_poses
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | # # download DPT (https://github.com/isl-org/DPT) pretrained weights into DPT/weights
17 | # wget https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt -P third_party/DPT/weights
18 |
19 | # # download RAFT (https://github.com/princeton-vl/RAFT) pretrained weights into RAFT/models/
20 | # wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip
21 | # unzip models.zip -d third_party/RAFT/
22 | # rm -rf models.zip
23 |
24 | # download rgbd inpainting pretrained weights into inpainting_ckpts/
25 | wget https://filebox.ece.vt.edu/~jbhuang/project/3DPhoto/model/color-model.pth
26 | wget https://filebox.ece.vt.edu/~jbhuang/project/3DPhoto/model/depth-model.pth
27 | wget https://filebox.ece.vt.edu/~jbhuang/project/3DPhoto/model/edge-model.pth
28 | mkdir inpainting_ckpts/
29 | mv color-model.pth inpainting_ckpts/
30 | mv depth-model.pth inpainting_ckpts/
31 | mv edge-model.pth inpainting_ckpts/
32 |
33 | # download the 3D moments pretrained model:
34 | gdown https://drive.google.com/uc?id=1keqdnl2roBO2MjXhbd0VbfYaGAhyjUed -O pretrained/
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: reshader
2 | channels:
3 | - pytorch3d
4 | - pytorch
5 | - iopath
6 | - fvcore
7 | - conda-forge
8 | - defaults
9 | dependencies:
10 | - _libgcc_mutex=0.1=main
11 | - _openmp_mutex=4.5=1_gnu
12 | - blas=1.0=mkl
13 | - ca-certificates=2021.10.8=ha878542_0
14 | - certifi=2021.10.8=py37h89c1867_0
15 | - colorama=0.4.4=pyh9f0ad1d_0
16 | - cudatoolkit=10.1.243=h6bb024c_0
17 | - freetype=2.10.4=h5ab3b9f_0
18 | - fvcore=0.1.5.post20210915=py37
19 | - intel-openmp=2021.3.0=h06a4308_3350
20 | - iopath=0.1.9=py37
21 | - jpeg=9b=h024ee3a_2
22 | - lcms2=2.12=h3be6417_0
23 | - ld_impl_linux-64=2.35.1=h7274673_9
24 | - libffi=3.3=he6710b0_2
25 | - libgcc-ng=9.3.0=h5101ec6_17
26 | - libgomp=9.3.0=h5101ec6_17
27 | - libpng=1.6.37=hbc83047_0
28 | - libstdcxx-ng=9.3.0=hd4cf53a_17
29 | - libtiff=4.2.0=h85742a9_0
30 | - libuv=1.40.0=h7b6447c_0
31 | - libwebp-base=1.2.0=h27cfd23_0
32 | - lz4-c=1.9.3=h295c915_1
33 | - mkl=2021.3.0=h06a4308_520
34 | - mkl-service=2.4.0=py37h7f8727e_0
35 | - mkl_fft=1.3.0=py37h42c9631_2
36 | - mkl_random=1.2.2=py37h51133e4_0
37 | - ncurses=6.2=he6710b0_1
38 | - ninja=1.10.2=hff7bd54_1
39 | - numpy=1.20.3=py37hf144106_0
40 | - numpy-base=1.20.3=py37h74d4b33_0
41 | - olefile=0.46=py37_0
42 | - openjpeg=2.4.0=h3ad879b_0
43 | - openssl=1.1.1l=h7f8727e_0
44 | - pillow=8.3.1=py37h2c7a002_0
45 | - portalocker=2.3.2=py37h89c1867_0
46 | - python=3.7.11=h12debd9_0
47 | - python_abi=3.7=2_cp37m
48 | - pytorch=1.7.1=py3.7_cuda10.1.243_cudnn7.6.3_0
49 | - pytorch3d=0.6.0=py37_cu101_pyt171
50 | - pyyaml=5.4.1=py37h5e8e339_0
51 | - readline=8.1=h27cfd23_0
52 | - six=1.16.0=pyhd3eb1b0_0
53 | - sqlite=3.36.0=hc218d9a_0
54 | - tabulate=0.8.9=pyhd8ed1ab_0
55 | - termcolor=1.1.0=py_2
56 | - tk=8.6.11=h1ccaba5_0
57 | - torchaudio=0.7.2=py37
58 | - torchvision=0.8.2=py37_cu101
59 | - tqdm=4.62.3=pyhd8ed1ab_0
60 | - typing_extensions=3.10.0.2=pyh06a4308_0
61 | - wheel=0.37.0=pyhd3eb1b0_1
62 | - xz=5.2.5=h7b6447c_0
63 | - yacs=0.1.6=py_0
64 | - yaml=0.2.5=h516909a_0
65 | - zlib=1.2.11=h7b6447c_3
66 | - zstd=1.4.9=haebb681_0
67 | - pip:
68 | - absl-py==0.15.0
69 | - beautifulsoup4==4.10.0
70 | - cached-property==1.5.2
71 | - cachetools==4.2.4
72 | - charset-normalizer==2.0.7
73 | - configargparse==1.5.3
74 | - cupy-cuda101==9.6.0
75 | - cycler==0.10.0
76 | - dill==0.3.4
77 | - fastrlock==0.8
78 | - filelock==3.3.0
79 | - gdown==4.4.0
80 | - google-auth==2.3.0
81 | - google-auth-oauthlib==0.4.6
82 | - grpcio==1.41.0
83 | - h5py==3.5.0
84 | - idna==3.2
85 | - imageio==2.9.0
86 | - imageio-ffmpeg==0.4.5
87 | - importlib-metadata==4.8.1
88 | - joblib==1.1.0
89 | - kiwisolver==1.3.2
90 | - kornia==0.5.11
91 | - lpips==0.1.4
92 | - markdown==3.3.4
93 | - matplotlib==3.4.3
94 | - networkx==2.6.3
95 | - oauthlib==3.1.1
96 | - opencv-python==4.5.3.56
97 | - packaging==21.0
98 | - pandas==1.3.4
99 | - pip==22.1.2
100 | - plotly==5.3.1
101 | - protobuf==3.18.1
102 | - pyasn1==0.4.8
103 | - pyasn1-modules==0.2.8
104 | - pyparsing==2.4.7
105 | - pyquaternion==0.9.9
106 | - pysocks==1.7.1
107 | - python-dateutil==2.8.2
108 | - pytz==2021.3
109 | - pywavelets==1.1.1
110 | - requests==2.26.0
111 | - requests-oauthlib==1.3.0
112 | - rsa==4.7.2
113 | - scikit-image==0.18.3
114 | - scikit-learn==1.0
115 | - scipy==1.7.1
116 | - setuptools==62.4.0
117 | - sklearn==0.0
118 | - soupsieve==2.2.1
119 | - tenacity==8.0.1
120 | - tensorboard==2.7.0
121 | - tensorboard-data-server==0.6.1
122 | - tensorboard-plugin-wit==1.8.0
123 | - tensorboardx==2.4
124 | - threadpoolctl==3.0.0
125 | - tifffile==2021.10.12
126 | - timm==0.4.5
127 | - urllib3==1.26.7
128 | - werkzeug==2.0.2
129 | - zipp==3.6.0
130 | prefix: /phoenix/S7/qw246/anaconda3/envs/3d_moments
131 |
--------------------------------------------------------------------------------
/examples/bottle/disp.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/bottle/disp.npy
--------------------------------------------------------------------------------
/examples/bottle/input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/bottle/input.jpg
--------------------------------------------------------------------------------
/examples/burger/disp.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/burger/disp.npy
--------------------------------------------------------------------------------
/examples/burger/input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/burger/input.jpg
--------------------------------------------------------------------------------
/examples/camera/disp.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/camera/disp.npy
--------------------------------------------------------------------------------
/examples/camera/input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/camera/input.jpg
--------------------------------------------------------------------------------
/examples/car/disp.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/car/disp.npy
--------------------------------------------------------------------------------
/examples/car/input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/car/input.jpg
--------------------------------------------------------------------------------
/examples/fireworks/disp.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/fireworks/disp.npy
--------------------------------------------------------------------------------
/examples/fireworks/input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/fireworks/input.jpg
--------------------------------------------------------------------------------
/examples/rook/disp.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/rook/disp.npy
--------------------------------------------------------------------------------
/examples/rook/input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/rook/input.jpg
--------------------------------------------------------------------------------
/examples/spoon/disp.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/spoon/disp.npy
--------------------------------------------------------------------------------
/examples/spoon/input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/examples/spoon/input.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class down(nn.Module):
7 | """
8 | A class for creating neural network blocks containing layers:
9 |
10 | Average Pooling --> Convlution + Leaky ReLU --> Convolution + Leaky ReLU
11 |
12 | This is used in the UNet Class to create a UNet like NN architecture.
13 |
14 | ...
15 |
16 | Methods
17 | -------
18 | forward(x)
19 | Returns output tensor after passing input `x` to the neural network
20 | block.
21 | """
22 |
23 |
24 | def __init__(self, inChannels, outChannels, filterSize):
25 | """
26 | Parameters
27 | ----------
28 | inChannels : int
29 | number of input channels for the first convolutional layer.
30 | outChannels : int
31 | number of output channels for the first convolutional layer.
32 | This is also used as input and output channels for the
33 | second convolutional layer.
34 | filterSize : int
35 | filter size for the convolution filter. input N would create
36 | a N x N filter.
37 | """
38 |
39 |
40 | super(down, self).__init__()
41 | # Initialize convolutional layers.
42 | self.conv1 = nn.Conv2d(inChannels, outChannels, filterSize, stride=1, padding=int((filterSize - 1) / 2))
43 | self.conv2 = nn.Conv2d(outChannels, outChannels, filterSize, stride=1, padding=int((filterSize - 1) / 2))
44 |
45 | def forward(self, x):
46 | """
47 | Returns output tensor after passing input `x` to the neural network
48 | block.
49 |
50 | Parameters
51 | ----------
52 | x : tensor
53 | input to the NN block.
54 |
55 | Returns
56 | -------
57 | tensor
58 | output of the NN block.
59 | """
60 |
61 |
62 | # Average pooling with kernel size 2 (2 x 2).
63 | x = F.avg_pool2d(x, 2)
64 | # Convolution + Leaky ReLU
65 | x = F.leaky_relu(self.conv1(x), negative_slope = 0.1)
66 | # Convolution + Leaky ReLU
67 | x = F.leaky_relu(self.conv2(x), negative_slope = 0.1)
68 | return x
69 |
70 | class up(nn.Module):
71 | """
72 | A class for creating neural network blocks containing layers:
73 |
74 | Bilinear interpolation --> Convlution + Leaky ReLU --> Convolution + Leaky ReLU
75 |
76 | This is used in the UNet Class to create a UNet like NN architecture.
77 |
78 | ...
79 |
80 | Methods
81 | -------
82 | forward(x, skpCn)
83 | Returns output tensor after passing input `x` to the neural network
84 | block.
85 | """
86 |
87 |
88 | def __init__(self, inChannels, outChannels, vChannels=0):
89 | """
90 | Parameters
91 | ----------
92 | inChannels : int
93 | number of input channels for the first convolutional layer.
94 | outChannels : int
95 | number of output channels for the first convolutional layer.
96 | This is also used for setting input and output channels for
97 | the second convolutional layer.
98 | """
99 |
100 |
101 | super(up, self).__init__()
102 | self.vChannels = vChannels
103 | # Initialize convolutional layers.
104 | self.conv1 = nn.Conv2d(inChannels, outChannels, 3, stride=1, padding=1)
105 | # (2 * outChannels) is used for accommodating skip connection.
106 | self.conv2 = nn.Conv2d(2 * outChannels+vChannels, outChannels, 3, stride=1, padding=1)
107 |
108 | def forward(self, x, skpCn, v=None):
109 | """
110 | Returns output tensor after passing input `x` to the neural network
111 | block.
112 |
113 | Parameters
114 | ----------
115 | x : tensor
116 | input to the NN block.
117 | skpCn : tensor
118 | skip connection input to the NN block.
119 |
120 | Returns
121 | -------
122 | tensor
123 | output of the NN block.
124 | """
125 |
126 | # Bilinear interpolation with scaling 2.
127 | x = F.interpolate(x, scale_factor=2, mode='bilinear')
128 | # Convolution + Leaky ReLU
129 | x = F.leaky_relu(self.conv1(x), negative_slope = 0.1)
130 | # Convolution + Leaky ReLU on (`x`, `skpCn`)
131 | if v is None:
132 | x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope = 0.1)
133 | else:
134 | n, _, h, w = x.shape
135 | v = v.expand(n, self.vChannels, h, w)
136 | x = F.leaky_relu(self.conv2(torch.cat((x, skpCn, v), 1)), negative_slope = 0.1)
137 | return x
138 |
139 |
140 |
141 | class UNet(nn.Module):
142 | """
143 | A class for creating UNet like architecture as specified by the
144 | Super SloMo paper.
145 |
146 | ...
147 |
148 | Methods
149 | -------
150 | forward(x)
151 | Returns output tensor after passing input `x` to the neural network
152 | block.
153 | """
154 |
155 |
156 | def __init__(self, inChannels, outChannels, vChannels=128):
157 | """
158 | Parameters
159 | ----------
160 | inChannels : int
161 | number of input channels for the UNet.
162 | outChannels : int
163 | number of output channels for the UNet.
164 | """
165 |
166 |
167 | super(UNet, self).__init__()
168 | # Initialize neural network blocks.
169 | self.conv1 = nn.Conv2d(inChannels, 32, 7, stride=1, padding=3)
170 | self.conv2 = nn.Conv2d(32, 32, 7, stride=1, padding=3)
171 | self.down1 = down(32, 64, 5)
172 | self.down2 = down(64, 128, 3)
173 | self.down3 = down(128, 256, 3)
174 | self.down4 = down(256, 512, 3)
175 | self.down5 = down(512, 512, 3)
176 | self.up1 = up(512, 512, vChannels)
177 | self.up2 = up(512, 256)
178 | self.up3 = up(256, 128)
179 | self.up4 = up(128, 64)
180 | self.up5 = up(64, 32)
181 | self.conv3 = nn.Conv2d(32, outChannels, 3, stride=1, padding=1)
182 |
183 | def forward(self, x, v):
184 | """
185 | Returns output tensor after passing input `x` to the neural network.
186 |
187 | Parameters
188 | ----------
189 | x : tensor
190 | input to the UNet.
191 |
192 | Returns
193 | -------
194 | tensor
195 | output of the UNet.
196 | """
197 | inp = x
198 | v = v[:, :, None, None]
199 |
200 | x = F.leaky_relu(self.conv1(x), negative_slope = 0.1)
201 | s1 = F.leaky_relu(self.conv2(x), negative_slope = 0.1)
202 | s2 = self.down1(s1)
203 | s3 = self.down2(s2)
204 | s4 = self.down3(s3)
205 | s5 = self.down4(s4)
206 | x = self.down5(s5)
207 | x = self.up1(x, s5, v)
208 | x = self.up2(x, s4)
209 | x = self.up3(x, s3)
210 | x = self.up4(x, s2)
211 | x = self.up5(x, s1)
212 | x = torch.tanh(self.conv3(x))
213 |
214 | x = x + inp[:, :3]
215 |
216 | return x
217 |
218 | class FCN(nn.Module):
219 |
220 | def __init__(self):
221 |
222 | super(FCN, self).__init__()
223 |
224 | self.model = nn.Sequential(
225 | nn.Linear(3, 8),
226 | nn.LeakyReLU(),
227 | nn.Linear(8, 32),
228 | nn.LeakyReLU(),
229 | nn.Linear(32, 125),
230 | nn.LeakyReLU(),
231 | nn.Linear(125, 125)
232 | )
233 |
234 | def forward(self, inp):
235 |
236 | return torch.cat((self.model(inp), inp), 1)
--------------------------------------------------------------------------------
/model_3dm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | import torch
18 | from networks.resunet import ResUNet
19 | from networks.img_decoder import ImgDecoder
20 | from utils import de_parallel
21 |
22 |
23 | class Namespace:
24 |
25 | def __init__(self, **kwargs):
26 | for name in kwargs:
27 | setattr(self, name, kwargs[name])
28 |
29 | def __eq__(self, other):
30 | if not isinstance(other, Namespace):
31 | return NotImplemented
32 | return vars(self) == vars(other)
33 |
34 | def __contains__(self, key):
35 | return key in self.__dict__
36 |
37 |
38 |
39 | ########################################################################################################################
40 | # creation/saving/loading of the model
41 | ########################################################################################################################
42 |
43 |
44 | class SpaceTimeModel(object):
45 | def __init__(self, args):
46 | self.args = args
47 | load_opt = not args.no_load_opt
48 | load_scheduler = not args.no_load_scheduler
49 | device = torch.device('cuda:{}'.format(args.local_rank))
50 |
51 | # initialize feature extraction network
52 | feat_in_ch = 4
53 | if args.use_inpainting_mask_for_feature:
54 | feat_in_ch += 1
55 | if args.use_depth_for_feature:
56 | feat_in_ch += 1
57 | self.feature_net = ResUNet(args, in_ch=feat_in_ch, out_ch=args.feature_dim).to(device)
58 | # initialize decoder
59 | decoder_in_ch = args.feature_dim + 4
60 | decoder_out_ch = 3
61 |
62 | if args.use_depth_for_decoding:
63 | decoder_in_ch += 1
64 | if args.use_mask_for_decoding:
65 | decoder_in_ch += 1
66 |
67 | self.img_decoder = ImgDecoder(args, in_ch=decoder_in_ch, out_ch=decoder_out_ch).to(device)
68 |
69 | learnable_params = list(self.feature_net.parameters())
70 | learnable_params += list(self.img_decoder.parameters())
71 |
72 | self.learnable_params = learnable_params
73 | self.optimizer = torch.optim.Adam(learnable_params, lr=args.lr, weight_decay=1e-4, betas=(0.9, 0.999))
74 |
75 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
76 | step_size=args.lrate_decay_steps,
77 | gamma=args.lrate_decay_factor)
78 |
79 | out_folder = os.path.join(args.rootdir, 'out', args.expname)
80 | self.start_step = self.load_from_ckpt(out_folder,
81 | load_opt=load_opt,
82 | load_scheduler=load_scheduler)
83 |
84 | if args.distributed:
85 | self.feature_net = torch.nn.parallel.DistributedDataParallel(
86 | self.feature_net,
87 | device_ids=[args.local_rank],
88 | output_device=args.local_rank,
89 | )
90 |
91 | self.img_decoder = torch.nn.parallel.DistributedDataParallel(
92 | self.img_decoder,
93 | device_ids=[args.local_rank],
94 | output_device=args.local_rank,
95 | )
96 |
97 | def switch_to_eval(self):
98 | self.feature_net.eval()
99 | self.img_decoder.eval()
100 |
101 | def switch_to_train(self):
102 | self.feature_net.train()
103 | self.img_decoder.train()
104 |
105 | def save_model(self, filename):
106 | to_save = {'optimizer': self.optimizer.state_dict(),
107 | 'scheduler': self.scheduler.state_dict(),
108 | 'feature_net': de_parallel(self.feature_net).state_dict(),
109 | 'img_decoder': de_parallel(self.img_decoder).state_dict()
110 | }
111 | torch.save(to_save, filename)
112 |
113 | def load_model(self, filename, load_opt=True, load_scheduler=True):
114 | if self.args.distributed:
115 | to_load = torch.load(filename, map_location='cuda:{}'.format(self.args.local_rank))
116 | else:
117 | to_load = torch.load(filename)
118 |
119 | if load_opt:
120 | self.optimizer.load_state_dict(to_load['optimizer'])
121 | if load_scheduler:
122 | self.scheduler.load_state_dict(to_load['scheduler'])
123 |
124 | self.feature_net.load_state_dict(to_load['feature_net'])
125 | self.img_decoder.load_state_dict(to_load['img_decoder'])
126 |
127 | def load_from_ckpt(self, out_folder,
128 | load_opt=True,
129 | load_scheduler=True,
130 | force_latest_ckpt=False):
131 | '''
132 | load model from existing checkpoints and return the current step
133 | :param out_folder: the directory that stores ckpts
134 | :return: the current starting step
135 | '''
136 |
137 | # all existing ckpts
138 | ckpts = []
139 | if os.path.exists(out_folder):
140 | ckpts = [os.path.join(out_folder, f)
141 | for f in sorted(os.listdir(out_folder)) if f.endswith('.pth')]
142 |
143 | if self.args.ckpt_path is not None and not force_latest_ckpt:
144 | if os.path.isfile(self.args.ckpt_path): # load the specified ckpt
145 | ckpts = [self.args.ckpt_path]
146 |
147 | if len(ckpts) > 0 and not self.args.no_reload:
148 | fpath = ckpts[-1]
149 | self.load_model(fpath, load_opt, load_scheduler)
150 | step = int(fpath[-10:-4])
151 | print('Reloading from {}, starting at step={}'.format(fpath, step))
152 | else:
153 | print('No ckpts found, training from scratch...')
154 | step = 0
155 |
156 | return step
157 |
158 |
159 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/networks/__init__.py
--------------------------------------------------------------------------------
/networks/img_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | """ Parts of the U-Net model """
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 |
22 |
23 | class DoubleConv(nn.Module):
24 | """(convolution => [GN] => PReLU) * 2"""
25 |
26 | def __init__(self, in_channels, out_channels, mid_channels=None, group=None):
27 | super().__init__()
28 | if not mid_channels:
29 | mid_channels = out_channels
30 | #if not group:
31 | # group = out_channels//16
32 | self.double_conv1 = nn.Sequential(
33 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, padding_mode='reflect'),
34 | nn.GroupNorm(mid_channels//32, mid_channels),
35 | nn.PReLU()
36 | )
37 | self.double_conv2 = nn.Sequential(
38 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, padding_mode='reflect'),
39 | nn.GroupNorm(out_channels//32, out_channels),
40 | nn.PReLU()
41 | )
42 |
43 | def forward(self, x):
44 | x1 = self.double_conv1(x)
45 | x2 = self.double_conv2(x1)
46 | return x1, x2
47 |
48 |
49 | class Down(nn.Module):
50 | """Downscaling with maxpool then double conv"""
51 |
52 | def __init__(self, in_channels, out_channels, group=None):
53 | super().__init__()
54 | self.maxpool_conv = nn.Sequential(
55 | nn.MaxPool2d(2),
56 | DoubleConv(in_channels, out_channels, group)
57 | )
58 |
59 | def forward(self, x):
60 | return self.maxpool_conv(x)
61 |
62 |
63 | class ConcatDoubleConv(nn.Module):
64 | """(convolution => [GN] => PReLU) * 2"""
65 |
66 | def __init__(self, in_channels, out_channels, mid_channels=None, group=None):
67 | super().__init__()
68 | if not mid_channels:
69 | mid_channels = out_channels
70 | #if not group:
71 | # group = out_channels//16
72 | self.double_conv1 = nn.Sequential(
73 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, padding_mode='reflect'),
74 | nn.GroupNorm(mid_channels//32, mid_channels),
75 | nn.PReLU()
76 | )
77 | self.double_conv2 = nn.Sequential(
78 | nn.Conv2d(mid_channels*2, out_channels, kernel_size=3, padding=1, padding_mode='reflect'),
79 | nn.GroupNorm(out_channels//32, out_channels),
80 | nn.PReLU()
81 | )
82 |
83 | def forward(self, x, xc1):
84 | x1 = self.double_conv1(x)
85 | x2 = self.double_conv2(torch.cat([xc1, x1], dim=1))
86 | return x2
87 |
88 |
89 | class Up(nn.Module):
90 | """Upscaling then double conv"""
91 | def __init__(self, in_channels, out_channels, mid_channels, bilinear=True):
92 | super().__init__()
93 |
94 | # if bilinear, use the normal convolutions to reduce the number of channels
95 | if bilinear:
96 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
97 | self.conv = ConcatDoubleConv(in_channels, out_channels, mid_channels)
98 | else:
99 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
100 | self.conv = ConcatDoubleConv(in_channels, out_channels, mid_channels)
101 |
102 | def forward(self, x, xc1, xc2):
103 | x1 = self.up(x)
104 | # input is CHW
105 | diffY = xc1.size()[2] - x1.size()[2]
106 | diffX = xc1.size()[3] - x1.size()[3]
107 |
108 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
109 | diffY // 2, diffY - diffY // 2], mode='reflect')
110 | # if you have padding issues, see
111 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
112 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
113 | x = torch.cat([xc1, x1], dim=1)
114 | return self.conv(x, xc2)
115 |
116 |
117 | class OutConv(nn.Module):
118 | def __init__(self, in_channels, out_channels):
119 | super(OutConv, self).__init__()
120 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
121 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect')
122 | #self.gn = nn.GroupNorm(1, out_channels)
123 | self.act = nn.Sigmoid()
124 |
125 | def forward(self, x, xc1):
126 | x1 = self.up(x)
127 | x2 = torch.cat([xc1, x1], dim=1)
128 | #return self.act(self.gn(self.conv(x2)))
129 | return self.act(self.conv(x2))
130 |
131 |
132 | class ImgDecoder(nn.Module):
133 | def __init__(self, args, in_ch, out_ch):
134 | super(ImgDecoder, self).__init__()
135 | self.firstconv = nn.Conv2d(in_ch, 32, kernel_size=7, stride=1, padding=3, bias=False) #256x256
136 | self.firstprelu = nn.PReLU()
137 | self.down1 = Down(32, 64) #128x128
138 | self.down2 = Down(64, 128) #64x64
139 | self.down3 = Down(128, 256) #32x32
140 | self.down4 = Down(256, 512) #16x16
141 | self.conv5 = DoubleConv(512, 512) #16x16
142 |
143 | self.up1 = Up(1024, 256, 512)
144 | self.up2 = Up(512, 128, 256)
145 | self.up3 = Up(256, 64, 128)
146 | self.up4 = Up(128, 32, 64)
147 | self.outc = OutConv(64, out_ch)
148 | self.alpha = nn.Parameter(torch.tensor(1.))
149 |
150 | def compute_weight_for_two_frame_blending(self, time, disp1, disp2, alpha0, alpha1):
151 | weight1 = (1 - time) * torch.exp(self.alpha*disp1) * alpha0
152 | weight2 = time * torch.exp(self.alpha*disp2) * alpha1
153 | sum_weight = torch.clamp(weight1 + weight2, min=1e-6)
154 | out_weight1 = weight1 / sum_weight
155 | out_weight2 = weight2 / sum_weight
156 | return out_weight1, out_weight2
157 |
158 | def forward(self, x0, x1, time):
159 | disp0 = x0[:, [-1]]
160 | disp1 = x1[:, [-1]]
161 | alpha0 = x0[:, [3]]
162 | alpha1 = x1[:, [3]]
163 | w0, w1 = self.compute_weight_for_two_frame_blending(time, disp0, disp1, alpha0, alpha1)
164 | x = w0 * x0 + w1 * x1
165 |
166 | x0 = self.firstprelu(self.firstconv(x))
167 | x20, x21 = self.down1(x0)
168 | x30, x31 = self.down2(x21)
169 | x40, x41 = self.down3(x31)
170 | x50, x51 = self.down4(x41)
171 | x60, x61 = self.conv5(x51)
172 |
173 | xt1 = self.up1(x61, x51, x50)
174 | xt2 = self.up2(xt1, x41, x40)
175 | xt3 = self.up3(xt2, x31, x30)
176 | xt4 = self.up4(xt3, x21, x20)
177 | target_img = self.outc(xt4, x0)
178 |
179 | return target_img
--------------------------------------------------------------------------------
/networks/inpainting_nets.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import math
17 | import torch
18 | import torch.nn as nn
19 | import numpy as np
20 | import torch.nn.functional as F
21 |
22 |
23 | class BaseNetwork(nn.Module):
24 | def __init__(self):
25 | super(BaseNetwork, self).__init__()
26 |
27 | def init_weights(self, init_type='normal', gain=0.02):
28 | '''
29 | initialize network's weights
30 | init_type: normal | xavier | kaiming | orthogonal
31 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
32 | '''
33 |
34 | def init_func(m):
35 | classname = m.__class__.__name__
36 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
37 | if init_type == 'normal':
38 | nn.init.normal_(m.weight.data, 0.0, gain)
39 | elif init_type == 'xavier':
40 | nn.init.xavier_normal_(m.weight.data, gain=gain)
41 | elif init_type == 'kaiming':
42 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
43 | elif init_type == 'orthogonal':
44 | nn.init.orthogonal_(m.weight.data, gain=gain)
45 |
46 | if hasattr(m, 'bias') and m.bias is not None:
47 | nn.init.constant_(m.bias.data, 0.0)
48 |
49 | elif classname.find('BatchNorm2d') != -1:
50 | nn.init.normal_(m.weight.data, 1.0, gain)
51 | nn.init.constant_(m.bias.data, 0.0)
52 |
53 | self.apply(init_func)
54 |
55 |
56 | def weights_init(init_type='gaussian'):
57 | def init_fun(m):
58 | classname = m.__class__.__name__
59 | if (classname.find('Conv') == 0 or classname.find(
60 | 'Linear') == 0) and hasattr(m, 'weight'):
61 | if init_type == 'gaussian':
62 | nn.init.normal_(m.weight, 0.0, 0.02)
63 | elif init_type == 'xavier':
64 | nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
65 | elif init_type == 'kaiming':
66 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
67 | elif init_type == 'orthogonal':
68 | nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
69 | elif init_type == 'default':
70 | pass
71 | else:
72 | assert 0, "Unsupported initialization: {}".format(init_type)
73 | if hasattr(m, 'bias') and m.bias is not None:
74 | nn.init.constant_(m.bias, 0.0)
75 |
76 | return init_fun
77 |
78 |
79 | class PartialConv(nn.Module):
80 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
81 | padding=0, dilation=1, groups=1, bias=True):
82 | super().__init__()
83 | self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
84 | stride, padding, dilation, groups, bias)
85 | self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
86 | stride, padding, dilation, groups, False)
87 | self.input_conv.apply(weights_init('kaiming'))
88 | self.slide_winsize = in_channels * kernel_size * kernel_size
89 |
90 | torch.nn.init.constant_(self.mask_conv.weight, 1.0)
91 |
92 | # mask is not updated
93 | for param in self.mask_conv.parameters():
94 | param.requires_grad = False
95 |
96 | def forward(self, input, mask):
97 | # http://masc.cs.gmu.edu/wiki/partialconv
98 | # C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M)
99 | # W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0)
100 | output = self.input_conv(input * mask)
101 | if self.input_conv.bias is not None:
102 | output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(
103 | output)
104 | else:
105 | output_bias = torch.zeros_like(output)
106 |
107 | with torch.no_grad():
108 | output_mask = self.mask_conv(mask)
109 |
110 | no_update_holes = output_mask == 0
111 |
112 | mask_sum = output_mask.masked_fill_(no_update_holes, 1.0)
113 |
114 | output_pre = ((output - output_bias) * self.slide_winsize) / mask_sum + output_bias
115 | output = output_pre.masked_fill_(no_update_holes, 0.0)
116 |
117 | new_mask = torch.ones_like(output)
118 | new_mask = new_mask.masked_fill_(no_update_holes, 0.0)
119 |
120 | return output, new_mask
121 |
122 |
123 | class PCBActiv(nn.Module):
124 | def __init__(self, in_ch, out_ch, bn=True, no_tracking_stats=False, sample='none-3', activ='relu',
125 | conv_bias=False):
126 | super().__init__()
127 | if sample == 'down-5':
128 | self.conv = PartialConv(in_ch, out_ch, 5, 2, 2, bias=conv_bias)
129 | elif sample == 'down-7':
130 | self.conv = PartialConv(in_ch, out_ch, 7, 2, 3, bias=conv_bias)
131 | elif sample == 'down-3':
132 | self.conv = PartialConv(in_ch, out_ch, 3, 2, 1, bias=conv_bias)
133 | else:
134 | self.conv = PartialConv(in_ch, out_ch, 3, 1, 1, bias=conv_bias)
135 |
136 | if bn:
137 | if no_tracking_stats:
138 | self.bn = nn.BatchNorm2d(out_ch, track_running_stats=False, affine=True)
139 | else:
140 | self.bn = nn.BatchNorm2d(out_ch)
141 | if activ == 'relu':
142 | self.activation = nn.ReLU()
143 | elif activ == 'leaky':
144 | self.activation = nn.LeakyReLU(negative_slope=0.2)
145 |
146 | def forward(self, input, input_mask):
147 | h, h_mask = self.conv(input, input_mask)
148 | if hasattr(self, 'bn'):
149 | h = self.bn(h)
150 | if hasattr(self, 'activation'):
151 | h = self.activation(h)
152 | return h, h_mask
153 |
154 |
155 | class Inpaint_Depth_Net(nn.Module):
156 | def __init__(self, layer_size=7, upsampling_mode='nearest'):
157 | super().__init__()
158 | in_channels = 4
159 | out_channels = 1
160 | self.freeze_enc_bn = False
161 | self.upsampling_mode = upsampling_mode
162 | self.layer_size = layer_size
163 | self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7', conv_bias=True)
164 | self.enc_2 = PCBActiv(64, 128, sample='down-5', conv_bias=True)
165 | self.enc_3 = PCBActiv(128, 256, sample='down-5')
166 | self.enc_4 = PCBActiv(256, 512, sample='down-3')
167 | for i in range(4, self.layer_size):
168 | name = 'enc_{:d}'.format(i + 1)
169 | setattr(self, name, PCBActiv(512, 512, sample='down-3'))
170 |
171 | for i in range(4, self.layer_size):
172 | name = 'dec_{:d}'.format(i + 1)
173 | setattr(self, name, PCBActiv(512 + 512, 512, activ='leaky'))
174 | self.dec_4 = PCBActiv(512 + 256, 256, activ='leaky')
175 | self.dec_3 = PCBActiv(256 + 128, 128, activ='leaky')
176 | self.dec_2 = PCBActiv(128 + 64, 64, activ='leaky')
177 | self.dec_1 = PCBActiv(64 + in_channels, out_channels,
178 | bn=False, activ=None, conv_bias=True)
179 |
180 | def add_border(self, input, mask_flag, PCONV=True):
181 | with torch.no_grad():
182 | h = input.shape[-2]
183 | w = input.shape[-1]
184 | require_len_unit = 2 ** self.layer_size
185 | residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit
186 | residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit
187 | enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device)
188 | if mask_flag:
189 | if PCONV is False:
190 | enlarge_input += 1.0
191 | enlarge_input = enlarge_input.clamp(0.0, 1.0)
192 | else:
193 | enlarge_input[:, 2, ...] = 0.0
194 | anchor_h = residual_h//2
195 | anchor_w = residual_w//2
196 | enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
197 |
198 | return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w]
199 |
200 | def forward_3P(self, mask, context, depth, edge, unit_length=128, cuda=None):
201 | input = torch.cat((depth, edge, context, mask), dim=1)
202 | n, c, h, w = input.shape
203 | residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h)
204 | residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w)
205 | anchor_h = residual_h//2
206 | anchor_w = residual_w//2
207 | enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda)
208 | enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
209 | # enlarge_input[:, 3] = 1. - enlarge_input[:, 3]
210 | depth_output = self.forward(enlarge_input)
211 | depth_output = depth_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w]
212 | # import pdb; pdb.set_trace()
213 |
214 | return depth_output
215 |
216 | def forward(self, input_feat, refine_border=False, sample=False, PCONV=True):
217 | input = input_feat
218 | input_mask = (input_feat[:, -2:-1] + input_feat[:, -1:]).clamp(0, 1).repeat(1, input.shape[1], 1, 1)
219 |
220 | vis_input = input.cpu().data.numpy()
221 | vis_input_mask = input_mask.cpu().data.numpy()
222 | H, W = input.shape[-2:]
223 | if refine_border is True:
224 | input, anchor = self.add_border(input, mask_flag=False)
225 | input_mask, anchor = self.add_border(input_mask, mask_flag=True, PCONV=PCONV)
226 | h_dict = {} # for the output of enc_N
227 | h_mask_dict = {} # for the output of enc_N
228 | h_dict['h_0'], h_mask_dict['h_0'] = input, input_mask
229 |
230 | h_key_prev = 'h_0'
231 | for i in range(1, self.layer_size + 1):
232 | l_key = 'enc_{:d}'.format(i)
233 | h_key = 'h_{:d}'.format(i)
234 | h_dict[h_key], h_mask_dict[h_key] = getattr(self, l_key)(
235 | h_dict[h_key_prev], h_mask_dict[h_key_prev])
236 | h_key_prev = h_key
237 |
238 | h_key = 'h_{:d}'.format(self.layer_size)
239 | h, h_mask = h_dict[h_key], h_mask_dict[h_key]
240 |
241 | for i in range(self.layer_size, 0, -1):
242 | enc_h_key = 'h_{:d}'.format(i - 1)
243 | dec_l_key = 'dec_{:d}'.format(i)
244 |
245 | h = F.interpolate(h, scale_factor=2, mode=self.upsampling_mode)
246 | h_mask = F.interpolate(h_mask, scale_factor=2, mode='nearest')
247 |
248 | h = torch.cat([h, h_dict[enc_h_key]], dim=1)
249 | h_mask = torch.cat([h_mask, h_mask_dict[enc_h_key]], dim=1)
250 | h, h_mask = getattr(self, dec_l_key)(h, h_mask)
251 | output = h
252 | if refine_border is True:
253 | h_mask = h_mask[..., anchor[0]:anchor[1], anchor[2]:anchor[3]]
254 | output = output[..., anchor[0]:anchor[1], anchor[2]:anchor[3]]
255 |
256 | return output
257 |
258 |
259 | class Inpaint_Edge_Net(BaseNetwork):
260 | def __init__(self, residual_blocks=8, init_weights=True):
261 | super(Inpaint_Edge_Net, self).__init__()
262 | in_channels = 7
263 | out_channels = 1
264 | self.encoder = []
265 | # 0
266 | self.encoder_0 = nn.Sequential(
267 | nn.ReflectionPad2d(3),
268 | spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0), True),
269 | nn.InstanceNorm2d(64, track_running_stats=False),
270 | nn.ReLU(True))
271 | # 1
272 | self.encoder_1 = nn.Sequential(
273 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), True),
274 | nn.InstanceNorm2d(128, track_running_stats=False),
275 | nn.ReLU(True))
276 | # 2
277 | self.encoder_2 = nn.Sequential(
278 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), True),
279 | nn.InstanceNorm2d(256, track_running_stats=False),
280 | nn.ReLU(True))
281 | # 3
282 | blocks = []
283 | for _ in range(residual_blocks):
284 | block = ResnetBlock(256, 2)
285 | blocks.append(block)
286 |
287 | self.middle = nn.Sequential(*blocks)
288 | # + 3
289 | self.decoder_0 = nn.Sequential(
290 | spectral_norm(nn.ConvTranspose2d(in_channels=256+256, out_channels=128, kernel_size=4, stride=2, padding=1), True),
291 | nn.InstanceNorm2d(128, track_running_stats=False),
292 | nn.ReLU(True))
293 | # + 2
294 | self.decoder_1 = nn.Sequential(
295 | spectral_norm(nn.ConvTranspose2d(in_channels=128+128, out_channels=64, kernel_size=4, stride=2, padding=1), True),
296 | nn.InstanceNorm2d(64, track_running_stats=False),
297 | nn.ReLU(True))
298 | # + 1
299 | self.decoder_2 = nn.Sequential(
300 | nn.ReflectionPad2d(3),
301 | nn.Conv2d(in_channels=64+64, out_channels=out_channels, kernel_size=7, padding=0),
302 | )
303 |
304 | if init_weights:
305 | self.init_weights()
306 |
307 | def add_border(self, input, channel_pad_1=None):
308 | h = input.shape[-2]
309 | w = input.shape[-1]
310 | require_len_unit = 16
311 | residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit
312 | residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit
313 | enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device)
314 | if channel_pad_1 is not None:
315 | for channel in channel_pad_1:
316 | enlarge_input[:, channel] = 1
317 | anchor_h = residual_h//2
318 | anchor_w = residual_w//2
319 | enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
320 |
321 | return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w]
322 |
323 | def forward_3P(self, mask, context, rgb, disp, edge, unit_length=128, cuda=None):
324 | input = torch.cat((rgb, disp/disp.max(), edge, context, mask), dim=1)
325 | n, c, h, w = input.shape
326 | residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h)
327 | residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w)
328 | anchor_h = residual_h//2
329 | anchor_w = residual_w//2
330 | enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda)
331 | enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
332 | edge_output = self.forward(enlarge_input)
333 | edge_output = edge_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w]
334 | return edge_output
335 |
336 | def forward(self, x, refine_border=False):
337 | if refine_border:
338 | x, anchor = self.add_border(x, [5])
339 | x1 = self.encoder_0(x)
340 | x2 = self.encoder_1(x1)
341 | x3 = self.encoder_2(x2)
342 | x4 = self.middle(x3)
343 | x5 = self.decoder_0(torch.cat((x4, x3), dim=1))
344 | x6 = self.decoder_1(torch.cat((x5, x2), dim=1))
345 | x7 = self.decoder_2(torch.cat((x6, x1), dim=1))
346 | x = torch.sigmoid(x7)
347 | if refine_border:
348 | x = x[..., anchor[0]:anchor[1], anchor[2]:anchor[3]]
349 |
350 | return x
351 |
352 |
353 | class Inpaint_Color_Net(nn.Module):
354 | def __init__(self, layer_size=7, upsampling_mode='nearest', add_hole_mask=False, add_two_layer=False, add_border=False):
355 | super().__init__()
356 | self.freeze_enc_bn = False
357 | self.upsampling_mode = upsampling_mode
358 | self.layer_size = layer_size
359 | in_channels = 6
360 | self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7')
361 | self.enc_2 = PCBActiv(64, 128, sample='down-5')
362 | self.enc_3 = PCBActiv(128, 256, sample='down-5')
363 | self.enc_4 = PCBActiv(256, 512, sample='down-3')
364 | self.enc_5 = PCBActiv(512, 512, sample='down-3')
365 | self.enc_6 = PCBActiv(512, 512, sample='down-3')
366 | self.enc_7 = PCBActiv(512, 512, sample='down-3')
367 |
368 | self.dec_7 = PCBActiv(512+512, 512, activ='leaky')
369 | self.dec_6 = PCBActiv(512+512, 512, activ='leaky')
370 |
371 | self.dec_5A = PCBActiv(512 + 512, 512, activ='leaky')
372 | self.dec_4A = PCBActiv(512 + 256, 256, activ='leaky')
373 | self.dec_3A = PCBActiv(256 + 128, 128, activ='leaky')
374 | self.dec_2A = PCBActiv(128 + 64, 64, activ='leaky')
375 | self.dec_1A = PCBActiv(64 + in_channels, 3, bn=False, activ=None, conv_bias=True)
376 | '''
377 | self.dec_5B = PCBActiv(512 + 512, 512, activ='leaky')
378 | self.dec_4B = PCBActiv(512 + 256, 256, activ='leaky')
379 | self.dec_3B = PCBActiv(256 + 128, 128, activ='leaky')
380 | self.dec_2B = PCBActiv(128 + 64, 64, activ='leaky')
381 | self.dec_1B = PCBActiv(64 + 4, 1, bn=False, activ=None, conv_bias=True)
382 | '''
383 | def cat(self, A, B):
384 | return torch.cat((A, B), dim=1)
385 |
386 | def upsample(self, feat, mask):
387 | feat = F.interpolate(feat, scale_factor=2, mode=self.upsampling_mode)
388 | mask = F.interpolate(mask, scale_factor=2, mode='nearest')
389 |
390 | return feat, mask
391 |
392 | def forward_3P(self, mask, context, rgb, edge, unit_length=128, cuda=None):
393 | input = torch.cat((rgb, edge, context, mask), dim=1)
394 | n, c, h, w = input.shape
395 | residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h) # + 128
396 | residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w) # + 256
397 | anchor_h = residual_h//2
398 | anchor_w = residual_w//2
399 | enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda)
400 | enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
401 | # enlarge_input[:, 3] = 1. - enlarge_input[:, 3]
402 | enlarge_input = enlarge_input.to(cuda)
403 | rgb_output = self.forward(enlarge_input)
404 | rgb_output = rgb_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w]
405 | return rgb_output
406 |
407 | def forward(self, input, add_border=False):
408 | input_mask = (input[:, -2:-1] + input[:, -1:]).clamp(0, 1)
409 | H, W = input.shape[-2:]
410 | f_0, h_0 = input, input_mask.repeat((1,input.shape[1],1,1))
411 | f_1, h_1 = self.enc_1(f_0, h_0)
412 | f_2, h_2 = self.enc_2(f_1, h_1)
413 | f_3, h_3 = self.enc_3(f_2, h_2)
414 | f_4, h_4 = self.enc_4(f_3, h_3)
415 | f_5, h_5 = self.enc_5(f_4, h_4)
416 | f_6, h_6 = self.enc_6(f_5, h_5)
417 | f_7, h_7 = self.enc_7(f_6, h_6)
418 |
419 | o_7, k_7 = self.upsample(f_7, h_7)
420 | o_6, k_6 = self.dec_7(self.cat(o_7, f_6), self.cat(k_7, h_6))
421 | o_6, k_6 = self.upsample(o_6, k_6)
422 | o_5, k_5 = self.dec_6(self.cat(o_6, f_5), self.cat(k_6, h_5))
423 | o_5, k_5 = self.upsample(o_5, k_5)
424 | o_5A, k_5A = o_5, k_5
425 | o_5B, k_5B = o_5, k_5
426 | ###############
427 | o_4A, k_4A = self.dec_5A(self.cat(o_5A, f_4), self.cat(k_5A, h_4))
428 | o_4A, k_4A = self.upsample(o_4A, k_4A)
429 | o_3A, k_3A = self.dec_4A(self.cat(o_4A, f_3), self.cat(k_4A, h_3))
430 | o_3A, k_3A = self.upsample(o_3A, k_3A)
431 | o_2A, k_2A = self.dec_3A(self.cat(o_3A, f_2), self.cat(k_3A, h_2))
432 | o_2A, k_2A = self.upsample(o_2A, k_2A)
433 | o_1A, k_1A = self.dec_2A(self.cat(o_2A, f_1), self.cat(k_2A, h_1))
434 | o_1A, k_1A = self.upsample(o_1A, k_1A)
435 | o_0A, k_0A = self.dec_1A(self.cat(o_1A, f_0), self.cat(k_1A, h_0))
436 |
437 | return torch.sigmoid(o_0A)
438 |
439 | def train(self, mode=True):
440 | """
441 | Override the default train() to freeze the BN parameters
442 | """
443 | super().train(mode)
444 | if self.freeze_enc_bn:
445 | for name, module in self.named_modules():
446 | if isinstance(module, nn.BatchNorm2d) and 'enc' in name:
447 | module.eval()
448 |
449 |
450 | class Discriminator(BaseNetwork):
451 | def __init__(self, use_sigmoid=True, use_spectral_norm=True, init_weights=True, in_channels=None):
452 | super(Discriminator, self).__init__()
453 | self.use_sigmoid = use_sigmoid
454 | self.conv1 = self.features = nn.Sequential(
455 | spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
456 | nn.LeakyReLU(0.2, inplace=True),
457 | )
458 |
459 | self.conv2 = nn.Sequential(
460 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
461 | nn.LeakyReLU(0.2, inplace=True),
462 | )
463 |
464 | self.conv3 = nn.Sequential(
465 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
466 | nn.LeakyReLU(0.2, inplace=True),
467 | )
468 |
469 | self.conv4 = nn.Sequential(
470 | spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
471 | nn.LeakyReLU(0.2, inplace=True),
472 | )
473 |
474 | self.conv5 = nn.Sequential(
475 | spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
476 | )
477 |
478 | if init_weights:
479 | self.init_weights()
480 |
481 | def forward(self, x):
482 | conv1 = self.conv1(x)
483 | conv2 = self.conv2(conv1)
484 | conv3 = self.conv3(conv2)
485 | conv4 = self.conv4(conv3)
486 | conv5 = self.conv5(conv4)
487 |
488 | outputs = conv5
489 | if self.use_sigmoid:
490 | outputs = torch.sigmoid(conv5)
491 |
492 | return outputs, [conv1, conv2, conv3, conv4, conv5]
493 |
494 |
495 | class ResnetBlock(nn.Module):
496 | def __init__(self, dim, dilation=1):
497 | super(ResnetBlock, self).__init__()
498 | self.conv_block = nn.Sequential(
499 | nn.ReflectionPad2d(dilation),
500 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not True), True),
501 | nn.InstanceNorm2d(dim, track_running_stats=False),
502 | nn.LeakyReLU(negative_slope=0.2),
503 |
504 | nn.ReflectionPad2d(1),
505 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not True), True),
506 | nn.InstanceNorm2d(dim, track_running_stats=False),
507 | )
508 |
509 | def forward(self, x):
510 | out = x + self.conv_block(x)
511 |
512 | # Remove ReLU at the end of the residual block
513 | # http://torch.ch/blog/2016/02/04/resnets.html
514 |
515 | return out
516 |
517 |
518 | def spectral_norm(module, mode=True):
519 | if mode:
520 | return nn.utils.spectral_norm(module)
521 |
522 | return module
523 |
--------------------------------------------------------------------------------
/networks/resunet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | import importlib
20 |
21 |
22 | def class_for_name(module_name, class_name):
23 | # load the module, will raise ImportError if module cannot be loaded
24 | m = importlib.import_module(module_name)
25 | return getattr(m, class_name)
26 |
27 |
28 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
29 | """3x3 convolution with padding"""
30 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
31 | padding=dilation, groups=groups, bias=False, dilation=dilation, padding_mode='reflect')
32 |
33 |
34 | def conv1x1(in_planes, out_planes, stride=1):
35 | """1x1 convolution"""
36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False, padding_mode='reflect')
37 |
38 |
39 | class BasicBlock(nn.Module):
40 | expansion = 1
41 |
42 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
43 | base_width=64, dilation=1, norm_layer=None):
44 | super(BasicBlock, self).__init__()
45 | if norm_layer is None:
46 | norm_layer = nn.BatchNorm2d
47 | if groups != 1 or base_width != 64:
48 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
49 | if dilation > 1:
50 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
51 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
52 | self.conv1 = conv3x3(inplanes, planes, stride)
53 | self.bn1 = norm_layer(planes, track_running_stats=False, affine=True)
54 | self.relu = nn.ReLU(inplace=True)
55 | self.conv2 = conv3x3(planes, planes)
56 | self.bn2 = norm_layer(planes, track_running_stats=False, affine=True)
57 | self.downsample = downsample
58 | self.stride = stride
59 |
60 | def forward(self, x):
61 | identity = x
62 |
63 | out = self.conv1(x)
64 | out = self.bn1(out)
65 | out = self.relu(out)
66 |
67 | out = self.conv2(out)
68 | out = self.bn2(out)
69 |
70 | if self.downsample is not None:
71 | identity = self.downsample(x)
72 |
73 | out += identity
74 | out = self.relu(out)
75 |
76 | return out
77 |
78 |
79 | class Bottleneck(nn.Module):
80 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
81 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
82 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
83 | # This variant is also known as ResNet V1.5 and improves accuracy according to
84 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
85 |
86 | expansion = 4
87 |
88 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
89 | base_width=64, dilation=1, norm_layer=None):
90 | super(Bottleneck, self).__init__()
91 | if norm_layer is None:
92 | norm_layer = nn.BatchNorm2d
93 | width = int(planes * (base_width / 64.)) * groups
94 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
95 | self.conv1 = conv1x1(inplanes, width)
96 | self.bn1 = norm_layer(width, track_running_stats=False, affine=True)
97 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
98 | self.bn2 = norm_layer(width, track_running_stats=False, affine=True)
99 | self.conv3 = conv1x1(width, planes * self.expansion)
100 | self.bn3 = norm_layer(planes * self.expansion, track_running_stats=False, affine=True)
101 | self.relu = nn.ReLU(inplace=True)
102 | self.downsample = downsample
103 | self.stride = stride
104 |
105 | def forward(self, x):
106 | identity = x
107 |
108 | out = self.conv1(x)
109 | out = self.bn1(out)
110 | out = self.relu(out)
111 |
112 | out = self.conv2(out)
113 | out = self.bn2(out)
114 | out = self.relu(out)
115 |
116 | out = self.conv3(out)
117 | out = self.bn3(out)
118 |
119 | if self.downsample is not None:
120 | identity = self.downsample(x)
121 |
122 | out += identity
123 | out = self.relu(out)
124 |
125 | return out
126 |
127 |
128 | class conv(nn.Module):
129 | def __init__(self, num_in_layers, num_out_layers, kernel_size, stride):
130 | super(conv, self).__init__()
131 | self.kernel_size = kernel_size
132 | self.conv = nn.Conv2d(num_in_layers,
133 | num_out_layers,
134 | kernel_size=kernel_size,
135 | stride=stride,
136 | padding=(self.kernel_size - 1) // 2,
137 | padding_mode='reflect')
138 | self.bn = nn.BatchNorm2d(num_out_layers, track_running_stats=False, affine=True)
139 |
140 | def forward(self, x):
141 | return F.elu(self.bn(self.conv(x)), inplace=True)
142 |
143 |
144 | class upconv(nn.Module):
145 | def __init__(self, num_in_layers, num_out_layers, kernel_size, scale):
146 | super(upconv, self).__init__()
147 | self.scale = scale
148 | self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1)
149 |
150 | def forward(self, x):
151 | x = nn.functional.interpolate(x, scale_factor=self.scale, align_corners=True, mode='bilinear')
152 | return self.conv(x)
153 |
154 |
155 | class ResUNet(nn.Module):
156 | def __init__(self, args,
157 | encoder='resnet34',
158 | in_ch=8,
159 | out_ch=32,
160 | norm_layer=None,
161 | ):
162 |
163 | super(ResUNet, self).__init__()
164 | assert encoder in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'], "Incorrect encoder type"
165 | if encoder in ['resnet18', 'resnet34']:
166 | filters = [64, 128, 256, 512]
167 | else:
168 | filters = [256, 512, 1024, 2048]
169 | # resnet = class_for_name("torchvision.models", encoder)(pretrained=True).to("cuda:{}".format(args.local_rank))
170 |
171 | # original
172 | layers = [3, 4, 6, 3]
173 | if norm_layer is None:
174 | norm_layer = nn.BatchNorm2d
175 | # norm_layer = nn.InstanceNorm2d
176 | self._norm_layer = norm_layer
177 | self.dilation = 1
178 | block = BasicBlock
179 | replace_stride_with_dilation = [False, False, False]
180 | self.inplanes = 64
181 | self.groups = 1
182 | self.base_width = 64
183 | self.conv1 = nn.Conv2d(in_ch, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False,
184 | padding_mode='reflect')
185 | self.bn1 = norm_layer(self.inplanes, track_running_stats=False, affine=True)
186 | self.relu = nn.ReLU(inplace=True)
187 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
188 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
189 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
190 | dilate=replace_stride_with_dilation[0])
191 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
192 | dilate=replace_stride_with_dilation[1])
193 |
194 | # if in_ch != 3: # Number of input channels
195 | # self.conv1 = nn.Conv2d(in_ch, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
196 | # else:
197 | # self.conv1 = resnet.conv1 # H/2
198 | # self.bn1 = resnet.bn1
199 | # self.relu = resnet.relu
200 | # self.maxpool = resnet.maxpool # H/4
201 | #
202 | # # encoder
203 | # self.layer1 = resnet.layer1 # H/4
204 | # self.layer2 = resnet.layer2 # H/8
205 | # self.layer3 = resnet.layer3 # H/16
206 |
207 | # decoder
208 | self.upconv3 = upconv(filters[2], 128, 3, 2)
209 | self.iconv3 = conv(filters[1] + 128, 128, 3, 1)
210 | self.upconv2 = upconv(128, 64, 3, 2)
211 | self.iconv2 = conv(filters[0] + 64, out_ch, 3, 1)
212 |
213 | # fine-level conv
214 | self.out_conv = nn.Conv2d(out_ch, out_ch, 1, 1)
215 |
216 | for m in self.modules():
217 | if isinstance(m, nn.Conv2d):
218 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
219 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
220 | nn.init.constant_(m.weight, 1)
221 | nn.init.constant_(m.bias, 0)
222 |
223 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
224 | norm_layer = self._norm_layer
225 | downsample = None
226 | previous_dilation = self.dilation
227 | if dilate:
228 | self.dilation *= stride
229 | stride = 1
230 | if stride != 1 or self.inplanes != planes * block.expansion:
231 | downsample = nn.Sequential(
232 | conv1x1(self.inplanes, planes * block.expansion, stride),
233 | norm_layer(planes * block.expansion, track_running_stats=False, affine=True),
234 | )
235 |
236 | layers = []
237 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
238 | self.base_width, previous_dilation, norm_layer))
239 | self.inplanes = planes * block.expansion
240 | for _ in range(1, blocks):
241 | layers.append(block(self.inplanes, planes, groups=self.groups,
242 | base_width=self.base_width, dilation=self.dilation,
243 | norm_layer=norm_layer))
244 |
245 | return nn.Sequential(*layers)
246 |
247 | def skipconnect(self, x1, x2):
248 | diffY = x2.size()[2] - x1.size()[2]
249 | diffX = x2.size()[3] - x1.size()[3]
250 |
251 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
252 | diffY // 2, diffY - diffY // 2))
253 |
254 | # for padding issues, see
255 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
256 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
257 |
258 | x = torch.cat([x2, x1], dim=1)
259 | return x
260 |
261 | def forward(self, x):
262 | x = self.relu(self.bn1(self.conv1(x)))
263 | x = self.maxpool(x)
264 |
265 | x1 = self.layer1(x)
266 | x2 = self.layer2(x1)
267 | x3 = self.layer3(x2)
268 |
269 | x = self.upconv3(x3)
270 | x = self.skipconnect(x2, x)
271 | x = self.iconv3(x)
272 |
273 | x = self.upconv2(x)
274 | x = self.skipconnect(x1, x)
275 | x = self.iconv2(x)
276 |
277 | x_out = self.out_conv(x)
278 |
279 | return x_out
280 |
--------------------------------------------------------------------------------
/posenc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class Embedder:
4 |
5 | def __init__(self, **kwargs):
6 |
7 | self.kwargs = kwargs
8 | self.create_embedding_fn()
9 |
10 | def create_embedding_fn(self):
11 |
12 | embed_fns = []
13 | d = self.kwargs['input_dims']
14 | out_dim = 0
15 | if self.kwargs['include_input']:
16 | embed_fns.append(lambda x: x)
17 | out_dim += d
18 |
19 | max_freq = self.kwargs['max_freq_log2']
20 | N_freqs = self.kwargs['num_freqs']
21 |
22 | if self.kwargs['log_sampling']:
23 | freq_bands = 2.**np.linspace(0., max_freq, N_freqs)
24 | else:
25 | freq_bands = np.linspace(2.**0., 2.**max_freq, N_freqs)
26 |
27 | for freq in freq_bands:
28 | for p_fn in self.kwargs['periodic_fns']:
29 | embed_fns.append(lambda x, p_fn=p_fn,
30 | freq=freq: p_fn(x * freq))
31 | out_dim += d
32 |
33 | self.embed_fns = embed_fns
34 | self.out_dim = out_dim
35 |
36 | def embed(self, inputs):
37 |
38 | # print([fn(inputs) for fn in self.embed_fns])
39 | # exit()
40 | return np.stack([fn(inputs) for fn in self.embed_fns], -1)
41 |
42 |
43 | def get_embedder(multires, i=0):
44 |
45 | if i == -1:
46 | return np.identity, 3
47 |
48 | embed_kwargs = {
49 | 'include_input': True,
50 | 'input_dims': 3,
51 | 'max_freq_log2': multires-1,
52 | 'num_freqs': multires,
53 | 'log_sampling': True,
54 | 'periodic_fns': [np.sin, np.cos],
55 | }
56 |
57 | embedder_obj = Embedder(**embed_kwargs)
58 | def embed(x, eo=embedder_obj): return eo.embed(x)
59 | return embed
--------------------------------------------------------------------------------
/renderer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import glob
16 | import time
17 | import imageio
18 | import cv2
19 | from PIL import ImageFile
20 |
21 | import config
22 | import math
23 | import torchvision
24 | from torchvision.transforms import Pad
25 | import torch.utils.data.distributed
26 | from tqdm import tqdm
27 | from utils import *
28 | from model_3dm import SpaceTimeModel
29 | from core.utils import *
30 | from core.renderer import ImgRenderer
31 | from core.inpainter import Inpainter
32 | from model import UNet, FCN
33 | from posenc import get_embedder
34 | ImageFile.LOAD_TRUNCATED_IMAGES = True
35 |
36 |
37 | def process_boundary_mask(mask):
38 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
39 | closing = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=5)
40 | dilation = cv2.dilate(closing, kernel, iterations=1, borderType=cv2.BORDER_CONSTANT, borderValue=0.)
41 | return dilation
42 |
43 |
44 | def get_input_data(args, ds_factor=1):
45 | to_tensor = torchvision.transforms.ToTensor()
46 | input_dir = args.input_dir
47 | img_file = (sorted(glob.glob(os.path.join(input_dir, 'input.png'))) + \
48 | sorted(glob.glob(os.path.join(input_dir, 'input.jpg'))))[0]
49 | src_img = imageio.imread(img_file) / 255.
50 | h1, w1 = src_img.shape[:2]
51 |
52 | disparity = (np.abs(np.load(os.path.join(input_dir, 'disp.npy'))))
53 | disparity = np.maximum(disparity, 1e-3)
54 | src_depth = (65535/disparity) / args.dscl
55 |
56 | fy = w1 / (2 * np.tan(np.deg2rad(args.fov) / 2))
57 | fx = fy
58 |
59 | intrinsic = np.array([[fx, 0, w1 // 2],
60 | [0, fy, h1 // 2],
61 | [0, 0, 1]])
62 |
63 | pose = np.eye(4)
64 | return {
65 | 'src_img1': to_tensor(src_img).float()[None],
66 | 'src_depth1': to_tensor(src_depth).float()[None],
67 | 'intrinsic1': torch.from_numpy(intrinsic).float()[None],
68 | 'tgt_intrinsic': torch.from_numpy(intrinsic).float()[None],
69 | 'pose': torch.from_numpy(pose).float()[None],
70 | 'scale_shift1': torch.tensor([1., 0.]).float()[None],
71 | 'src_rgb_file1': [img_file],
72 | 'multi_view': [False]
73 | }
74 |
75 | def reshading(reshader, dir_model, rgbd, xyz):
76 |
77 | xyz = xyz[None] * torch.tensor([-1, 1, 1]).cuda()
78 | pos_feat = dir_model(xyz)
79 | pred = reshader(rgbd, pos_feat)
80 | pred = torch.clamp(pred, 0, 1)
81 |
82 | return pred
83 |
84 | def render(args):
85 | to_tensor = torchvision.transforms.ToTensor()
86 | device = "cuda:{}".format(args.local_rank)
87 |
88 | print('========================= Reshading Init...=========================')
89 | runpath = "runs/"
90 |
91 | reshader = UNet(14, 3).cuda()
92 | dir_model = FCN().cuda()
93 | reshader.load_state_dict(torch.load(f'{runpath}{args.model_dir}/model{args.ckpt}.pth'))
94 | dir_model.load_state_dict(torch.load(f'{runpath}{args.model_dir}/dir_model{args.ckpt}.pth'))
95 | embed_fn = get_embedder(args.pos_enc_freq)
96 |
97 |
98 | print('=========================run 3D Moments...=========================')
99 |
100 | data = get_input_data(args)
101 | rgb_file1 = data['src_rgb_file1'][0]
102 | frame_id1 = os.path.basename(rgb_file1).split('.')[0]
103 | scene_id = rgb_file1.split('/')[-3]
104 |
105 | video_out_folder = os.path.join(args.input_dir, 'out')
106 | os.makedirs(video_out_folder, exist_ok=True)
107 |
108 | im_h, im_w = data['src_img1'].shape[2:]
109 | pad_h, pad_w = (32 * math.ceil(im_h / 32) - im_h), (32 * math.ceil(im_w / 32) - im_w)
110 | padder = Pad(padding=(0, 0, pad_w, pad_h))
111 |
112 | model = SpaceTimeModel(args)
113 | if model.start_step == 0:
114 | raise Exception('no pretrained model found! please check the model path.')
115 |
116 | inpainter = Inpainter(args)
117 | renderer = ImgRenderer(args, model, None, inpainter, device)
118 |
119 | model.switch_to_eval()
120 | with torch.no_grad():
121 | renderer.process_data_single(data)
122 |
123 | pts1, rgb1, feat1, mask, side_ids = \
124 | renderer.render_rgbda_layers_from_one_view(return_pts=True)
125 |
126 | num_frames = 60#[60, 60, 60, 90]
127 | video_paths = ['circle']#['up-down', 'zoom-in', 'side', 'circle']
128 | Ts = [
129 | # define_camera_path(num_frames[0], 0., -0.08, 0., path_type='double-straight-line', return_t_only=True),
130 | # define_camera_path(num_frames[1], 0., 0., -0.24, path_type='straight-line', return_t_only=True),
131 | # define_camera_path(num_frames[2], -0.09, 0, -0, path_type='double-straight-line', return_t_only=True),
132 | # define_camera_path(num_frames, -0.15, -0.15, -0.15, path_type='circle', return_t_only=True),
133 | # define_camera_path(num_frames, -0.14, -0.14, -0.14, path_type='circle', return_t_only=True),
134 | define_camera_path(num_frames, -0.14, -0.14, -0.14, path_type='circle', return_t_only=True),
135 | ]
136 | crop = 32
137 | ##### the max value of the relative coordinates (above) should not exceed 0.3 (max used in training)
138 |
139 | ref_input = data['src_img1']
140 | for j, T in enumerate(Ts):
141 | print(video_paths[j])
142 | T = torch.from_numpy(T).float().to(renderer.device)
143 | time_steps = np.linspace(0, 1, num_frames)
144 | frames = []
145 | reshaded_frames = []
146 |
147 | for i, t_step in tqdm(enumerate(time_steps), total=len(time_steps),
148 | desc='generating video of {} camera trajectory'.format(video_paths[j])):
149 |
150 | ######### RESHADING INSERT#############
151 | disparity = 1 / data['src_depth1']
152 | # disparity is scaled by 4 during training (should not exceed 0.25 (max during training))
153 | disparity = (disparity / 4)
154 | disparity = torch.Tensor(embed_fn(disparity[0])[0]).permute(2, 0, 1)[None]
155 |
156 | rgbd = padder(torch.cat((ref_input, disparity), 1).cuda())
157 | reshaded = reshading(reshader, dir_model, rgbd, T[i])[:, :, :im_h, :im_w]
158 | reshaded_frames.append((255. * reshaded.detach().cpu().squeeze().permute(1, 2, 0).numpy()).astype(np.uint8))
159 |
160 | data['src_img1'] = reshaded
161 |
162 | renderer.process_data_single(data)
163 |
164 | pts1, rgb1, feat1, mask, side_ids = \
165 | renderer.render_rgbda_layers_from_one_view(return_pts=True)
166 | #################################
167 |
168 | pred_img, _, meta = renderer.render_pcd_single(pts1, rgb1,
169 | feat1, mask, side_ids,
170 | t=T[i], R=None, time=0)
171 | frame = (255. * pred_img.detach().cpu().squeeze().permute(1, 2, 0).numpy()).astype(np.uint8)
172 | # mask out fuzzy image boundaries due to no outpainting
173 | img_boundary_mask = (meta['acc'] > 0.5).detach().cpu().squeeze().numpy().astype(np.uint8)
174 | img_boundary_mask_cleaned = process_boundary_mask(img_boundary_mask)
175 | frame = frame * img_boundary_mask_cleaned[..., None]
176 | frame = frame[crop:-crop, crop:-crop]
177 | frames.append(frame)
178 |
179 | video_out_file = os.path.join(video_out_folder, '{}_{}-{}.mp4'.format(
180 | video_paths[j], scene_id, frame_id1))
181 | imageio.mimwrite(video_out_file, frames, fps=25, quality=8)
182 | imageio.mimwrite(f"{video_out_folder}/out_{args.model_dir}_model{args.ckpt}.mp4", reshaded_frames, fps=25, quality=8)
183 |
184 |
185 | print('output videos have been saved in {}.'.format(video_out_folder))
186 |
187 | if __name__ == '__main__':
188 | args = config.config_parser()
189 |
190 | render(args)
--------------------------------------------------------------------------------
/runs/reshader/dir_model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/runs/reshader/dir_model.pth
--------------------------------------------------------------------------------
/runs/reshader/model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avinashpaliwal/ReShader/7a138860cdc5075f5f92eacc67e520298a692957/runs/reshader/model.pth
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | import numpy as np
18 | import torch
19 | from datetime import datetime
20 | import shutil
21 |
22 | TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision
23 |
24 |
25 | def de_parallel(model):
26 | return model.module if hasattr(model, 'module') else model
27 |
28 |
29 | def cycle(iterable):
30 | while True:
31 | for x in iterable:
32 | yield x
33 |
34 |
35 | def dict_to_device(dict_):
36 | for k in dict_.keys():
37 | if type(dict_[k]) == torch.Tensor:
38 | dict_[k] = dict_[k].cuda()
39 |
40 | return dict_
41 |
42 |
43 | def save_current_code(outdir):
44 | now = datetime.now() # current date and time
45 | date_time = now.strftime("%m_%d-%H:%M:%S")
46 | src_dir = '.'
47 | code_out_dir = os.path.join(outdir, 'code')
48 | os.makedirs(code_out_dir, exist_ok=True)
49 | dst_dir = os.path.join(code_out_dir, '{}'.format(date_time))
50 | shutil.copytree(src_dir, dst_dir,
51 | ignore=shutil.ignore_patterns('pretrained*', '*logs*', 'out*', '*.png', '*.mp4', 'eval*',
52 | '*__pycache__*', '*.git*', '*.idea*', '*.zip', '*.jpg'))
53 |
54 |
55 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
56 | # assert isinstance(input, torch.Tensor)
57 | if posinf is None:
58 | posinf = torch.finfo(input.dtype).max
59 | if neginf is None:
60 | neginf = torch.finfo(input.dtype).min
61 | assert nan == 0
62 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
63 |
64 |
65 | def img2mse(x, y, mask=None):
66 | '''
67 | :param x: img 1, [(...), 3]
68 | :param y: img 2, [(...), 3]
69 | :param mask: optional, [(...)]
70 | :return: mse score
71 | '''
72 | if mask is None:
73 | return torch.mean((x - y) * (x - y))
74 | else:
75 | return torch.sum((x - y) * (x - y) * mask.unsqueeze(-1)) / (torch.sum(mask) * x.shape[-1] + TINY_NUMBER)
76 |
77 |
78 | mse2psnr = lambda x: -10. * np.log(x+TINY_NUMBER) / np.log(10.)
79 |
80 |
81 | def img2psnr(x, y, mask=None):
82 | return mse2psnr(img2mse(x, y, mask).item())
83 |
84 |
--------------------------------------------------------------------------------