├── .gitignore ├── LICENSE ├── README.md ├── conf ├── default.conf ├── default_mv.conf └── exp │ ├── dtu.conf │ ├── multi_obj.conf │ ├── sn64.conf │ ├── sn64_unseen.conf │ └── srn.conf ├── environment.yml ├── eval ├── calc_metrics.py ├── eval.py ├── eval_approx.py ├── eval_real.py └── gen_video.py ├── expconf.conf ├── input ├── model3.png ├── model3_normalize.png ├── police.jpg ├── police_normalize.png ├── toyota.jpg └── toyota_normalize.png ├── readme-img └── paper_teaser.jpg ├── requirements.txt ├── scripts ├── README.md ├── detectron2 │ ├── LICENSE │ ├── configs │ │ └── Base-RCNN-FPN.yaml │ └── projects │ │ └── PointRend │ │ ├── configs │ │ ├── InstanceSegmentation │ │ │ ├── Base-PointRend-RCNN-FPN.yaml │ │ │ ├── pointrend_rcnn_R_50_FPN_1x_cityscapes.yaml │ │ │ ├── pointrend_rcnn_R_50_FPN_1x_coco.yaml │ │ │ └── pointrend_rcnn_R_50_FPN_3x_coco.yaml │ │ └── SemanticSegmentation │ │ │ ├── Base-PointRend-Semantic-FPN.yaml │ │ │ └── pointrend_semantic_R_101_FPN_1x_cityscapes.yaml │ │ └── point_rend │ │ ├── __init__.py │ │ ├── coarse_mask_head.py │ │ ├── color_augmentation.py │ │ ├── config.py │ │ ├── point_features.py │ │ ├── point_head.py │ │ ├── roi_heads.py │ │ └── semantic_seg.py ├── preproc.py └── render_shapenet.py ├── src ├── data │ ├── DVRDataset.py │ ├── MultiObjectDataset.py │ ├── SRNDataset.py │ ├── __init__.py │ └── data_util.py ├── model │ ├── __init__.py │ ├── code.py │ ├── custom_encoder.py │ ├── encoder.py │ ├── loss.py │ ├── mlp.py │ ├── model_util.py │ ├── models.py │ └── resnetfc.py ├── render │ ├── __init__.py │ └── nerf.py └── util │ ├── __init__.py │ ├── args.py │ ├── recon.py │ └── util.py ├── train ├── train.py └── trainlib │ ├── __init__.py │ └── trainer.py └── viewlist ├── 2obj_eval_views.txt ├── src_dvr.txt ├── src_gen.txt └── srn_eval_views.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Editors 2 | .vscode/ 3 | .idea/ 4 | 5 | # Vagrant 6 | .vagrant/ 7 | 8 | # Mac/OSX 9 | .DS_Store 10 | 11 | # Windows 12 | Thumbs.db 13 | 14 | # Source for the following rules: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | *.swp 20 | 21 | # Custom 22 | checkpoints/ 23 | visuals/ 24 | logs/ 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2021, Alex Yu 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /conf/default.conf: -------------------------------------------------------------------------------- 1 | # Single-view only base model 2 | # (Not used in experiments; resnet_fine_mv.conf inherits) 3 | model { 4 | # Condition on local encoder 5 | use_encoder = True 6 | 7 | # Condition also on a global encoder? 8 | use_global_encoder = False 9 | 10 | # Use xyz input instead of just z 11 | # (didn't ablate) 12 | use_xyz = True 13 | 14 | # Canonical space xyz (default view space) 15 | canon_xyz = False 16 | 17 | # Positional encoding 18 | use_code = True 19 | code { 20 | num_freqs = 6 21 | freq_factor = 1.5 22 | include_input = True 23 | } 24 | 25 | # View directions 26 | use_viewdirs = True 27 | # Apply pos. enc. to viewdirs? 28 | use_code_viewdirs = False 29 | 30 | # MLP architecture 31 | mlp_coarse { 32 | type = resnet # Can change to mlp 33 | n_blocks = 3 34 | d_hidden = 512 35 | } 36 | mlp_fine { 37 | type = resnet 38 | n_blocks = 3 39 | d_hidden = 512 40 | } 41 | 42 | # Encoder architecture 43 | encoder { 44 | backbone = resnet34 45 | pretrained = True 46 | num_layers = 4 47 | } 48 | } 49 | renderer { 50 | n_coarse = 64 51 | n_fine = 32 52 | # Try using expected depth sample 53 | n_fine_depth = 16 54 | # Noise to add to depth sample 55 | depth_std = 0.01 56 | # Decay schedule, not used 57 | sched = [] 58 | # White background color (false : black) 59 | white_bkgd = True 60 | } 61 | loss { 62 | # RGB losses coarse/fine 63 | rgb { 64 | use_l1 = False 65 | } 66 | rgb_fine { 67 | use_l1 = False 68 | } 69 | # Alpha regularization (disabled in final version) 70 | alpha { 71 | # lambda_alpha = 0.0001 72 | lambda_alpha = 0.0 73 | clamp_alpha = 100 74 | init_epoch = 5 75 | } 76 | # Coarse/fine weighting (nerf = equal) 77 | lambda_coarse = 1.0 # loss = lambda_coarse * loss_coarse + loss_fine 78 | lambda_fine = 1.0 # loss = lambda_coarse * loss_coarse + loss_fine 79 | } 80 | train { 81 | # Training 82 | print_interval = 2 83 | save_interval = 50 84 | vis_interval = 100 85 | eval_interval = 50 86 | 87 | # Accumulating gradients. Not really recommended. 88 | # 1 = disable 89 | accu_grad = 1 90 | 91 | # Number of times to repeat dataset per 'epoch' 92 | # Useful if dataset is extremely small, like DTU 93 | num_epoch_repeats = 1 94 | } 95 | -------------------------------------------------------------------------------- /conf/default_mv.conf: -------------------------------------------------------------------------------- 1 | # Main multiview supported config 2 | include required("default.conf") 3 | model { 4 | # MLP architecture 5 | # Adapted for multiview 6 | # Possibly too big 7 | mlp_coarse { 8 | type = resnet 9 | n_blocks = 5 10 | d_hidden = 512 11 | # Combine after 3rd layer by average 12 | combine_layer = 3 13 | combine_type = average 14 | } 15 | mlp_fine { 16 | type = resnet 17 | n_blocks = 5 18 | d_hidden = 512 19 | combine_layer = 3 20 | combine_type = average 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /conf/exp/dtu.conf: -------------------------------------------------------------------------------- 1 | # DTU config 2 | include required("../default_mv.conf") 3 | train { 4 | num_epoch_repeats = 32 5 | vis_interval = 200 6 | } 7 | renderer { 8 | white_bkgd = False 9 | } 10 | data { 11 | format = dvr_dtu 12 | # ban_views = [3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 36, 37, 38, 39] 13 | } 14 | -------------------------------------------------------------------------------- /conf/exp/multi_obj.conf: -------------------------------------------------------------------------------- 1 | # Multi object config 2 | include required("../default_mv.conf") 3 | data { 4 | format = multi_obj 5 | } 6 | -------------------------------------------------------------------------------- /conf/exp/sn64.conf: -------------------------------------------------------------------------------- 1 | # Config for 64x64 images (NMR-SoftRas-DVR ShapeNet) 2 | # - Category agnostic 3 | include required("../default_mv.conf") 4 | model { 5 | encoder { 6 | # Skip first pooling layer to avoid reducing size too much 7 | use_first_pool=False 8 | } 9 | } 10 | data { 11 | format = dvr 12 | } 13 | -------------------------------------------------------------------------------- /conf/exp/sn64_unseen.conf: -------------------------------------------------------------------------------- 1 | include required("sn64.conf") 2 | data { 3 | format = dvr_gen 4 | } 5 | -------------------------------------------------------------------------------- /conf/exp/srn.conf: -------------------------------------------------------------------------------- 1 | # SRN experiments config 2 | include required("../default_mv.conf") 3 | data { 4 | format = srn 5 | } 6 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # run: conda env create -f environment.yml 2 | name: pixelnerf 3 | channels: 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - python>=3.8 8 | - pip 9 | - pip: 10 | - pyhocon 11 | - opencv-python 12 | - dotmap 13 | - tensorboard 14 | - imageio 15 | - imageio-ffmpeg 16 | - ipdb 17 | - pretrainedmodels 18 | - lpips 19 | - scipy 20 | - numpy 21 | - matplotlib 22 | - pytorch==1.6.0 23 | - torchvision==0.7.0 24 | - scikit-image==0.17.2 25 | - tqdm 26 | -------------------------------------------------------------------------------- /eval/calc_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compute metrics on rendered images (after eval.py). 3 | First computes per-object metric then reduces them. If --multicat is used 4 | then also summarized per-categority metrics. Use --reduce_only to skip the 5 | per-object computation step. 6 | 7 | Note eval.py already outputs PSNR/SSIM. 8 | This also computes LPIPS and is useful for double-checking metric is correct. 9 | """ 10 | 11 | import os 12 | import os.path as osp 13 | import argparse 14 | import skimage.measure 15 | from tqdm import tqdm 16 | import warnings 17 | import lpips 18 | import numpy as np 19 | import torch 20 | import imageio 21 | import json 22 | 23 | parser = argparse.ArgumentParser(description="Calculate PSNR for rendered images.") 24 | parser.add_argument( 25 | "--datadir", 26 | "-D", 27 | type=str, 28 | default="/home/group/chairs_test", 29 | help="Dataset directory; note: different from usual, directly use the thing please", 30 | ) 31 | parser.add_argument( 32 | "--output", 33 | "-O", 34 | type=str, 35 | default="eval", 36 | help="Root path of rendered output (our format, from eval.py)", 37 | ) 38 | parser.add_argument( 39 | "--dataset_format", 40 | "-F", 41 | type=str, 42 | default="dvr", 43 | help="Dataset format, nerf | srn | dvr", 44 | ) 45 | parser.add_argument( 46 | "--list_name", type=str, default="softras_test", help="Filter list prefix for DVR", 47 | ) 48 | parser.add_argument( 49 | "--gpu_id", 50 | type=int, 51 | default=0, 52 | help="GPU id. Only single GPU supported for this script.", 53 | ) 54 | parser.add_argument( 55 | "--overwrite", action="store_true", help="overwriting existing metrics.txt", 56 | ) 57 | parser.add_argument( 58 | "--exclude_dtu_bad", action="store_true", help="exclude hardcoded DTU bad views", 59 | ) 60 | parser.add_argument( 61 | "--multicat", 62 | action="store_true", 63 | help="Prepend category id to object id. Specify if model fits multiple categories.", 64 | ) 65 | 66 | parser.add_argument( 67 | "--viewlist", 68 | "-L", 69 | type=str, 70 | default="", 71 | help="Path to source view list e.g. src_dvr.txt; if specified, excludes the source view from evaluation", 72 | ) 73 | parser.add_argument( 74 | "--eval_view_list", type=str, default=None, help="Path to eval view list" 75 | ) 76 | parser.add_argument( 77 | "--primary", "-P", type=str, default="", help="List of views to exclude" 78 | ) 79 | parser.add_argument( 80 | "--lpips_batch_size", type=int, default=32, help="Batch size for LPIPS", 81 | ) 82 | 83 | parser.add_argument( 84 | "--reduce_only", 85 | "-R", 86 | action="store_true", 87 | help="skip the map (per-obj metric computation)", 88 | ) 89 | parser.add_argument( 90 | "--metadata", 91 | type=str, 92 | default="metadata.yaml", 93 | help="Path to dataset metadata under datadir, used for getting category names if --multicat", 94 | ) 95 | parser.add_argument( 96 | "--dtu_sort", action="store_true", help="Sort using DTU scene order instead of lex" 97 | ) 98 | args = parser.parse_args() 99 | 100 | 101 | if args.dataset_format == "dvr": 102 | list_name = args.list_name + ".lst" 103 | img_dir_name = "image" 104 | elif args.dataset_format == "srn": 105 | list_name = "" 106 | img_dir_name = "rgb" 107 | elif args.dataset_format == "nerf": 108 | warnings.warn("test split not implemented for NeRF synthetic data format") 109 | list_name = "" 110 | img_dir_name = "" 111 | else: 112 | raise NotImplementedError("Not supported data format " + args.dataset_format) 113 | 114 | 115 | data_root = args.datadir 116 | render_root = args.output 117 | 118 | 119 | def run_map(): 120 | if args.multicat: 121 | cats = os.listdir(data_root) 122 | 123 | def fmt_obj_name(c, x): 124 | return c + "_" + x 125 | 126 | else: 127 | cats = ["."] 128 | 129 | def fmt_obj_name(c, x): 130 | return x 131 | 132 | use_exclude_lut = len(args.viewlist) > 0 133 | if use_exclude_lut: 134 | print("Excluding views from list", args.viewlist) 135 | with open(args.viewlist, "r") as f: 136 | tmp = [x.strip().split() for x in f.readlines()] 137 | exclude_lut = { 138 | x[0] + "/" + x[1]: torch.tensor(list(map(int, x[2:])), dtype=torch.long) 139 | for x in tmp 140 | } 141 | base_exclude_views = list(map(int, args.primary.split())) 142 | if args.exclude_dtu_bad: 143 | base_exclude_views.extend( 144 | [3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 36, 37, 38, 39] 145 | ) 146 | 147 | if args.eval_view_list is not None: 148 | with open(args.eval_view_list, "r") as f: 149 | eval_views = list(map(int, f.readline().split())) 150 | print("Only using views", eval_views) 151 | else: 152 | eval_views = None 153 | 154 | all_objs = [] 155 | 156 | print("CATEGORICAL SUMMARY") 157 | total_objs = 0 158 | 159 | for cat in cats: 160 | cat_root = osp.join(data_root, cat) 161 | if not osp.isdir(cat_root): 162 | continue 163 | 164 | objs = sorted([x for x in os.listdir(cat_root)]) 165 | 166 | if len(list_name) > 0: 167 | list_path = osp.join(cat_root, list_name) 168 | with open(list_path, "r") as f: 169 | split = set([x.strip() for x in f.readlines()]) 170 | objs = [x for x in objs if x in split] 171 | 172 | objs_rend = [osp.join(render_root, fmt_obj_name(cat, x)) for x in objs] 173 | 174 | objs = [osp.join(cat_root, x) for x in objs] 175 | objs = [x for x in objs if osp.isdir(x)] 176 | 177 | objs = list(zip(objs, objs_rend)) 178 | objs_avail = [x for x in objs if osp.exists(x[1])] 179 | print(cat, "TOTAL", len(objs), "AVAILABLE", len(objs_avail)) 180 | # assert len(objs) == len(objs_avail) 181 | total_objs += len(objs) 182 | all_objs.extend(objs_avail) 183 | print(">>> USING", len(all_objs), "OF", total_objs, "OBJECTS") 184 | 185 | cuda = "cuda:" + str(args.gpu_id) 186 | lpips_vgg = lpips.LPIPS(net="vgg").to(device=cuda) 187 | 188 | def get_metrics(rgb, gt): 189 | ssim = skimage.measure.compare_ssim(rgb, gt, multichannel=True, data_range=1) 190 | psnr = skimage.measure.compare_psnr(rgb, gt, data_range=1) 191 | return psnr, ssim 192 | 193 | def isimage(path): 194 | ext = osp.splitext(path)[1] 195 | return ext == ".jpg" or ext == ".png" 196 | 197 | def process_obj(path, rend_path): 198 | if len(img_dir_name) > 0: 199 | im_root = osp.join(path, img_dir_name) 200 | else: 201 | im_root = path 202 | out_path = osp.join(rend_path, "metrics.txt") 203 | if osp.exists(out_path) and not args.overwrite: 204 | return 205 | ims = [x for x in sorted(os.listdir(im_root)) if isimage(x)] 206 | psnr_avg = 0.0 207 | ssim_avg = 0.0 208 | gts = [] 209 | preds = [] 210 | num_ims = 0 211 | if use_exclude_lut: 212 | lut_key = osp.basename(rend_path).replace("_", "/") 213 | exclude_views = exclude_lut[lut_key] 214 | else: 215 | exclude_views = [] 216 | exclude_views.extend(base_exclude_views) 217 | 218 | for im_name in ims: 219 | im_path = osp.join(im_root, im_name) 220 | im_name_id = int(osp.splitext(im_name)[0]) 221 | im_name_out = "{:06}.png".format(im_name_id) 222 | im_rend_path = osp.join(rend_path, im_name_out) 223 | if osp.exists(im_rend_path) and im_name_id not in exclude_views: 224 | if eval_views is not None and im_name_id not in eval_views: 225 | continue 226 | gt = imageio.imread(im_path).astype(np.float32)[..., :3] / 255.0 227 | pred = imageio.imread(im_rend_path).astype(np.float32) / 255.0 228 | 229 | psnr, ssim = get_metrics(pred, gt) 230 | psnr_avg += psnr 231 | ssim_avg += ssim 232 | gts.append(torch.from_numpy(gt).permute(2, 0, 1) * 2.0 - 1.0) 233 | preds.append(torch.from_numpy(pred).permute(2, 0, 1) * 2.0 - 1.0) 234 | num_ims += 1 235 | gts = torch.stack(gts) 236 | preds = torch.stack(preds) 237 | 238 | lpips_all = [] 239 | preds_spl = torch.split(preds, args.lpips_batch_size, dim=0) 240 | gts_spl = torch.split(gts, args.lpips_batch_size, dim=0) 241 | with torch.no_grad(): 242 | for predi, gti in zip(preds_spl, gts_spl): 243 | lpips_i = lpips_vgg(predi.to(device=cuda), gti.to(device=cuda)) 244 | lpips_all.append(lpips_i) 245 | lpips = torch.cat(lpips_all) 246 | lpips = lpips.mean().item() 247 | psnr_avg /= num_ims 248 | ssim_avg /= num_ims 249 | out_txt = "psnr {}\nssim {}\nlpips {}".format(psnr_avg, ssim_avg, lpips) 250 | with open(out_path, "w") as f: 251 | f.write(out_txt) 252 | 253 | for obj_path, obj_rend_path in tqdm(all_objs): 254 | process_obj(obj_path, obj_rend_path) 255 | 256 | 257 | def run_reduce(): 258 | if args.multicat: 259 | meta = json.load(open(osp.join(args.datadir, args.metadata), "r")) 260 | cats = sorted(list(meta.keys())) 261 | cat_description = {cat: meta[cat]["name"].split(",")[0] for cat in cats} 262 | 263 | all_objs = [] 264 | objs = [x for x in os.listdir(render_root)] 265 | objs = [osp.join(render_root, x) for x in objs if x[0] != "_"] 266 | objs = [x for x in objs if osp.isdir(x)] 267 | if args.dtu_sort: 268 | objs = sorted(objs, key=lambda x: int(x[x.rindex("/") + 5 :])) 269 | else: 270 | objs = sorted(objs) 271 | all_objs.extend(objs) 272 | 273 | print(">>> PROCESSING", len(all_objs), "OBJECTS") 274 | 275 | METRIC_NAMES = ["psnr", "ssim", "lpips"] 276 | 277 | out_metrics_path = osp.join(render_root, "all_metrics.txt") 278 | 279 | if args.multicat: 280 | cat_sz = {} 281 | for cat in cats: 282 | cat_sz[cat] = 0 283 | 284 | all_metrics = {} 285 | for name in METRIC_NAMES: 286 | if args.multicat: 287 | for cat in cats: 288 | all_metrics[cat + "." + name] = 0.0 289 | all_metrics[name] = 0.0 290 | 291 | should_print_all_objs = len(all_objs) < 100 292 | 293 | for obj_root in tqdm(all_objs): 294 | metrics_path = osp.join(obj_root, "metrics.txt") 295 | with open(metrics_path, "r") as f: 296 | metrics = [line.split() for line in f.readlines()] 297 | if args.multicat: 298 | cat_name = osp.basename(obj_root).split("_")[0] 299 | cat_sz[cat_name] += 1 300 | for metric, val in metrics: 301 | all_metrics[cat_name + "." + metric] += float(val) 302 | 303 | for metric, val in metrics: 304 | all_metrics[metric] += float(val) 305 | if should_print_all_objs: 306 | print(obj_root, end=" ") 307 | for metric, val in metrics: 308 | print(val, end=" ") 309 | print() 310 | 311 | for name in METRIC_NAMES: 312 | if args.multicat: 313 | for cat in cats: 314 | if cat_sz[cat] > 0: 315 | all_metrics[cat + "." + name] /= cat_sz[cat] 316 | all_metrics[name] /= len(all_objs) 317 | print(name, all_metrics[name]) 318 | 319 | metrics_txt = [] 320 | if args.multicat: 321 | for cat in cats: 322 | if cat_sz[cat] > 0: 323 | cat_txt = "{:12s}".format(cat_description[cat]) 324 | for name in METRIC_NAMES: 325 | cat_txt += " {}: {:.6f}".format(name, all_metrics[cat + "." + name]) 326 | cat_txt += " n_inst: {}".format(cat_sz[cat]) 327 | metrics_txt.append(cat_txt) 328 | 329 | total_txt = "---\n{:12s}".format("total") 330 | else: 331 | total_txt = "" 332 | for name in METRIC_NAMES: 333 | total_txt += " {}: {:.6f}".format(name, all_metrics[name]) 334 | metrics_txt.append(total_txt) 335 | 336 | metrics_txt = "\n".join(metrics_txt) 337 | with open(out_metrics_path, "w") as f: 338 | f.write(metrics_txt) 339 | print("WROTE", out_metrics_path) 340 | print(metrics_txt) 341 | 342 | 343 | if __name__ == "__main__": 344 | if not args.reduce_only: 345 | print(">>> Compute") 346 | run_map() 347 | print(">>> Reduce") 348 | run_reduce() 349 | -------------------------------------------------------------------------------- /eval/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full evaluation script, including PSNR+SSIM evaluation with multi-GPU support. 3 | 4 | python eval.py --gpu_id= -n -c -D /home/group/data/chairs -F srn 5 | """ 6 | import sys 7 | import os 8 | 9 | sys.path.insert( 10 | 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")) 11 | ) 12 | 13 | import torch 14 | import numpy as np 15 | import imageio 16 | import skimage.measure 17 | import util 18 | from data import get_split_dataset 19 | from model import make_model 20 | from render import NeRFRenderer 21 | import cv2 22 | import tqdm 23 | import ipdb 24 | import warnings 25 | 26 | # from pytorch_memlab import set_target_gpu 27 | # set_target_gpu(9) 28 | 29 | 30 | def extra_args(parser): 31 | parser.add_argument( 32 | "--split", 33 | type=str, 34 | default="test", 35 | help="Split of data to use train | val | test", 36 | ) 37 | parser.add_argument( 38 | "--source", 39 | "-P", 40 | type=str, 41 | default="64", 42 | help="Source view(s) for each object. Alternatively, specify -L to viewlist file and leave this blank.", 43 | ) 44 | parser.add_argument( 45 | "--eval_view_list", type=str, default=None, help="Path to eval view list" 46 | ) 47 | parser.add_argument("--coarse", action="store_true", help="Coarse network as fine") 48 | parser.add_argument( 49 | "--no_compare_gt", 50 | action="store_true", 51 | help="Skip GT comparison (metric won't be computed) and only render images", 52 | ) 53 | parser.add_argument( 54 | "--multicat", 55 | action="store_true", 56 | help="Prepend category id to object id. Specify if model fits multiple categories.", 57 | ) 58 | parser.add_argument( 59 | "--viewlist", 60 | "-L", 61 | type=str, 62 | default="", 63 | help="Path to source view list e.g. src_dvr.txt; if specified, overrides source/P", 64 | ) 65 | 66 | parser.add_argument( 67 | "--output", 68 | "-O", 69 | type=str, 70 | default="eval", 71 | help="If specified, saves generated images to directory", 72 | ) 73 | parser.add_argument( 74 | "--include_src", action="store_true", help="Include source views in calculation" 75 | ) 76 | parser.add_argument( 77 | "--scale", type=float, default=1.0, help="Video scale relative to input size" 78 | ) 79 | parser.add_argument("--write_depth", action="store_true", help="Write depth image") 80 | parser.add_argument( 81 | "--write_compare", action="store_true", help="Write GT comparison image" 82 | ) 83 | parser.add_argument( 84 | "--free_pose", 85 | action="store_true", 86 | help="Set to indicate poses may change between objects. In most of our datasets, the test set has fixed poses.", 87 | ) 88 | return parser 89 | 90 | 91 | args, conf = util.args.parse_args( 92 | extra_args, default_conf="conf/resnet_fine_mv.conf", default_expname="shapenet", 93 | ) 94 | args.resume = True 95 | 96 | device = util.get_cuda(args.gpu_id[0]) 97 | 98 | dset = get_split_dataset( 99 | args.dataset_format, args.datadir, want_split=args.split, training=False 100 | ) 101 | data_loader = torch.utils.data.DataLoader( 102 | dset, batch_size=1, shuffle=False, num_workers=8, pin_memory=False 103 | ) 104 | 105 | output_dir = args.output.strip() 106 | has_output = len(output_dir) > 0 107 | 108 | total_psnr = 0.0 109 | total_ssim = 0.0 110 | cnt = 0 111 | 112 | if has_output: 113 | finish_path = os.path.join(output_dir, "finish.txt") 114 | os.makedirs(output_dir, exist_ok=True) 115 | if os.path.exists(finish_path): 116 | with open(finish_path, "r") as f: 117 | lines = [x.strip().split() for x in f.readlines()] 118 | lines = [x for x in lines if len(x) == 4] 119 | finished = set([x[0] for x in lines]) 120 | total_psnr = sum((float(x[1]) for x in lines)) 121 | total_ssim = sum((float(x[2]) for x in lines)) 122 | cnt = sum((int(x[3]) for x in lines)) 123 | if cnt > 0: 124 | print("resume psnr", total_psnr / cnt, "ssim", total_ssim / cnt) 125 | else: 126 | total_psnr = 0.0 127 | total_ssim = 0.0 128 | else: 129 | finished = set() 130 | 131 | finish_file = open(finish_path, "a", buffering=1) 132 | print("Writing images to", output_dir) 133 | 134 | 135 | net = make_model(conf["model"]).to(device=device).load_weights(args) 136 | renderer = NeRFRenderer.from_conf( 137 | conf["renderer"], lindisp=dset.lindisp, eval_batch_size=args.ray_batch_size 138 | ).to(device=device) 139 | if args.coarse: 140 | net.mlp_fine = None 141 | 142 | if renderer.n_coarse < 64: 143 | # Ensure decent sampling resolution 144 | renderer.n_coarse = 64 145 | if args.coarse: 146 | renderer.n_coarse = 64 147 | renderer.n_fine = 128 148 | renderer.using_fine = True 149 | 150 | render_par = renderer.bind_parallel(net, args.gpu_id, simple_output=True).eval() 151 | 152 | z_near = dset.z_near 153 | z_far = dset.z_far 154 | 155 | use_source_lut = len(args.viewlist) > 0 156 | if use_source_lut: 157 | print("Using views from list", args.viewlist) 158 | with open(args.viewlist, "r") as f: 159 | tmp = [x.strip().split() for x in f.readlines()] 160 | source_lut = { 161 | x[0] + "/" + x[1]: torch.tensor(list(map(int, x[2:])), dtype=torch.long) 162 | for x in tmp 163 | } 164 | else: 165 | source = torch.tensor(sorted(list(map(int, args.source.split()))), dtype=torch.long) 166 | 167 | NV = dset[0]["images"].shape[0] 168 | 169 | if args.eval_view_list is not None: 170 | with open(args.eval_view_list, "r") as f: 171 | eval_views = torch.tensor(list(map(int, f.readline().split()))) 172 | target_view_mask = torch.zeros(NV, dtype=torch.bool) 173 | target_view_mask[eval_views] = 1 174 | else: 175 | target_view_mask = torch.ones(NV, dtype=torch.bool) 176 | target_view_mask_init = target_view_mask 177 | 178 | all_rays = None 179 | rays_spl = [] 180 | 181 | src_view_mask = None 182 | total_objs = len(data_loader) 183 | 184 | with torch.no_grad(): 185 | for obj_idx, data in enumerate(data_loader): 186 | print( 187 | "OBJECT", 188 | obj_idx, 189 | "OF", 190 | total_objs, 191 | "PROGRESS", 192 | obj_idx / total_objs * 100.0, 193 | "%", 194 | data["path"][0], 195 | ) 196 | dpath = data["path"][0] 197 | obj_basename = os.path.basename(dpath) 198 | cat_name = os.path.basename(os.path.dirname(dpath)) 199 | obj_name = cat_name + "_" + obj_basename if args.multicat else obj_basename 200 | if has_output and obj_name in finished: 201 | print("(skip)") 202 | continue 203 | images = data["images"][0] # (NV, 3, H, W) 204 | 205 | NV, _, H, W = images.shape 206 | 207 | if args.scale != 1.0: 208 | Ht = int(H * args.scale) 209 | Wt = int(W * args.scale) 210 | if abs(Ht / args.scale - H) > 1e-10 or abs(Wt / args.scale - W) > 1e-10: 211 | warnings.warn( 212 | "Inexact scaling, please check {} times ({}, {}) is integral".format( 213 | args.scale, H, W 214 | ) 215 | ) 216 | H, W = Ht, Wt 217 | 218 | if all_rays is None or use_source_lut or args.free_pose: 219 | if use_source_lut: 220 | obj_id = cat_name + "/" + obj_basename 221 | source = source_lut[obj_id] 222 | 223 | NS = len(source) 224 | src_view_mask = torch.zeros(NV, dtype=torch.bool) 225 | src_view_mask[source] = 1 226 | 227 | focal = data["focal"][0] 228 | if isinstance(focal, float): 229 | focal = torch.tensor(focal, dtype=torch.float32) 230 | focal = focal[None] 231 | 232 | c = data.get("c") 233 | if c is not None: 234 | c = c[0].to(device=device).unsqueeze(0) 235 | 236 | poses = data["poses"][0] # (NV, 4, 4) 237 | src_poses = poses[src_view_mask].to(device=device) # (NS, 4, 4) 238 | 239 | target_view_mask = target_view_mask_init.clone() 240 | if not args.include_src: 241 | target_view_mask *= ~src_view_mask 242 | 243 | novel_view_idxs = target_view_mask.nonzero(as_tuple=False).reshape(-1) 244 | 245 | poses = poses[target_view_mask] # (NV[-NS], 4, 4) 246 | 247 | all_rays = ( 248 | util.gen_rays( 249 | poses.reshape(-1, 4, 4), 250 | W, 251 | H, 252 | focal * args.scale, 253 | z_near, 254 | z_far, 255 | c=c * args.scale if c is not None else None, 256 | ) 257 | .reshape(-1, 8) 258 | .to(device=device) 259 | ) # ((NV[-NS])*H*W, 8) 260 | 261 | poses = None 262 | focal = focal.to(device=device) 263 | 264 | rays_spl = torch.split(all_rays, args.ray_batch_size, dim=0) # Creates views 265 | 266 | n_gen_views = len(novel_view_idxs) 267 | 268 | net.encode( 269 | images[src_view_mask].to(device=device).unsqueeze(0), 270 | src_poses.unsqueeze(0), 271 | focal, 272 | c=c, 273 | ) 274 | 275 | all_rgb, all_depth = [], [] 276 | for rays in tqdm.tqdm(rays_spl): 277 | rgb, depth = render_par(rays[None]) 278 | rgb = rgb[0].cpu() 279 | depth = depth[0].cpu() 280 | all_rgb.append(rgb) 281 | all_depth.append(depth) 282 | 283 | all_rgb = torch.cat(all_rgb, dim=0) 284 | all_depth = torch.cat(all_depth, dim=0) 285 | all_depth = (all_depth - z_near) / (z_far - z_near) 286 | all_depth = all_depth.reshape(n_gen_views, H, W).numpy() 287 | 288 | all_rgb = torch.clamp( 289 | all_rgb.reshape(n_gen_views, H, W, 3), 0.0, 1.0 290 | ).numpy() # (NV-NS, H, W, 3) 291 | if has_output: 292 | obj_out_dir = os.path.join(output_dir, obj_name) 293 | os.makedirs(obj_out_dir, exist_ok=True) 294 | for i in range(n_gen_views): 295 | out_file = os.path.join( 296 | obj_out_dir, "{:06}.png".format(novel_view_idxs[i].item()) 297 | ) 298 | imageio.imwrite(out_file, (all_rgb[i] * 255).astype(np.uint8)) 299 | 300 | if args.write_depth: 301 | out_depth_file = os.path.join( 302 | obj_out_dir, "{:06}_depth.exr".format(novel_view_idxs[i].item()) 303 | ) 304 | out_depth_norm_file = os.path.join( 305 | obj_out_dir, 306 | "{:06}_depth_norm.png".format(novel_view_idxs[i].item()), 307 | ) 308 | depth_cmap_norm = util.cmap(all_depth[i]) 309 | cv2.imwrite(out_depth_file, all_depth[i]) 310 | imageio.imwrite(out_depth_norm_file, depth_cmap_norm) 311 | 312 | curr_ssim = 0.0 313 | curr_psnr = 0.0 314 | if not args.no_compare_gt: 315 | images_0to1 = images * 0.5 + 0.5 # (NV, 3, H, W) 316 | images_gt = images_0to1[target_view_mask] 317 | rgb_gt_all = ( 318 | images_gt.permute(0, 2, 3, 1).contiguous().numpy() 319 | ) # (NV-NS, H, W, 3) 320 | for view_idx in range(n_gen_views): 321 | ssim = skimage.measure.compare_ssim( 322 | all_rgb[view_idx], 323 | rgb_gt_all[view_idx], 324 | multichannel=True, 325 | data_range=1, 326 | ) 327 | psnr = skimage.measure.compare_psnr( 328 | all_rgb[view_idx], rgb_gt_all[view_idx], data_range=1 329 | ) 330 | curr_ssim += ssim 331 | curr_psnr += psnr 332 | 333 | if args.write_compare: 334 | out_file = os.path.join( 335 | obj_out_dir, 336 | "{:06}_compare.png".format(novel_view_idxs[view_idx].item()), 337 | ) 338 | out_im = np.hstack((all_rgb[view_idx], rgb_gt_all[view_idx])) 339 | imageio.imwrite(out_file, (out_im * 255).astype(np.uint8)) 340 | curr_psnr /= n_gen_views 341 | curr_ssim /= n_gen_views 342 | curr_cnt = 1 343 | total_psnr += curr_psnr 344 | total_ssim += curr_ssim 345 | cnt += curr_cnt 346 | if not args.no_compare_gt: 347 | print( 348 | "curr psnr", 349 | curr_psnr, 350 | "ssim", 351 | curr_ssim, 352 | "running psnr", 353 | total_psnr / cnt, 354 | "running ssim", 355 | total_ssim / cnt, 356 | ) 357 | finish_file.write( 358 | "{} {} {} {}\n".format(obj_name, curr_psnr, curr_ssim, curr_cnt) 359 | ) 360 | print("final psnr", total_psnr / cnt, "ssim", total_ssim / cnt) 361 | -------------------------------------------------------------------------------- /eval/eval_approx.py: -------------------------------------------------------------------------------- 1 | """ 2 | Approximate PSNR+SSIM evaluation for use during development, since eval.py is too slow. 3 | Evaluates using only 1 random target view per object. You can try different --seed. 4 | 5 | python eval_approx.py --gpu_id= -n -c -D -F 6 | Add --seed to set random seed 7 | 8 | May not work for DTU. 9 | """ 10 | import sys 11 | import os 12 | 13 | sys.path.insert( 14 | 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")) 15 | ) 16 | 17 | import torch 18 | import numpy as np 19 | import imageio 20 | import skimage.measure 21 | import util 22 | from data import get_split_dataset 23 | from render import NeRFRenderer 24 | from model import make_model 25 | import tqdm 26 | 27 | 28 | def extra_args(parser): 29 | parser.add_argument( 30 | "--split", 31 | type=str, 32 | default="val", 33 | help="Split of data to use train | val | test", 34 | ) 35 | 36 | parser.add_argument( 37 | "--source", 38 | "-P", 39 | type=str, 40 | default="64", 41 | help="Source view(s) in image, in increasing order. -1 to use random 1 view.", 42 | ) 43 | 44 | parser.add_argument("--batch_size", type=int, default=4, help="Batch size") 45 | parser.add_argument( 46 | "--seed", 47 | type=int, 48 | default=1234, 49 | help="Random seed for selecting target views of each object", 50 | ) 51 | parser.add_argument("--coarse", action="store_true", help="Coarse network as fine") 52 | return parser 53 | 54 | 55 | args, conf = util.args.parse_args(extra_args) 56 | args.resume = True 57 | 58 | device = util.get_cuda(args.gpu_id[0]) 59 | net = make_model(conf["model"]).to(device=device) 60 | net.load_weights(args) 61 | 62 | if args.coarse: 63 | net.mlp_fine = None 64 | 65 | dset = get_split_dataset( 66 | args.dataset_format, args.datadir, want_split=args.split, training=False 67 | ) 68 | data_loader = torch.utils.data.DataLoader( 69 | dset, batch_size=args.batch_size, shuffle=False, num_workers=8, pin_memory=False 70 | ) 71 | 72 | renderer = NeRFRenderer.from_conf( 73 | conf["renderer"], eval_batch_size=args.ray_batch_size 74 | ).to(device=device) 75 | 76 | if renderer.n_coarse < 64: 77 | # Ensure decent sampling resolution 78 | renderer.n_coarse = 64 79 | if args.coarse: 80 | renderer.n_coarse = 64 81 | renderer.n_fine = 128 82 | renderer.using_fine = True 83 | 84 | render_par = renderer.bind_parallel(net, args.gpu_id, simple_output=True).eval() 85 | 86 | z_near = dset.z_near 87 | z_far = dset.z_far 88 | 89 | torch.random.manual_seed(args.seed) 90 | 91 | total_psnr = 0.0 92 | total_ssim = 0.0 93 | cnt = 0 94 | 95 | 96 | source = torch.tensor(list(map(int, args.source.split())), dtype=torch.long) 97 | NS = len(source) 98 | random_source = NS == 1 and source[0] == -1 99 | 100 | with torch.no_grad(): 101 | for data in tqdm.tqdm(data_loader, total=len(data_loader)): 102 | images = data["images"] # (SB, NV, 3, H, W) 103 | masks = data["masks"] # (SB, NV, 1, H, W) 104 | poses = data["poses"] # (SB, NV, 4, 4) 105 | focal = data["focal"][0] 106 | 107 | images_0to1 = images * 0.5 + 0.5 # (B, 3, H, W) 108 | 109 | SB, NV, _, H, W = images.shape 110 | 111 | if random_source: 112 | src_view = torch.randint(0, NV, (SB, 1)) 113 | else: 114 | src_view = source.unsqueeze(0).expand(SB, -1) 115 | 116 | dest_view = torch.randint(0, NV - NS, (SB, 1)) 117 | for i in range(NS): 118 | dest_view += dest_view >= src_view[:, i : i + 1] 119 | 120 | dest_poses = util.batched_index_select_nd(poses, dest_view) 121 | all_rays = util.gen_rays( 122 | dest_poses.reshape(-1, 4, 4), W, H, focal, z_near, z_far 123 | ).reshape(SB, -1, 8) 124 | 125 | pri_images = util.batched_index_select_nd(images, src_view) # (SB, NS, 3, H, W) 126 | pri_poses = util.batched_index_select_nd(poses, src_view) # (SB, NS, 4, 4) 127 | 128 | net.encode( 129 | pri_images.to(device=device), 130 | pri_poses.to(device=device), 131 | focal.to(device=device), 132 | ) 133 | 134 | rgb_fine, _depth = render_par(all_rays.to(device=device)) 135 | _depth = None 136 | rgb_fine = rgb_fine.reshape(SB, H, W, 3).cpu().numpy() 137 | images_gt = util.batched_index_select_nd(images_0to1, dest_view).reshape( 138 | SB, 3, H, W 139 | ) 140 | rgb_gt_all = images_gt.permute(0, 2, 3, 1).contiguous().numpy() 141 | 142 | for sb in range(SB): 143 | ssim = skimage.measure.compare_ssim( 144 | rgb_fine[sb], rgb_gt_all[sb], multichannel=True, data_range=1 145 | ) 146 | psnr = skimage.measure.compare_psnr( 147 | rgb_fine[sb], rgb_gt_all[sb], data_range=1 148 | ) 149 | total_ssim += ssim 150 | total_psnr += psnr 151 | cnt += SB 152 | print("curr psnr", total_psnr / cnt, "ssim", total_ssim / cnt) 153 | print("final psnr", total_psnr / cnt, "ssim", total_ssim / cnt) 154 | -------------------------------------------------------------------------------- /eval/eval_real.py: -------------------------------------------------------------------------------- 1 | """ 2 | Eval on real images from input/*_normalize.png, output to output/ 3 | """ 4 | import sys 5 | import os 6 | 7 | ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 8 | sys.path.insert(0, os.path.join(ROOT_DIR, "src")) 9 | 10 | import util 11 | import torch 12 | import numpy as np 13 | from model import make_model 14 | from render import NeRFRenderer 15 | import torchvision.transforms as T 16 | import tqdm 17 | import imageio 18 | from PIL import Image 19 | 20 | 21 | def extra_args(parser): 22 | parser.add_argument( 23 | "--input", 24 | "-I", 25 | type=str, 26 | default=os.path.join(ROOT_DIR, "input"), 27 | help="Image directory", 28 | ) 29 | parser.add_argument( 30 | "--output", 31 | "-O", 32 | type=str, 33 | default=os.path.join(ROOT_DIR, "output"), 34 | help="Output directory", 35 | ) 36 | parser.add_argument("--size", type=int, default=128, help="Input image maxdim") 37 | parser.add_argument( 38 | "--out_size", 39 | type=str, 40 | default="128", 41 | help="Output image size, either 1 or 2 number (w h)", 42 | ) 43 | 44 | parser.add_argument("--focal", type=float, default=131.25, help="Focal length") 45 | 46 | parser.add_argument("--radius", type=float, default=1.3, help="Camera distance") 47 | parser.add_argument("--z_near", type=float, default=0.8) 48 | parser.add_argument("--z_far", type=float, default=1.8) 49 | 50 | parser.add_argument( 51 | "--elevation", 52 | "-e", 53 | type=float, 54 | default=0.0, 55 | help="Elevation angle (negative is above)", 56 | ) 57 | parser.add_argument( 58 | "--num_views", 59 | type=int, 60 | default=24, 61 | help="Number of video frames (rotated views)", 62 | ) 63 | parser.add_argument("--fps", type=int, default=15, help="FPS of video") 64 | parser.add_argument("--gif", action="store_true", help="Store gif instead of mp4") 65 | parser.add_argument( 66 | "--no_vid", 67 | action="store_true", 68 | help="Do not store video (only image frames will be written)", 69 | ) 70 | return parser 71 | 72 | 73 | args, conf = util.args.parse_args( 74 | extra_args, default_expname="srn_car", default_data_format="srn", 75 | ) 76 | args.resume = True 77 | 78 | device = util.get_cuda(args.gpu_id[0]) 79 | net = make_model(conf["model"]).to(device=device).load_weights(args) 80 | renderer = NeRFRenderer.from_conf( 81 | conf["renderer"], eval_batch_size=args.ray_batch_size 82 | ).to(device=device) 83 | render_par = renderer.bind_parallel(net, args.gpu_id, simple_output=True).eval() 84 | 85 | z_near, z_far = args.z_near, args.z_far 86 | focal = torch.tensor(args.focal, dtype=torch.float32, device=device) 87 | 88 | in_sz = args.size 89 | sz = list(map(int, args.out_size.split())) 90 | if len(sz) == 1: 91 | H = W = sz[0] 92 | else: 93 | assert len(sz) == 2 94 | W, H = sz 95 | 96 | _coord_to_blender = util.coord_to_blender() 97 | _coord_from_blender = util.coord_from_blender() 98 | 99 | print("Generating rays") 100 | render_poses = torch.stack( 101 | [ 102 | _coord_from_blender @ util.pose_spherical(angle, args.elevation, args.radius) 103 | # util.pose_spherical(angle, args.elevation, args.radius) 104 | for angle in np.linspace(-180, 180, args.num_views + 1)[:-1] 105 | ], 106 | 0, 107 | ) # (NV, 4, 4) 108 | 109 | render_rays = util.gen_rays(render_poses, W, H, focal, z_near, z_far).to(device=device) 110 | 111 | 112 | inputs_all = os.listdir(args.input) 113 | inputs = [ 114 | os.path.join(args.input, x) for x in inputs_all if x.endswith("_normalize.png") 115 | ] 116 | os.makedirs(args.output, exist_ok=True) 117 | 118 | if len(inputs) == 0: 119 | if len(inputs_all) == 0: 120 | print("No input images found, please place an image into ./input") 121 | else: 122 | print("No processed input images found, did you run 'scripts/preproc.py'?") 123 | import sys 124 | 125 | sys.exit(1) 126 | 127 | cam_pose = torch.eye(4, device=device) 128 | cam_pose[2, -1] = args.radius 129 | print("SET DUMMY CAMERA") 130 | print(cam_pose) 131 | 132 | image_to_tensor = util.get_image_to_tensor_balanced() 133 | 134 | with torch.no_grad(): 135 | for i, image_path in enumerate(inputs): 136 | print("IMAGE", i + 1, "of", len(inputs), "@", image_path) 137 | image = Image.open(image_path).convert("RGB") 138 | image = T.Resize(in_sz)(image) 139 | image = image_to_tensor(image).to(device=device) 140 | 141 | net.encode( 142 | image.unsqueeze(0), cam_pose.unsqueeze(0), focal, 143 | ) 144 | print("Rendering", args.num_views * H * W, "rays") 145 | all_rgb_fine = [] 146 | for rays in tqdm.tqdm(torch.split(render_rays.view(-1, 8), 80000, dim=0)): 147 | rgb, _depth = render_par(rays[None]) 148 | all_rgb_fine.append(rgb[0]) 149 | _depth = None 150 | rgb_fine = torch.cat(all_rgb_fine) 151 | frames = (rgb_fine.view(args.num_views, H, W, 3).cpu().numpy() * 255).astype( 152 | np.uint8 153 | ) 154 | 155 | im_name = os.path.basename(os.path.splitext(image_path)[0]) 156 | 157 | frames_dir_name = os.path.join(args.output, im_name + "_frames") 158 | os.makedirs(frames_dir_name, exist_ok=True) 159 | 160 | for i in range(args.num_views): 161 | frm_path = os.path.join(frames_dir_name, "{:04}.png".format(i)) 162 | imageio.imwrite(frm_path, frames[i]) 163 | 164 | if not args.no_vid: 165 | if args.gif: 166 | vid_path = os.path.join(args.output, im_name + "_vid.gif") 167 | imageio.mimwrite(vid_path, frames, fps=args.fps) 168 | else: 169 | vid_path = os.path.join(args.output, im_name + "_vid.mp4") 170 | imageio.mimwrite(vid_path, frames, fps=args.fps, quality=8) 171 | print("Wrote to", vid_path) 172 | -------------------------------------------------------------------------------- /eval/gen_video.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.insert( 5 | 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")) 6 | ) 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import numpy as np 11 | import imageio 12 | import util 13 | import warnings 14 | from data import get_split_dataset 15 | from render import NeRFRenderer 16 | from model import make_model 17 | from scipy.interpolate import CubicSpline 18 | import tqdm 19 | 20 | 21 | def extra_args(parser): 22 | parser.add_argument( 23 | "--subset", "-S", type=int, default=0, help="Subset in data to use" 24 | ) 25 | parser.add_argument( 26 | "--split", 27 | type=str, 28 | default="train", 29 | help="Split of data to use train | val | test", 30 | ) 31 | parser.add_argument( 32 | "--source", 33 | "-P", 34 | type=str, 35 | default="64", 36 | help="Source view(s) in image, in increasing order. -1 to do random", 37 | ) 38 | parser.add_argument( 39 | "--num_views", 40 | type=int, 41 | default=40, 42 | help="Number of video frames (rotated views)", 43 | ) 44 | parser.add_argument( 45 | "--elevation", 46 | type=float, 47 | default=-10.0, 48 | help="Elevation angle (negative is above)", 49 | ) 50 | parser.add_argument( 51 | "--scale", type=float, default=1.0, help="Video scale relative to input size" 52 | ) 53 | parser.add_argument( 54 | "--radius", 55 | type=float, 56 | default=0.0, 57 | help="Distance of camera from origin, default is average of z_far, z_near of dataset (only for non-DTU)", 58 | ) 59 | parser.add_argument("--fps", type=int, default=30, help="FPS of video") 60 | return parser 61 | 62 | 63 | args, conf = util.args.parse_args(extra_args) 64 | args.resume = True 65 | 66 | device = util.get_cuda(args.gpu_id[0]) 67 | 68 | dset = get_split_dataset( 69 | args.dataset_format, args.datadir, want_split=args.split, training=False 70 | ) 71 | 72 | data = dset[args.subset] 73 | data_path = data["path"] 74 | print("Data instance loaded:", data_path) 75 | 76 | images = data["images"] # (NV, 3, H, W) 77 | 78 | poses = data["poses"] # (NV, 4, 4) 79 | focal = data["focal"] 80 | if isinstance(focal, float): 81 | # Dataset implementations are not consistent about 82 | # returning float or scalar tensor in case of fx=fy 83 | focal = torch.tensor(focal, dtype=torch.float32) 84 | focal = focal[None] 85 | 86 | c = data.get("c") 87 | if c is not None: 88 | c = c.to(device=device).unsqueeze(0) 89 | 90 | NV, _, H, W = images.shape 91 | 92 | if args.scale != 1.0: 93 | Ht = int(H * args.scale) 94 | Wt = int(W * args.scale) 95 | if abs(Ht / args.scale - H) > 1e-10 or abs(Wt / args.scale - W) > 1e-10: 96 | warnings.warn( 97 | "Inexact scaling, please check {} times ({}, {}) is integral".format( 98 | args.scale, H, W 99 | ) 100 | ) 101 | H, W = Ht, Wt 102 | 103 | net = make_model(conf["model"]).to(device=device) 104 | net.load_weights(args) 105 | 106 | renderer = NeRFRenderer.from_conf( 107 | conf["renderer"], lindisp=dset.lindisp, eval_batch_size=args.ray_batch_size, 108 | ).to(device=device) 109 | 110 | render_par = renderer.bind_parallel(net, args.gpu_id, simple_output=True).eval() 111 | 112 | # Get the distance from camera to origin 113 | z_near = dset.z_near 114 | z_far = dset.z_far 115 | 116 | print("Generating rays") 117 | 118 | dtu_format = hasattr(dset, "sub_format") and dset.sub_format == "dtu" 119 | 120 | if dtu_format: 121 | print("Using DTU camera trajectory") 122 | # Use hard-coded pose interpolation from IDR for DTU 123 | 124 | t_in = np.array([0, 2, 3, 5, 6]).astype(np.float32) 125 | pose_quat = torch.tensor( 126 | [ 127 | [0.9698, 0.2121, 0.1203, -0.0039], 128 | [0.7020, 0.1578, 0.4525, 0.5268], 129 | [0.6766, 0.3176, 0.5179, 0.4161], 130 | [0.9085, 0.4020, 0.1139, -0.0025], 131 | [0.9698, 0.2121, 0.1203, -0.0039], 132 | ] 133 | ) 134 | n_inter = args.num_views // 5 135 | args.num_views = n_inter * 5 136 | t_out = np.linspace(t_in[0], t_in[-1], n_inter * int(t_in[-1])).astype(np.float32) 137 | scales = np.array([2.0, 2.0, 2.0, 2.0, 2.0]).astype(np.float32) 138 | 139 | s_new = CubicSpline(t_in, scales, bc_type="periodic") 140 | s_new = s_new(t_out) 141 | 142 | q_new = CubicSpline(t_in, pose_quat.detach().cpu().numpy(), bc_type="periodic") 143 | q_new = q_new(t_out) 144 | q_new = q_new / np.linalg.norm(q_new, 2, 1)[:, None] 145 | q_new = torch.from_numpy(q_new).float() 146 | 147 | render_poses = [] 148 | for i, (new_q, scale) in enumerate(zip(q_new, s_new)): 149 | new_q = new_q.unsqueeze(0) 150 | R = util.quat_to_rot(new_q) 151 | t = R[:, :, 2] * scale 152 | new_pose = torch.eye(4, dtype=torch.float32).unsqueeze(0) 153 | new_pose[:, :3, :3] = R 154 | new_pose[:, :3, 3] = t 155 | render_poses.append(new_pose) 156 | render_poses = torch.cat(render_poses, dim=0) 157 | else: 158 | print("Using default (360 loop) camera trajectory") 159 | if args.radius == 0.0: 160 | radius = (z_near + z_far) * 0.5 161 | print("> Using default camera radius", radius) 162 | else: 163 | radius = args.radius 164 | 165 | # Use 360 pose sequence from NeRF 166 | render_poses = torch.stack( 167 | [ 168 | util.pose_spherical(angle, args.elevation, radius) 169 | for angle in np.linspace(-180, 180, args.num_views + 1)[:-1] 170 | ], 171 | 0, 172 | ) # (NV, 4, 4) 173 | 174 | render_rays = util.gen_rays( 175 | render_poses, 176 | W, 177 | H, 178 | focal * args.scale, 179 | z_near, 180 | z_far, 181 | c=c * args.scale if c is not None else None, 182 | ).to(device=device) 183 | # (NV, H, W, 8) 184 | 185 | focal = focal.to(device=device) 186 | 187 | source = torch.tensor(list(map(int, args.source.split())), dtype=torch.long) 188 | NS = len(source) 189 | random_source = NS == 1 and source[0] == -1 190 | assert not (source >= NV).any() 191 | 192 | if renderer.n_coarse < 64: 193 | # Ensure decent sampling resolution 194 | renderer.n_coarse = 64 195 | renderer.n_fine = 128 196 | 197 | with torch.no_grad(): 198 | print("Encoding source view(s)") 199 | if random_source: 200 | src_view = torch.randint(0, NV, (1,)) 201 | else: 202 | src_view = source 203 | 204 | net.encode( 205 | images[src_view].unsqueeze(0), 206 | poses[src_view].unsqueeze(0).to(device=device), 207 | focal, 208 | c=c, 209 | ) 210 | 211 | print("Rendering", args.num_views * H * W, "rays") 212 | all_rgb_fine = [] 213 | for rays in tqdm.tqdm( 214 | torch.split(render_rays.view(-1, 8), args.ray_batch_size, dim=0) 215 | ): 216 | rgb, _depth = render_par(rays[None]) 217 | all_rgb_fine.append(rgb[0]) 218 | _depth = None 219 | rgb_fine = torch.cat(all_rgb_fine) 220 | # rgb_fine (V*H*W, 3) 221 | 222 | frames = rgb_fine.view(-1, H, W, 3) 223 | 224 | print("Writing video") 225 | vid_name = "{:04}".format(args.subset) 226 | if args.split == "test": 227 | vid_name = "t" + vid_name 228 | elif args.split == "val": 229 | vid_name = "v" + vid_name 230 | vid_name += "_v" + "_".join(map(lambda x: "{:03}".format(x), source)) 231 | vid_path = os.path.join(args.visual_path, args.name, "video" + vid_name + ".mp4") 232 | viewimg_path = os.path.join( 233 | args.visual_path, args.name, "video" + vid_name + "_view.jpg" 234 | ) 235 | imageio.mimwrite( 236 | vid_path, (frames.cpu().numpy() * 255).astype(np.uint8), fps=args.fps, quality=8 237 | ) 238 | 239 | img_np = (data["images"][src_view].permute(0, 2, 3, 1) * 0.5 + 0.5).numpy() 240 | img_np = (img_np * 255).astype(np.uint8) 241 | img_np = np.hstack((*img_np,)) 242 | imageio.imwrite(viewimg_path, img_np) 243 | 244 | print("Wrote to", vid_path, "view:", viewimg_path) 245 | -------------------------------------------------------------------------------- /expconf.conf: -------------------------------------------------------------------------------- 1 | # To save typing, this file offers a way to associate 2 | # default config files (-c) and datase directories (-D) with expnames 3 | config { 4 | # expname = config_file_path 5 | # path should be under 6 | sn64 = conf/exp/sn64.conf 7 | sn64_unseen = conf/exp/sn64_unseen.conf 8 | srn_chair = conf/exp/srn.conf 9 | srn_car = conf/exp/srn.conf 10 | dtu = conf/exp/dtu.conf 11 | multi_obj = conf/exp/multi_obj.conf 12 | } 13 | 14 | datadir { 15 | # expname = data directory 16 | } 17 | -------------------------------------------------------------------------------- /input/model3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxyu/pixel-nerf/91a044bdd62aebe0ed3a5685ca37cb8a9dc8e8ee/input/model3.png -------------------------------------------------------------------------------- /input/model3_normalize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxyu/pixel-nerf/91a044bdd62aebe0ed3a5685ca37cb8a9dc8e8ee/input/model3_normalize.png -------------------------------------------------------------------------------- /input/police.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxyu/pixel-nerf/91a044bdd62aebe0ed3a5685ca37cb8a9dc8e8ee/input/police.jpg -------------------------------------------------------------------------------- /input/police_normalize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxyu/pixel-nerf/91a044bdd62aebe0ed3a5685ca37cb8a9dc8e8ee/input/police_normalize.png -------------------------------------------------------------------------------- /input/toyota.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxyu/pixel-nerf/91a044bdd62aebe0ed3a5685ca37cb8a9dc8e8ee/input/toyota.jpg -------------------------------------------------------------------------------- /input/toyota_normalize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxyu/pixel-nerf/91a044bdd62aebe0ed3a5685ca37cb8a9dc8e8ee/input/toyota_normalize.png -------------------------------------------------------------------------------- /readme-img/paper_teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxyu/pixel-nerf/91a044bdd62aebe0ed3a5685ca37cb8a9dc8e8ee/readme-img/paper_teaser.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | pretrainedmodels 4 | pyhocon 5 | imageio 6 | opencv-python 7 | imageio-ffmpeg 8 | tensorboard 9 | dotmap 10 | numpy 11 | scipy 12 | scikit-image 13 | ipdb 14 | matplotlib 15 | tqdm 16 | lpips 17 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Rendering multiple object ShapeNet scenes 2 | The `render_shapenet.py` script is used to render ShapeNet scenes composed of multiple object instances, given an object class. 3 | This script will render different splits (train/test/val) of the ShapeNet models; see [Render Flags](#render-flags) for more information. 4 | 5 | ## Installing Blender 6 | 7 | 1. Download and untar Blender 8 | ``` 9 | wget https://mirror.clarkson.edu/blender/release/Blender2.90/blender-2.90.1-linux64.tar.xz 10 | tar -xvf blender-2.82a-linux64.tar.xz 11 | ``` 12 | 13 | 2. Install other Python dependencies in the Blender bundled Python 14 | ``` 15 | cd $INSTALL_PATH/blender-2.82a-linux64/2.82/python/bin/ 16 | ./python3.7m -m ensurepip 17 | ./pip3 install numpy scipy 18 | ``` 19 | 20 | 3. In your `.bash_aliases` file, add 21 | ``` 22 | alias bpy="blender --background -noaudio --python” 23 | ``` 24 | This allows you to call 25 | ``` 26 | bpy render_shapenet.py -- 27 | ``` 28 | Unless debugging, recommended to redirect Blender's stdout to /dev/null and direct stderr to stdout to keep script logging. 29 | ``` 30 | bpy render_shapenet.py -- 2>&1 >/dev/null 31 | ``` 32 | 33 | ## Render Flags 34 | - `--out_dir` (required) -- Parent directory to write rendered images. Instances will be rendered by ID in child subdirectories. 35 | - `--src_model_dir` (required) -- Location of the ShapeNet model directory with all object classes and instances. 36 | - `--object` (default: chair) -- Name of object class to render. 37 | - `--val_frac` (default: 0.2) -- When generating a split of object instances, what fraction of all instances to use as validation. The resulting split is written in the object class directory as `val_split_{n_val}.txt`. 38 | - `--test_frac` (default: 0.2) -- When generating a split of object instances, what fraction of all instances to use as test. The resulting split is written in the object class directory as `test_split_{n_test}.txt`. 39 | - `--split` (choice of `[train, val, test]`) -- Which split to render. `val/test` splits use a specific camera trajectory (Archimedes spiral, from SRN). 40 | - `--n_views` (default: 20) -- Number of views to render per instance. 41 | - `--res` (default: 128) -- Output resolution of images (default 128x128). 42 | - `--start_idx` (default: 0) -- If rendering a subset of the object instances, provide the starting index. 43 | - `--n_objects` (default: 2) -- The number of objects to include per scene. 44 | - `--end_idx` (default: -1) -- If rendering a subset of the object instances, provide the ending index. 45 | - `--use_pbr` -- Whether to use Cycles to render with physically based rendering. Slower, but more photorealistic. 46 | - `--light_env` -- If `--use_pbr`, you can use an HDRI light map. Pass the path of the HDRI here. 47 | - `--light_strength` -- The strength of the light map in the scene, if using an HDRI light map. You can easily get HDRIs from websites like https://hdrihaven.com/. 48 | - `--render_alpha` -- Render the object masks. 49 | - `--render_depth` -- Render the scene depth map. 50 | - `--render_bg` -- Render the scene background (only useful if using PBR + HDRI light maps). 51 | - `--pool` -- Render in parallel. Faster. 52 | 53 | 54 | ## Rendering with Blender EEVEE 55 | 56 | By default, it is only possible to render headless using Blender's PBR engine Cycles. 57 | While photorealistic, rendering with Eevee is much faster. 58 | To enable headless rendering using Eevee, you also need the following dependencies 59 | 60 | ### OpenGL 61 | Openg GL is necessary for Virtual GL. Normally OpenGL can be installed through apt. 62 | ```sudo apt-get install freeglut3-dev mesa-utils``` 63 | 64 | ### Virtual GL 65 | Install VGL with [this tutorial](https://virtualgl.org/vgldoc/2_2_1/#hd004001). 66 | 67 | 68 | ### TurboVNC 69 | Install TurboVNC with [this tutorial](https://cdn.rawgit.com/TurboVNC/turbovnc/2.1.1/doc/index.html#hd005001). 70 | 71 | ### X11 utilities 72 | ``` 73 | sudo apt install x11-xserver-utils libxrandr-dev 74 | ``` 75 | 76 | ### Emulating the Virtual Display 77 | First configure your X server to be compatible with your graphics card. 78 | ``` 79 | sudo nvidia-xconfig -a --use-display-device=None --virtual=1280x1024 80 | ``` 81 | You can also further edit this configuration at `/etc/X11/xorg.conf`. 82 | Now start your X server, labeled with an arbitrary server number, in this case 7 83 | ``` 84 | sudo nohup Xorg :7 & 85 | ``` 86 | Run an auxiliary remote VNC server to create a virtual display. Label it with a separate remote server number, in this case 8. 87 | ``` 88 | /opt/TurboVNC/bin/vncserver :8 89 | ``` 90 | To test, run `glxinfo` on Xserver 7, device 0 (GPU 0 on your machine). 91 | ``` 92 | DISPLAY=:8 vglrun -d :7.0 glxinfo 93 | ``` 94 | If all is well, proceed to run headless rendering with Eevee with 95 | ``` 96 | DISPLAY=:8 vglrun -d :7.0 blender --background -noaudio --python render_shapenet.py -- 97 | ``` 98 | -------------------------------------------------------------------------------- /scripts/detectron2/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019 - present, Facebook, Inc 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /scripts/detectron2/configs/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | DATASETS: 33 | TRAIN: ("coco_2017_train",) 34 | TEST: ("coco_2017_val",) 35 | SOLVER: 36 | IMS_PER_BATCH: 16 37 | BASE_LR: 0.02 38 | STEPS: (60000, 80000) 39 | MAX_ITER: 90000 40 | INPUT: 41 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 42 | VERSION: 2 43 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/configs/InstanceSegmentation/Base-PointRend-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../../../configs/Base-RCNN-FPN.yaml" 2 | MODEL: 3 | ROI_HEADS: 4 | NAME: "PointRendROIHeads" 5 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 6 | ROI_BOX_HEAD: 7 | TRAIN_ON_PRED_BOXES: True 8 | ROI_MASK_HEAD: 9 | NAME: "CoarseMaskHead" 10 | FC_DIM: 1024 11 | NUM_FC: 2 12 | OUTPUT_SIDE_RESOLUTION: 7 13 | IN_FEATURES: ["p2"] 14 | POINT_HEAD_ON: True 15 | POINT_HEAD: 16 | FC_DIM: 256 17 | NUM_FC: 3 18 | IN_FEATURES: ["p2"] 19 | INPUT: 20 | # PointRend for instance segmenation does not work with "polygon" mask_format. 21 | MASK_FORMAT: "bitmask" 22 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-PointRend-RCNN-FPN.yaml 2 | MODEL: 3 | WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl 4 | MASK_ON: true 5 | RESNETS: 6 | DEPTH: 50 7 | ROI_HEADS: 8 | NUM_CLASSES: 8 9 | POINT_HEAD: 10 | NUM_CLASSES: 8 11 | DATASETS: 12 | TEST: ("cityscapes_fine_instance_seg_val",) 13 | TRAIN: ("cityscapes_fine_instance_seg_train",) 14 | SOLVER: 15 | BASE_LR: 0.01 16 | IMS_PER_BATCH: 8 17 | MAX_ITER: 24000 18 | STEPS: (18000,) 19 | INPUT: 20 | MAX_SIZE_TEST: 2048 21 | MAX_SIZE_TRAIN: 2048 22 | MIN_SIZE_TEST: 1024 23 | MIN_SIZE_TRAIN: (800, 832, 864, 896, 928, 960, 992, 1024) 24 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-PointRend-RCNN-FPN.yaml 2 | MODEL: 3 | WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl 4 | MASK_ON: true 5 | RESNETS: 6 | DEPTH: 50 7 | # To add COCO AP evaluation against the higher-quality LVIS annotations. 8 | # DATASETS: 9 | # TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied") 10 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-PointRend-RCNN-FPN.yaml 2 | MODEL: 3 | WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl 4 | MASK_ON: true 5 | RESNETS: 6 | DEPTH: 50 7 | SOLVER: 8 | STEPS: (210000, 250000) 9 | MAX_ITER: 270000 10 | # To add COCO AP evaluation against the higher-quality LVIS annotations. 11 | # DATASETS: 12 | # TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied") 13 | 14 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/configs/SemanticSegmentation/Base-PointRend-Semantic-FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../../../configs/Base-RCNN-FPN.yaml" 2 | MODEL: 3 | META_ARCHITECTURE: "SemanticSegmentor" 4 | BACKBONE: 5 | FREEZE_AT: 0 6 | SEM_SEG_HEAD: 7 | NAME: "PointRendSemSegHead" 8 | POINT_HEAD: 9 | NUM_CLASSES: 54 10 | FC_DIM: 256 11 | NUM_FC: 3 12 | IN_FEATURES: ["p2"] 13 | TRAIN_NUM_POINTS: 1024 14 | SUBDIVISION_STEPS: 2 15 | SUBDIVISION_NUM_POINTS: 8192 16 | COARSE_SEM_SEG_HEAD_NAME: "SemSegFPNHead" 17 | COARSE_PRED_EACH_LAYER: False 18 | DATASETS: 19 | TRAIN: ("coco_2017_train_panoptic_stuffonly",) 20 | TEST: ("coco_2017_val_panoptic_stuffonly",) 21 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/configs/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-PointRend-Semantic-FPN.yaml 2 | MODEL: 3 | WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-101.pkl 4 | RESNETS: 5 | DEPTH: 101 6 | SEM_SEG_HEAD: 7 | NUM_CLASSES: 19 8 | POINT_HEAD: 9 | NUM_CLASSES: 19 10 | TRAIN_NUM_POINTS: 2048 11 | SUBDIVISION_NUM_POINTS: 8192 12 | DATASETS: 13 | TRAIN: ("cityscapes_fine_sem_seg_train",) 14 | TEST: ("cityscapes_fine_sem_seg_val",) 15 | SOLVER: 16 | BASE_LR: 0.01 17 | STEPS: (40000, 55000) 18 | MAX_ITER: 65000 19 | IMS_PER_BATCH: 32 20 | INPUT: 21 | MIN_SIZE_TRAIN: (512, 768, 1024, 1280, 1536, 1792, 2048) 22 | MIN_SIZE_TRAIN_SAMPLING: "choice" 23 | MIN_SIZE_TEST: 1024 24 | MAX_SIZE_TRAIN: 4096 25 | MAX_SIZE_TEST: 2048 26 | CROP: 27 | ENABLED: True 28 | TYPE: "absolute" 29 | SIZE: (512, 1024) 30 | SINGLE_CATEGORY_MAX_AREA: 0.75 31 | COLOR_AUG_SSD: True 32 | DATALOADER: 33 | NUM_WORKERS: 10 34 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .config import add_pointrend_config 3 | from .coarse_mask_head import CoarseMaskHead 4 | from .roi_heads import PointRendROIHeads 5 | from .semantic_seg import PointRendSemSegHead 6 | from .color_augmentation import ColorAugSSDTransform 7 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/coarse_mask_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import fvcore.nn.weight_init as weight_init 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from detectron2.layers import Conv2d, ShapeSpec 8 | from detectron2.modeling import ROI_MASK_HEAD_REGISTRY 9 | 10 | 11 | @ROI_MASK_HEAD_REGISTRY.register() 12 | class CoarseMaskHead(nn.Module): 13 | """ 14 | A mask head with fully connected layers. Given pooled features it first reduces channels and 15 | spatial dimensions with conv layers and then uses FC layers to predict coarse masks analogously 16 | to the standard box head. 17 | """ 18 | 19 | def __init__(self, cfg, input_shape: ShapeSpec): 20 | """ 21 | The following attributes are parsed from config: 22 | conv_dim: the output dimension of the conv layers 23 | fc_dim: the feature dimenstion of the FC layers 24 | num_fc: the number of FC layers 25 | output_side_resolution: side resolution of the output square mask prediction 26 | """ 27 | super(CoarseMaskHead, self).__init__() 28 | 29 | # fmt: off 30 | self.num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES 31 | conv_dim = cfg.MODEL.ROI_MASK_HEAD.CONV_DIM 32 | self.fc_dim = cfg.MODEL.ROI_MASK_HEAD.FC_DIM 33 | num_fc = cfg.MODEL.ROI_MASK_HEAD.NUM_FC 34 | self.output_side_resolution = cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION 35 | self.input_channels = input_shape.channels 36 | self.input_h = input_shape.height 37 | self.input_w = input_shape.width 38 | # fmt: on 39 | 40 | self.conv_layers = [] 41 | if self.input_channels > conv_dim: 42 | self.reduce_channel_dim_conv = Conv2d( 43 | self.input_channels, 44 | conv_dim, 45 | kernel_size=1, 46 | stride=1, 47 | padding=0, 48 | bias=True, 49 | activation=F.relu, 50 | ) 51 | self.conv_layers.append(self.reduce_channel_dim_conv) 52 | 53 | self.reduce_spatial_dim_conv = Conv2d( 54 | conv_dim, conv_dim, kernel_size=2, stride=2, padding=0, bias=True, activation=F.relu 55 | ) 56 | self.conv_layers.append(self.reduce_spatial_dim_conv) 57 | 58 | input_dim = conv_dim * self.input_h * self.input_w 59 | input_dim //= 4 60 | 61 | self.fcs = [] 62 | for k in range(num_fc): 63 | fc = nn.Linear(input_dim, self.fc_dim) 64 | self.add_module("coarse_mask_fc{}".format(k + 1), fc) 65 | self.fcs.append(fc) 66 | input_dim = self.fc_dim 67 | 68 | output_dim = self.num_classes * self.output_side_resolution * self.output_side_resolution 69 | 70 | self.prediction = nn.Linear(self.fc_dim, output_dim) 71 | # use normal distribution initialization for mask prediction layer 72 | nn.init.normal_(self.prediction.weight, std=0.001) 73 | nn.init.constant_(self.prediction.bias, 0) 74 | 75 | for layer in self.conv_layers: 76 | weight_init.c2_msra_fill(layer) 77 | for layer in self.fcs: 78 | weight_init.c2_xavier_fill(layer) 79 | 80 | def forward(self, x): 81 | # unlike BaseMaskRCNNHead, this head only outputs intermediate 82 | # features, because the features will be used later by PointHead. 83 | N = x.shape[0] 84 | x = x.view(N, self.input_channels, self.input_h, self.input_w) 85 | for layer in self.conv_layers: 86 | x = layer(x) 87 | x = torch.flatten(x, start_dim=1) 88 | for layer in self.fcs: 89 | x = F.relu(layer(x)) 90 | return self.prediction(x).view( 91 | N, self.num_classes, self.output_side_resolution, self.output_side_resolution 92 | ) 93 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/color_augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import numpy as np 3 | import random 4 | import cv2 5 | from fvcore.transforms.transform import Transform 6 | 7 | 8 | class ColorAugSSDTransform(Transform): 9 | """ 10 | A color related data augmentation used in Single Shot Multibox Detector (SSD). 11 | 12 | Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, 13 | Scott Reed, Cheng-Yang Fu, Alexander C. Berg. 14 | SSD: Single Shot MultiBox Detector. ECCV 2016. 15 | 16 | Implementation based on: 17 | 18 | https://github.com/weiliu89/caffe/blob 19 | /4817bf8b4200b35ada8ed0dc378dceaf38c539e4 20 | /src/caffe/util/im_transforms.cpp 21 | 22 | https://github.com/chainer/chainercv/blob 23 | /7159616642e0be7c5b3ef380b848e16b7e99355b/chainercv 24 | /links/model/ssd/transforms.py 25 | """ 26 | 27 | def __init__( 28 | self, 29 | img_format, 30 | brightness_delta=32, 31 | contrast_low=0.5, 32 | contrast_high=1.5, 33 | saturation_low=0.5, 34 | saturation_high=1.5, 35 | hue_delta=18, 36 | ): 37 | super().__init__() 38 | assert img_format in ["BGR", "RGB"] 39 | self.is_rgb = img_format == "RGB" 40 | del img_format 41 | self._set_attributes(locals()) 42 | 43 | def apply_coords(self, coords): 44 | return coords 45 | 46 | def apply_segmentation(self, segmentation): 47 | return segmentation 48 | 49 | def apply_image(self, img, interp=None): 50 | if self.is_rgb: 51 | img = img[:, :, [2, 1, 0]] 52 | img = self.brightness(img) 53 | if random.randrange(2): 54 | img = self.contrast(img) 55 | img = self.saturation(img) 56 | img = self.hue(img) 57 | else: 58 | img = self.saturation(img) 59 | img = self.hue(img) 60 | img = self.contrast(img) 61 | if self.is_rgb: 62 | img = img[:, :, [2, 1, 0]] 63 | return img 64 | 65 | def convert(self, img, alpha=1, beta=0): 66 | img = img.astype(np.float32) * alpha + beta 67 | img = np.clip(img, 0, 255) 68 | return img.astype(np.uint8) 69 | 70 | def brightness(self, img): 71 | if random.randrange(2): 72 | return self.convert( 73 | img, beta=random.uniform(-self.brightness_delta, self.brightness_delta) 74 | ) 75 | return img 76 | 77 | def contrast(self, img): 78 | if random.randrange(2): 79 | return self.convert(img, alpha=random.uniform(self.contrast_low, self.contrast_high)) 80 | return img 81 | 82 | def saturation(self, img): 83 | if random.randrange(2): 84 | img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 85 | img[:, :, 1] = self.convert( 86 | img[:, :, 1], alpha=random.uniform(self.saturation_low, self.saturation_high) 87 | ) 88 | return cv2.cvtColor(img, cv2.COLOR_HSV2BGR) 89 | return img 90 | 91 | def hue(self, img): 92 | if random.randrange(2): 93 | img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 94 | img[:, :, 0] = ( 95 | img[:, :, 0].astype(int) + random.randint(-self.hue_delta, self.hue_delta) 96 | ) % 180 97 | return cv2.cvtColor(img, cv2.COLOR_HSV2BGR) 98 | return img 99 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from detectron2.config import CfgNode as CN 5 | 6 | 7 | def add_pointrend_config(cfg): 8 | """ 9 | Add config for PointRend. 10 | """ 11 | # We retry random cropping until no single category in semantic segmentation GT occupies more 12 | # than `SINGLE_CATEGORY_MAX_AREA` part of the crop. 13 | cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0 14 | # Color augmentatition from SSD paper for semantic segmentation model during training. 15 | cfg.INPUT.COLOR_AUG_SSD = False 16 | 17 | # Names of the input feature maps to be used by a coarse mask head. 18 | cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES = ("p2",) 19 | cfg.MODEL.ROI_MASK_HEAD.FC_DIM = 1024 20 | cfg.MODEL.ROI_MASK_HEAD.NUM_FC = 2 21 | # The side size of a coarse mask head prediction. 22 | cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION = 7 23 | # True if point head is used. 24 | cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON = False 25 | 26 | cfg.MODEL.POINT_HEAD = CN() 27 | cfg.MODEL.POINT_HEAD.NAME = "StandardPointHead" 28 | cfg.MODEL.POINT_HEAD.NUM_CLASSES = 80 29 | # Names of the input feature maps to be used by a mask point head. 30 | cfg.MODEL.POINT_HEAD.IN_FEATURES = ("p2",) 31 | # Number of points sampled during training for a mask point head. 32 | cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS = 14 * 14 33 | # Oversampling parameter for PointRend point sampling during training. Parameter `k` in the 34 | # original paper. 35 | cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO = 3 36 | # Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in 37 | # the original paper. 38 | cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO = 0.75 39 | # Number of subdivision steps during inference. 40 | cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS = 5 41 | # Maximum number of points selected at each subdivision step (N). 42 | cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS = 28 * 28 43 | cfg.MODEL.POINT_HEAD.FC_DIM = 256 44 | cfg.MODEL.POINT_HEAD.NUM_FC = 3 45 | cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK = False 46 | # If True, then coarse prediction features are used as inout for each layer in PointRend's MLP. 47 | cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER = True 48 | cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME = "SemSegFPNHead" 49 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/point_features.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | from detectron2.layers import cat 6 | from detectron2.structures import Boxes 7 | 8 | 9 | """ 10 | Shape shorthand in this module: 11 | 12 | N: minibatch dimension size, i.e. the number of RoIs for instance segmenation or the 13 | number of images for semantic segmenation. 14 | R: number of ROIs, combined over all images, in the minibatch 15 | P: number of points 16 | """ 17 | 18 | 19 | def point_sample(input, point_coords, **kwargs): 20 | """ 21 | A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. 22 | Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside 23 | [0, 1] x [0, 1] square. 24 | 25 | Args: 26 | input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. 27 | point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains 28 | [0, 1] x [0, 1] normalized point coordinates. 29 | 30 | Returns: 31 | output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains 32 | features for points in `point_coords`. The features are obtained via bilinear 33 | interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. 34 | """ 35 | add_dim = False 36 | if point_coords.dim() == 3: 37 | add_dim = True 38 | point_coords = point_coords.unsqueeze(2) 39 | output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) 40 | if add_dim: 41 | output = output.squeeze(3) 42 | return output 43 | 44 | 45 | def generate_regular_grid_point_coords(R, side_size, device): 46 | """ 47 | Generate regular square grid of points in [0, 1] x [0, 1] coordinate space. 48 | 49 | Args: 50 | R (int): The number of grids to sample, one for each region. 51 | side_size (int): The side size of the regular grid. 52 | device (torch.device): Desired device of returned tensor. 53 | 54 | Returns: 55 | (Tensor): A tensor of shape (R, side_size^2, 2) that contains coordinates 56 | for the regular grids. 57 | """ 58 | aff = torch.tensor([[[0.5, 0, 0.5], [0, 0.5, 0.5]]], device=device) 59 | r = F.affine_grid(aff, torch.Size((1, 1, side_size, side_size)), align_corners=False) 60 | return r.view(1, -1, 2).expand(R, -1, -1) 61 | 62 | 63 | def get_uncertain_point_coords_with_randomness( 64 | coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio 65 | ): 66 | """ 67 | Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties 68 | are calculated for each point using 'uncertainty_func' function that takes point's logit 69 | prediction as input. 70 | See PointRend paper for details. 71 | 72 | Args: 73 | coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for 74 | class-specific or class-agnostic prediction. 75 | uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that 76 | contains logit predictions for P points and returns their uncertainties as a Tensor of 77 | shape (N, 1, P). 78 | num_points (int): The number of points P to sample. 79 | oversample_ratio (int): Oversampling parameter. 80 | importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. 81 | 82 | Returns: 83 | point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P 84 | sampled points. 85 | """ 86 | assert oversample_ratio >= 1 87 | assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 88 | num_boxes = coarse_logits.shape[0] 89 | num_sampled = int(num_points * oversample_ratio) 90 | point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) 91 | point_logits = point_sample(coarse_logits, point_coords, align_corners=False) 92 | # It is crucial to calculate uncertainty based on the sampled prediction value for the points. 93 | # Calculating uncertainties of the coarse predictions first and sampling them for points leads 94 | # to incorrect results. 95 | # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between 96 | # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. 97 | # However, if we calculate uncertainties for the coarse predictions first, 98 | # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. 99 | point_uncertainties = uncertainty_func(point_logits) 100 | num_uncertain_points = int(importance_sample_ratio * num_points) 101 | num_random_points = num_points - num_uncertain_points 102 | idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] 103 | shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) 104 | idx += shift[:, None] 105 | point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( 106 | num_boxes, num_uncertain_points, 2 107 | ) 108 | if num_random_points > 0: 109 | point_coords = cat( 110 | [ 111 | point_coords, 112 | torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), 113 | ], 114 | dim=1, 115 | ) 116 | return point_coords 117 | 118 | 119 | def get_uncertain_point_coords_on_grid(uncertainty_map, num_points): 120 | """ 121 | Find `num_points` most uncertain points from `uncertainty_map` grid. 122 | 123 | Args: 124 | uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty 125 | values for a set of points on a regular H x W grid. 126 | num_points (int): The number of points P to select. 127 | 128 | Returns: 129 | point_indices (Tensor): A tensor of shape (N, P) that contains indices from 130 | [0, H x W) of the most uncertain points. 131 | point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized 132 | coordinates of the most uncertain points from the H x W grid. 133 | """ 134 | R, _, H, W = uncertainty_map.shape 135 | h_step = 1.0 / float(H) 136 | w_step = 1.0 / float(W) 137 | 138 | num_points = min(H * W, num_points) 139 | point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)[1] 140 | point_coords = torch.zeros(R, num_points, 2, dtype=torch.float, device=uncertainty_map.device) 141 | point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step 142 | point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step 143 | return point_indices, point_coords 144 | 145 | 146 | def point_sample_fine_grained_features(features_list, feature_scales, boxes, point_coords): 147 | """ 148 | Get features from feature maps in `features_list` that correspond to specific point coordinates 149 | inside each bounding box from `boxes`. 150 | 151 | Args: 152 | features_list (list[Tensor]): A list of feature map tensors to get features from. 153 | feature_scales (list[float]): A list of scales for tensors in `features_list`. 154 | boxes (list[Boxes]): A list of I Boxes objects that contain R_1 + ... + R_I = R boxes all 155 | together. 156 | point_coords (Tensor): A tensor of shape (R, P, 2) that contains 157 | [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. 158 | 159 | Returns: 160 | point_features (Tensor): A tensor of shape (R, C, P) that contains features sampled 161 | from all features maps in feature_list for P sampled points for all R boxes in `boxes`. 162 | point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-level 163 | coordinates of P points. 164 | """ 165 | cat_boxes = Boxes.cat(boxes) 166 | num_boxes = [len(b) for b in boxes] 167 | 168 | point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) 169 | split_point_coords_wrt_image = torch.split(point_coords_wrt_image, num_boxes) 170 | 171 | point_features = [] 172 | for idx_img, point_coords_wrt_image_per_image in enumerate(split_point_coords_wrt_image): 173 | point_features_per_image = [] 174 | for idx_feature, feature_map in enumerate(features_list): 175 | h, w = feature_map.shape[-2:] 176 | scale = torch.tensor([w, h], device=feature_map.device) / feature_scales[idx_feature] 177 | point_coords_scaled = point_coords_wrt_image_per_image / scale 178 | point_features_per_image.append( 179 | point_sample( 180 | feature_map[idx_img].unsqueeze(0), 181 | point_coords_scaled.unsqueeze(0), 182 | align_corners=False, 183 | ) 184 | .squeeze(0) 185 | .transpose(1, 0) 186 | ) 187 | point_features.append(cat(point_features_per_image, dim=1)) 188 | 189 | return cat(point_features, dim=0), point_coords_wrt_image 190 | 191 | 192 | def get_point_coords_wrt_image(boxes_coords, point_coords): 193 | """ 194 | Convert box-normalized [0, 1] x [0, 1] point cooordinates to image-level coordinates. 195 | 196 | Args: 197 | boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes. 198 | coordinates. 199 | point_coords (Tensor): A tensor of shape (R, P, 2) that contains 200 | [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. 201 | 202 | Returns: 203 | point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains 204 | image-normalized coordinates of P sampled points. 205 | """ 206 | with torch.no_grad(): 207 | point_coords_wrt_image = point_coords.clone() 208 | point_coords_wrt_image[:, :, 0] = point_coords_wrt_image[:, :, 0] * ( 209 | boxes_coords[:, None, 2] - boxes_coords[:, None, 0] 210 | ) 211 | point_coords_wrt_image[:, :, 1] = point_coords_wrt_image[:, :, 1] * ( 212 | boxes_coords[:, None, 3] - boxes_coords[:, None, 1] 213 | ) 214 | point_coords_wrt_image[:, :, 0] += boxes_coords[:, None, 0] 215 | point_coords_wrt_image[:, :, 1] += boxes_coords[:, None, 1] 216 | return point_coords_wrt_image 217 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/point_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import fvcore.nn.weight_init as weight_init 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from detectron2.layers import ShapeSpec, cat 8 | from detectron2.structures import BitMasks 9 | from detectron2.utils.events import get_event_storage 10 | from detectron2.utils.registry import Registry 11 | 12 | from .point_features import point_sample 13 | 14 | POINT_HEAD_REGISTRY = Registry("POINT_HEAD") 15 | POINT_HEAD_REGISTRY.__doc__ = """ 16 | Registry for point heads, which makes prediction for a given set of per-point features. 17 | 18 | The registered object will be called with `obj(cfg, input_shape)`. 19 | """ 20 | 21 | 22 | def roi_mask_point_loss(mask_logits, instances, points_coord): 23 | """ 24 | Compute the point-based loss for instance segmentation mask predictions. 25 | 26 | Args: 27 | mask_logits (Tensor): A tensor of shape (R, C, P) or (R, 1, P) for class-specific or 28 | class-agnostic, where R is the total number of predicted masks in all images, C is the 29 | number of foreground classes, and P is the number of points sampled for each mask. 30 | The values are logits. 31 | instances (list[Instances]): A list of N Instances, where N is the number of images 32 | in the batch. These instances are in 1:1 correspondence with the `mask_logits`. So, i_th 33 | elememt of the list contains R_i objects and R_1 + ... + R_N is equal to R. 34 | The ground-truth labels (class, box, mask, ...) associated with each instance are stored 35 | in fields. 36 | points_coords (Tensor): A tensor of shape (R, P, 2), where R is the total number of 37 | predicted masks and P is the number of points for each mask. The coordinates are in 38 | the image pixel coordinate space, i.e. [0, H] x [0, W]. 39 | Returns: 40 | point_loss (Tensor): A scalar tensor containing the loss. 41 | """ 42 | with torch.no_grad(): 43 | cls_agnostic_mask = mask_logits.size(1) == 1 44 | total_num_masks = mask_logits.size(0) 45 | 46 | gt_classes = [] 47 | gt_mask_logits = [] 48 | idx = 0 49 | for instances_per_image in instances: 50 | if len(instances_per_image) == 0: 51 | continue 52 | assert isinstance( 53 | instances_per_image.gt_masks, BitMasks 54 | ), "Point head works with GT in 'bitmask' format. Set INPUT.MASK_FORMAT to 'bitmask'." 55 | 56 | if not cls_agnostic_mask: 57 | gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64) 58 | gt_classes.append(gt_classes_per_image) 59 | 60 | gt_bit_masks = instances_per_image.gt_masks.tensor 61 | h, w = instances_per_image.gt_masks.image_size 62 | scale = torch.tensor([w, h], dtype=torch.float, device=gt_bit_masks.device) 63 | points_coord_grid_sample_format = ( 64 | points_coord[idx : idx + len(instances_per_image)] / scale 65 | ) 66 | idx += len(instances_per_image) 67 | gt_mask_logits.append( 68 | point_sample( 69 | gt_bit_masks.to(torch.float32).unsqueeze(1), 70 | points_coord_grid_sample_format, 71 | align_corners=False, 72 | ).squeeze(1) 73 | ) 74 | 75 | if len(gt_mask_logits) == 0: 76 | return mask_logits.sum() * 0 77 | 78 | gt_mask_logits = cat(gt_mask_logits) 79 | assert gt_mask_logits.numel() > 0, gt_mask_logits.shape 80 | 81 | if cls_agnostic_mask: 82 | mask_logits = mask_logits[:, 0] 83 | else: 84 | indices = torch.arange(total_num_masks) 85 | gt_classes = cat(gt_classes, dim=0) 86 | mask_logits = mask_logits[indices, gt_classes] 87 | 88 | # Log the training accuracy (using gt classes and 0.0 threshold for the logits) 89 | mask_accurate = (mask_logits > 0.0) == gt_mask_logits.to(dtype=torch.uint8) 90 | mask_accuracy = mask_accurate.nonzero().size(0) / mask_accurate.numel() 91 | get_event_storage().put_scalar("point_rend/accuracy", mask_accuracy) 92 | 93 | point_loss = F.binary_cross_entropy_with_logits( 94 | mask_logits, gt_mask_logits.to(dtype=torch.float32), reduction="mean" 95 | ) 96 | return point_loss 97 | 98 | 99 | @POINT_HEAD_REGISTRY.register() 100 | class StandardPointHead(nn.Module): 101 | """ 102 | A point head multi-layer perceptron which we model with conv1d layers with kernel 1. The head 103 | takes both fine-grained and coarse prediction features as its input. 104 | """ 105 | 106 | def __init__(self, cfg, input_shape: ShapeSpec): 107 | """ 108 | The following attributes are parsed from config: 109 | fc_dim: the output dimension of each FC layers 110 | num_fc: the number of FC layers 111 | coarse_pred_each_layer: if True, coarse prediction features are concatenated to each 112 | layer's input 113 | """ 114 | super(StandardPointHead, self).__init__() 115 | # fmt: off 116 | num_classes = cfg.MODEL.POINT_HEAD.NUM_CLASSES 117 | fc_dim = cfg.MODEL.POINT_HEAD.FC_DIM 118 | num_fc = cfg.MODEL.POINT_HEAD.NUM_FC 119 | cls_agnostic_mask = cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK 120 | self.coarse_pred_each_layer = cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER 121 | input_channels = input_shape.channels 122 | # fmt: on 123 | 124 | fc_dim_in = input_channels + num_classes 125 | self.fc_layers = [] 126 | for k in range(num_fc): 127 | fc = nn.Conv1d(fc_dim_in, fc_dim, kernel_size=1, stride=1, padding=0, bias=True) 128 | self.add_module("fc{}".format(k + 1), fc) 129 | self.fc_layers.append(fc) 130 | fc_dim_in = fc_dim 131 | fc_dim_in += num_classes if self.coarse_pred_each_layer else 0 132 | 133 | num_mask_classes = 1 if cls_agnostic_mask else num_classes 134 | self.predictor = nn.Conv1d(fc_dim_in, num_mask_classes, kernel_size=1, stride=1, padding=0) 135 | 136 | for layer in self.fc_layers: 137 | weight_init.c2_msra_fill(layer) 138 | # use normal distribution initialization for mask prediction layer 139 | nn.init.normal_(self.predictor.weight, std=0.001) 140 | if self.predictor.bias is not None: 141 | nn.init.constant_(self.predictor.bias, 0) 142 | 143 | def forward(self, fine_grained_features, coarse_features): 144 | x = torch.cat((fine_grained_features, coarse_features), dim=1) 145 | for layer in self.fc_layers: 146 | x = F.relu(layer(x)) 147 | if self.coarse_pred_each_layer: 148 | x = cat((x, coarse_features), dim=1) 149 | return self.predictor(x) 150 | 151 | 152 | def build_point_head(cfg, input_channels): 153 | """ 154 | Build a point head defined by `cfg.MODEL.POINT_HEAD.NAME`. 155 | """ 156 | head_name = cfg.MODEL.POINT_HEAD.NAME 157 | return POINT_HEAD_REGISTRY.get(head_name)(cfg, input_channels) 158 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/roi_heads.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import numpy as np 4 | import torch 5 | 6 | from detectron2.layers import ShapeSpec, cat, interpolate 7 | from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads 8 | from detectron2.modeling.roi_heads.mask_head import ( 9 | build_mask_head, 10 | mask_rcnn_inference, 11 | mask_rcnn_loss, 12 | ) 13 | from detectron2.modeling.roi_heads.roi_heads import select_foreground_proposals 14 | 15 | from .point_features import ( 16 | generate_regular_grid_point_coords, 17 | get_uncertain_point_coords_on_grid, 18 | get_uncertain_point_coords_with_randomness, 19 | point_sample, 20 | point_sample_fine_grained_features, 21 | ) 22 | from .point_head import build_point_head, roi_mask_point_loss 23 | 24 | 25 | def calculate_uncertainty(logits, classes): 26 | """ 27 | We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the 28 | foreground class in `classes`. 29 | 30 | Args: 31 | logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or 32 | class-agnostic, where R is the total number of predicted masks in all images and C is 33 | the number of foreground classes. The values are logits. 34 | classes (list): A list of length R that contains either predicted of ground truth class 35 | for eash predicted mask. 36 | 37 | Returns: 38 | scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with 39 | the most uncertain locations having the highest uncertainty score. 40 | """ 41 | if logits.shape[1] == 1: 42 | gt_class_logits = logits.clone() 43 | else: 44 | gt_class_logits = logits[ 45 | torch.arange(logits.shape[0], device=logits.device), classes 46 | ].unsqueeze(1) 47 | return -(torch.abs(gt_class_logits)) 48 | 49 | 50 | @ROI_HEADS_REGISTRY.register() 51 | class PointRendROIHeads(StandardROIHeads): 52 | """ 53 | The RoI heads class for PointRend instance segmentation models. 54 | 55 | In this class we redefine the mask head of `StandardROIHeads` leaving all other heads intact. 56 | To avoid namespace conflict with other heads we use names starting from `mask_` for all 57 | variables that correspond to the mask head in the class's namespace. 58 | """ 59 | 60 | def __init__(self, cfg, input_shape): 61 | # TODO use explicit args style 62 | super().__init__(cfg, input_shape) 63 | self._init_mask_head(cfg, input_shape) 64 | 65 | def _init_mask_head(self, cfg, input_shape): 66 | # fmt: off 67 | self.mask_on = cfg.MODEL.MASK_ON 68 | if not self.mask_on: 69 | return 70 | self.mask_coarse_in_features = cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES 71 | self.mask_coarse_side_size = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION 72 | self._feature_scales = {k: 1.0 / v.stride for k, v in input_shape.items()} 73 | # fmt: on 74 | 75 | in_channels = np.sum([input_shape[f].channels for f in self.mask_coarse_in_features]) 76 | self.mask_coarse_head = build_mask_head( 77 | cfg, 78 | ShapeSpec( 79 | channels=in_channels, 80 | width=self.mask_coarse_side_size, 81 | height=self.mask_coarse_side_size, 82 | ), 83 | ) 84 | self._init_point_head(cfg, input_shape) 85 | 86 | def _init_point_head(self, cfg, input_shape): 87 | # fmt: off 88 | self.mask_point_on = cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON 89 | if not self.mask_point_on: 90 | return 91 | assert cfg.MODEL.ROI_HEADS.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES 92 | self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES 93 | self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS 94 | self.mask_point_oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO 95 | self.mask_point_importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO 96 | # next two parameters are use in the adaptive subdivions inference procedure 97 | self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS 98 | self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS 99 | # fmt: on 100 | 101 | in_channels = np.sum([input_shape[f].channels for f in self.mask_point_in_features]) 102 | self.mask_point_head = build_point_head( 103 | cfg, ShapeSpec(channels=in_channels, width=1, height=1) 104 | ) 105 | 106 | def _forward_mask(self, features, instances): 107 | """ 108 | Forward logic of the mask prediction branch. 109 | 110 | Args: 111 | features (dict[str, Tensor]): #level input features for mask prediction 112 | instances (list[Instances]): the per-image instances to train/predict masks. 113 | In training, they can be the proposals. 114 | In inference, they can be the predicted boxes. 115 | 116 | Returns: 117 | In training, a dict of losses. 118 | In inference, update `instances` with new fields "pred_masks" and return it. 119 | """ 120 | if not self.mask_on: 121 | return {} if self.training else instances 122 | 123 | if self.training: 124 | proposals, _ = select_foreground_proposals(instances, self.num_classes) 125 | proposal_boxes = [x.proposal_boxes for x in proposals] 126 | mask_coarse_logits = self._forward_mask_coarse(features, proposal_boxes) 127 | 128 | losses = {"loss_mask": mask_rcnn_loss(mask_coarse_logits, proposals)} 129 | losses.update(self._forward_mask_point(features, mask_coarse_logits, proposals)) 130 | return losses 131 | else: 132 | pred_boxes = [x.pred_boxes for x in instances] 133 | mask_coarse_logits = self._forward_mask_coarse(features, pred_boxes) 134 | 135 | mask_logits = self._forward_mask_point(features, mask_coarse_logits, instances) 136 | mask_rcnn_inference(mask_logits, instances) 137 | return instances 138 | 139 | def _forward_mask_coarse(self, features, boxes): 140 | """ 141 | Forward logic of the coarse mask head. 142 | """ 143 | point_coords = generate_regular_grid_point_coords( 144 | np.sum(len(x) for x in boxes), self.mask_coarse_side_size, boxes[0].device 145 | ) 146 | mask_coarse_features_list = [features[k] for k in self.mask_coarse_in_features] 147 | features_scales = [self._feature_scales[k] for k in self.mask_coarse_in_features] 148 | # For regular grids of points, this function is equivalent to `len(features_list)' calls 149 | # of `ROIAlign` (with `SAMPLING_RATIO=2`), and concat the results. 150 | mask_features, _ = point_sample_fine_grained_features( 151 | mask_coarse_features_list, features_scales, boxes, point_coords 152 | ) 153 | return self.mask_coarse_head(mask_features) 154 | 155 | def _forward_mask_point(self, features, mask_coarse_logits, instances): 156 | """ 157 | Forward logic of the mask point head. 158 | """ 159 | if not self.mask_point_on: 160 | return {} if self.training else mask_coarse_logits 161 | 162 | mask_features_list = [features[k] for k in self.mask_point_in_features] 163 | features_scales = [self._feature_scales[k] for k in self.mask_point_in_features] 164 | 165 | if self.training: 166 | proposal_boxes = [x.proposal_boxes for x in instances] 167 | gt_classes = cat([x.gt_classes for x in instances]) 168 | with torch.no_grad(): 169 | point_coords = get_uncertain_point_coords_with_randomness( 170 | mask_coarse_logits, 171 | lambda logits: calculate_uncertainty(logits, gt_classes), 172 | self.mask_point_train_num_points, 173 | self.mask_point_oversample_ratio, 174 | self.mask_point_importance_sample_ratio, 175 | ) 176 | 177 | fine_grained_features, point_coords_wrt_image = point_sample_fine_grained_features( 178 | mask_features_list, features_scales, proposal_boxes, point_coords 179 | ) 180 | coarse_features = point_sample(mask_coarse_logits, point_coords, align_corners=False) 181 | point_logits = self.mask_point_head(fine_grained_features, coarse_features) 182 | return { 183 | "loss_mask_point": roi_mask_point_loss( 184 | point_logits, instances, point_coords_wrt_image 185 | ) 186 | } 187 | else: 188 | pred_boxes = [x.pred_boxes for x in instances] 189 | pred_classes = cat([x.pred_classes for x in instances]) 190 | # The subdivision code will fail with the empty list of boxes 191 | if len(pred_classes) == 0: 192 | return mask_coarse_logits 193 | 194 | mask_logits = mask_coarse_logits.clone() 195 | for subdivions_step in range(self.mask_point_subdivision_steps): 196 | mask_logits = interpolate( 197 | mask_logits, scale_factor=2, mode="bilinear", align_corners=False 198 | ) 199 | # If `mask_point_subdivision_num_points` is larger or equal to the 200 | # resolution of the next step, then we can skip this step 201 | H, W = mask_logits.shape[-2:] 202 | if ( 203 | self.mask_point_subdivision_num_points >= 4 * H * W 204 | and subdivions_step < self.mask_point_subdivision_steps - 1 205 | ): 206 | continue 207 | uncertainty_map = calculate_uncertainty(mask_logits, pred_classes) 208 | point_indices, point_coords = get_uncertain_point_coords_on_grid( 209 | uncertainty_map, self.mask_point_subdivision_num_points 210 | ) 211 | fine_grained_features, _ = point_sample_fine_grained_features( 212 | mask_features_list, features_scales, pred_boxes, point_coords 213 | ) 214 | coarse_features = point_sample( 215 | mask_coarse_logits, point_coords, align_corners=False 216 | ) 217 | point_logits = self.mask_point_head(fine_grained_features, coarse_features) 218 | 219 | # put mask point predictions to the right places on the upsampled grid. 220 | R, C, H, W = mask_logits.shape 221 | point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) 222 | mask_logits = ( 223 | mask_logits.reshape(R, C, H * W) 224 | .scatter_(2, point_indices, point_logits) 225 | .view(R, C, H, W) 226 | ) 227 | return mask_logits 228 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/semantic_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import numpy as np 3 | from typing import Dict 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from detectron2.layers import ShapeSpec, cat 9 | from detectron2.modeling import SEM_SEG_HEADS_REGISTRY 10 | 11 | from .point_features import ( 12 | get_uncertain_point_coords_on_grid, 13 | get_uncertain_point_coords_with_randomness, 14 | point_sample, 15 | ) 16 | from .point_head import build_point_head 17 | 18 | 19 | def calculate_uncertainty(sem_seg_logits): 20 | """ 21 | For each location of the prediction `sem_seg_logits` we estimate uncerainty as the 22 | difference between top first and top second predicted logits. 23 | 24 | Args: 25 | mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and 26 | C is the number of foreground classes. The values are logits. 27 | 28 | Returns: 29 | scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with 30 | the most uncertain locations having the highest uncertainty score. 31 | """ 32 | top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0] 33 | return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) 34 | 35 | 36 | @SEM_SEG_HEADS_REGISTRY.register() 37 | class PointRendSemSegHead(nn.Module): 38 | """ 39 | A semantic segmentation head that combines a head set in `POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME` 40 | and a point head set in `MODEL.POINT_HEAD.NAME`. 41 | """ 42 | 43 | def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): 44 | super().__init__() 45 | 46 | self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE 47 | 48 | self.coarse_sem_seg_head = SEM_SEG_HEADS_REGISTRY.get( 49 | cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME 50 | )(cfg, input_shape) 51 | self._init_point_head(cfg, input_shape) 52 | 53 | def _init_point_head(self, cfg, input_shape: Dict[str, ShapeSpec]): 54 | # fmt: off 55 | assert cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES 56 | feature_channels = {k: v.channels for k, v in input_shape.items()} 57 | self.in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES 58 | self.train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS 59 | self.oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO 60 | self.importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO 61 | self.subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS 62 | self.subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS 63 | # fmt: on 64 | 65 | in_channels = np.sum([feature_channels[f] for f in self.in_features]) 66 | self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) 67 | 68 | def forward(self, features, targets=None): 69 | coarse_sem_seg_logits = self.coarse_sem_seg_head.layers(features) 70 | 71 | if self.training: 72 | losses = self.coarse_sem_seg_head.losses(coarse_sem_seg_logits, targets) 73 | 74 | with torch.no_grad(): 75 | point_coords = get_uncertain_point_coords_with_randomness( 76 | coarse_sem_seg_logits, 77 | calculate_uncertainty, 78 | self.train_num_points, 79 | self.oversample_ratio, 80 | self.importance_sample_ratio, 81 | ) 82 | coarse_features = point_sample(coarse_sem_seg_logits, point_coords, align_corners=False) 83 | 84 | fine_grained_features = cat( 85 | [ 86 | point_sample(features[in_feature], point_coords, align_corners=False) 87 | for in_feature in self.in_features 88 | ], 89 | dim=1, 90 | ) 91 | point_logits = self.point_head(fine_grained_features, coarse_features) 92 | point_targets = ( 93 | point_sample( 94 | targets.unsqueeze(1).to(torch.float), 95 | point_coords, 96 | mode="nearest", 97 | align_corners=False, 98 | ) 99 | .squeeze(1) 100 | .to(torch.long) 101 | ) 102 | losses["loss_sem_seg_point"] = F.cross_entropy( 103 | point_logits, point_targets, reduction="mean", ignore_index=self.ignore_value 104 | ) 105 | return None, losses 106 | else: 107 | sem_seg_logits = coarse_sem_seg_logits.clone() 108 | for _ in range(self.subdivision_steps): 109 | sem_seg_logits = F.interpolate( 110 | sem_seg_logits, scale_factor=2, mode="bilinear", align_corners=False 111 | ) 112 | uncertainty_map = calculate_uncertainty(sem_seg_logits) 113 | point_indices, point_coords = get_uncertain_point_coords_on_grid( 114 | uncertainty_map, self.subdivision_num_points 115 | ) 116 | fine_grained_features = cat( 117 | [ 118 | point_sample(features[in_feature], point_coords, align_corners=False) 119 | for in_feature in self.in_features 120 | ] 121 | ) 122 | coarse_features = point_sample( 123 | coarse_sem_seg_logits, point_coords, align_corners=False 124 | ) 125 | point_logits = self.point_head(fine_grained_features, coarse_features) 126 | 127 | # put sem seg point predictions to the right places on the upsampled grid. 128 | N, C, H, W = sem_seg_logits.shape 129 | point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) 130 | sem_seg_logits = ( 131 | sem_seg_logits.reshape(N, C, H * W) 132 | .scatter_(2, point_indices, point_logits) 133 | .view(N, C, H, W) 134 | ) 135 | return sem_seg_logits, {} 136 | -------------------------------------------------------------------------------- /scripts/preproc.py: -------------------------------------------------------------------------------- 1 | """ 2 | PointRend background removal + normalization for car images 3 | (c) Alex Yu 2020 4 | Usage: python [-S scale=4.37] [-s size=128] 5 | outputs to *_mask.png, then *_mask_*.png (for other instances). 6 | also writes _crop.txt 7 | """ 8 | import sys 9 | import argparse 10 | import os 11 | import os.path as osp 12 | import json 13 | from math import floor, ceil 14 | 15 | ROOT_PATH = osp.dirname(os.path.abspath(__file__)) 16 | POINTREND_ROOT_PATH = osp.join(ROOT_PATH, "detectron2", "projects", "PointRend") 17 | INPUT_DIR = osp.join(ROOT_PATH, "..", "input") 18 | 19 | if not os.path.exists(POINTREND_ROOT_PATH): 20 | import urllib.request, zipfile 21 | 22 | print("Downloading minimal PointRend source package") 23 | zipfile_name = "pointrend_min.zip" 24 | urllib.request.urlretrieve( 25 | "https://alexyu.net/data/pointrend_min.zip", zipfile_name 26 | ) 27 | with zipfile.ZipFile(zipfile_name) as zipfile: 28 | zipfile.extractall(ROOT_PATH) 29 | os.remove(zipfile_name) 30 | 31 | sys.path.insert(0, POINTREND_ROOT_PATH) 32 | 33 | try: 34 | import detectron2 35 | except: 36 | print( 37 | "Please install Detectron2 by selecting the right version", 38 | "from https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md", 39 | ) 40 | # import PointRend project 41 | import point_rend 42 | 43 | from detectron2.utils.logger import setup_logger 44 | 45 | setup_logger() 46 | 47 | # import some common libraries 48 | import numpy as np 49 | import cv2 50 | import torch 51 | import tqdm 52 | import glob 53 | 54 | from matplotlib import pyplot as plt 55 | import matplotlib.patches as patches 56 | 57 | # import some common detectron2 utilities 58 | from detectron2 import model_zoo 59 | from detectron2.engine import DefaultPredictor 60 | from detectron2.config import get_cfg 61 | from detectron2.utils.visualizer import Visualizer, ColorMode 62 | from detectron2.data import MetadataCatalog 63 | 64 | 65 | def _crop_image(img, rect, const_border=False, value=0): 66 | """ 67 | Image cropping helper 68 | """ 69 | x, y, w, h = rect 70 | 71 | left = abs(x) if x < 0 else 0 72 | top = abs(y) if y < 0 else 0 73 | right = abs(img.shape[1] - (x + w)) if x + w >= img.shape[1] else 0 74 | bottom = abs(img.shape[0] - (y + h)) if y + h >= img.shape[0] else 0 75 | 76 | color = [value] * img.shape[2] if const_border else None 77 | new_img = cv2.copyMakeBorder( 78 | img, 79 | top, 80 | bottom, 81 | left, 82 | right, 83 | cv2.BORDER_CONSTANT if const_border else cv2.BORDER_REPLICATE, 84 | value=color, 85 | ) 86 | if len(new_img.shape) == 2: 87 | new_img = new_img[..., None] 88 | 89 | x = x + left 90 | y = y + top 91 | 92 | return new_img[y : (y + h), x : (x + w), :] 93 | 94 | 95 | def _is_image_path(f): 96 | return ( 97 | f.endswith(".jpg") 98 | or f.endswith(".jpeg") 99 | or f.endswith(".png") 100 | or f.endswith(".bmp") 101 | or f.endswith(".tiff") 102 | or f.endswith(".gif") 103 | ) 104 | 105 | 106 | class PointRendWrapper: 107 | def __init__(self, filter_class=-1): 108 | """ 109 | :param filter_class output only intances of filter_class (-1 to disable). Note: class 0 is person. 110 | """ 111 | self.filter_class = filter_class 112 | self.coco_metadata = MetadataCatalog.get("coco_2017_val") 113 | self.cfg = get_cfg() 114 | 115 | # Add PointRend-specific config 116 | 117 | point_rend.add_pointrend_config(self.cfg) 118 | 119 | # Load a config from file 120 | self.cfg.merge_from_file( 121 | os.path.join( 122 | POINTREND_ROOT_PATH, 123 | "configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml", 124 | ) 125 | ) 126 | self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model 127 | # Use a model from PointRend model zoo: https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend#pretrained-models 128 | self.cfg.MODEL.WEIGHTS = "detectron2://PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco/164955410/model_final_3c3198.pkl" 129 | self.predictor = DefaultPredictor(self.cfg) 130 | 131 | def segment(self, im, out_name="", visualize=False): 132 | """ 133 | Run PointRend 134 | :param out_name if set, writes segments B&W mask to this image file 135 | :param visualize if set, and out_name is set, outputs visualization rater than B&W mask 136 | """ 137 | outputs = self.predictor(im) 138 | 139 | insts = outputs["instances"] 140 | if self.filter_class != -1: 141 | insts = insts[insts.pred_classes == self.filter_class] # 0 is person 142 | if visualize: 143 | v = Visualizer( 144 | im[:, :, ::-1], 145 | self.coco_metadata, 146 | scale=1.2, 147 | instance_mode=ColorMode.IMAGE_BW, 148 | ) 149 | 150 | point_rend_result = v.draw_instance_predictions(insts.to("cpu")).get_image() 151 | if out_name: 152 | cv2.imwrite(out_name + ".png", point_rend_result[:, :, ::-1]) 153 | return point_rend_result[:, :, ::-1] 154 | else: 155 | im_names = [] 156 | masks = [] 157 | for i in range(len(insts)): 158 | mask = insts[i].pred_masks.to("cpu").permute( 159 | 1, 2, 0 160 | ).numpy() * np.uint8(255) 161 | if out_name: 162 | im_name = out_name 163 | if i: 164 | im_name += "_" + str(i) + ".png" 165 | else: 166 | im_name += ".png" 167 | im_names.append(im_name) 168 | cv2.imwrite(im_name, mask) 169 | masks.append(mask) 170 | if out_name: 171 | with open(out_name + ".json", "w") as fp: 172 | json.dump({"files": im_names}, fp) 173 | return masks 174 | 175 | 176 | if __name__ == "__main__": 177 | parser = argparse.ArgumentParser() 178 | parser.add_argument( 179 | "--coco_class", 180 | type=int, 181 | default=2, 182 | help="COCO class wanted (0 = human, 2 = car)", 183 | ) 184 | parser.add_argument( 185 | "--size", 186 | "-s", 187 | type=int, 188 | default=128, 189 | help="output image side length (will be square)", 190 | ) 191 | parser.add_argument( 192 | "--scale", 193 | "-S", 194 | type=float, 195 | default=4.37, 196 | help="bbox scaling rel minor axis of fitted ellipse. " 197 | + "Will take max radius from this and major_scale.", 198 | ) 199 | parser.add_argument( 200 | "--major_scale", 201 | "-M", 202 | type=float, 203 | default=0.8, 204 | help="bbox scaling rel major axis of fitted ellipse. " 205 | + "Will take max radius from this and major_scale.", 206 | ) 207 | parser.add_argument( 208 | "--const_border", 209 | action="store_true", 210 | help="constant white border instead of replicate pad", 211 | ) 212 | args = parser.parse_args() 213 | 214 | pointrend = PointRendWrapper(args.coco_class) 215 | 216 | input_images = glob.glob(os.path.join(INPUT_DIR, "*")) 217 | input_images = [ 218 | f 219 | for f in input_images 220 | if _is_image_path(f) and not f.endswith("_normalize.png") 221 | ] 222 | 223 | os.makedirs(INPUT_DIR, exist_ok=True) 224 | 225 | for image_path in tqdm.tqdm(input_images): 226 | print(image_path) 227 | im = cv2.imread(image_path) 228 | img_no_ext = os.path.split(os.path.splitext(image_path)[0])[1] 229 | masks = pointrend.segment(im) 230 | if len(masks) == 0: 231 | print("WARNING: PointRend detected no objects in", image_path, "skipping") 232 | continue 233 | mask_main = masks[0] 234 | assert mask_main.shape[:2] == im.shape[:2] 235 | assert mask_main.shape[-1] == 1 236 | assert mask_main.dtype == "uint8" 237 | 238 | # mask is (H, W, 1) with values{0, 255} 239 | 240 | cnt, _ = cv2.findContours(mask_main, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 241 | ellipse = cv2.fitEllipse(cnt[0]) 242 | cen_pt = ellipse[0] 243 | min_ax, max_ax = min(ellipse[1]), max(ellipse[1]) 244 | 245 | # imgvis = np.zeros((*im.shape[:2], 3), dtype=np.uint8) 246 | # cv2.drawContours(imgvis, cnt, -1, (0,255,0), 3) 247 | # imgvis = cv2.ellipse(imgvis, ellipse, (255,0,0),2) 248 | # cv2.imwrite('vs.png', imgvis) 249 | # print(len(cnt), cnt[0].shape) 250 | # print(cen_pt, min_ax) 251 | 252 | # rows = np.any(mask_main, axis=1) 253 | # cols = np.any(mask_main, axis=0) 254 | # rnz = np.where(rows)[0] 255 | # cnz = np.where(cols)[0] 256 | # if len(rnz) == 0: 257 | # cmin = rmin = 0 258 | # cmax = mask_main.shape[-1] 259 | # rmax = mask_main.shape[-2] 260 | # else: 261 | # rmin, rmax = rnz[[0, -1]] 262 | # cmin, cmax = cnz[[0, -1]] 263 | # rcen = int(round((rmin + rmax) * 0.5)) 264 | # ccen = int(round((cmin + cmax) * 0.5)) 265 | # rad = int(ceil(min(cmax - cmin, rmax - rmin) * args.scale * 0.5)) 266 | 267 | ccen, rcen = map(int, map(round, cen_pt)) 268 | rad = max(min_ax * args.scale, max_ax * args.major_scale) * 0.5 269 | rad = int(ceil(rad)) 270 | rect_main = [ccen - rad, rcen - rad, 2 * rad, 2 * rad] 271 | 272 | im_crop = _crop_image(im, rect_main, args.const_border, value=255) 273 | mask_crop = _crop_image(mask_main, rect_main, True, value=0) 274 | 275 | mask_flt = mask_crop.astype(np.float32) / 255.0 276 | masked_crop = im_crop.astype(np.float32) * mask_flt + 255 * (1.0 - mask_flt) 277 | masked_crop = masked_crop.astype(np.uint8) 278 | 279 | # im_crop = cv2.resize(im_crop, (args.size, args.size), interpolation=cv2.INTER_LINEAR) 280 | mask_crop = cv2.resize( 281 | mask_crop, (args.size, args.size), interpolation=cv2.INTER_LINEAR 282 | ) 283 | masked_crop = cv2.resize( 284 | masked_crop, (args.size, args.size), interpolation=cv2.INTER_LINEAR 285 | ) 286 | 287 | if len(mask_crop.nonzero()[0]) == 0: 288 | print("WARNING: cropped mask is empty for", image_path, "skipping") 289 | continue 290 | 291 | # out_im_path = os.path.join(INPUT_DIR, 292 | # img_no_ext + ".jpg") 293 | # out_mask_path = os.path.join(INPUT_DIR, 294 | # img_no_ext + "_mask.png") 295 | out_masked_path = os.path.join(INPUT_DIR, img_no_ext + "_normalize.png") 296 | # cv2.imwrite(out_im_path, im_crop) 297 | # cv2.imwrite(out_mask_path, mask_crop) 298 | cv2.imwrite(out_masked_path, masked_crop) 299 | 300 | # np.savetxt(os.path.join(INPUT_DIR, 301 | # img_no_ext + "_crop.txt"), 302 | # rect_main, 303 | # fmt='%.18f') 304 | -------------------------------------------------------------------------------- /src/data/DVRDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import glob 5 | import imageio 6 | import numpy as np 7 | import cv2 8 | from util import get_image_to_tensor_balanced, get_mask_to_tensor 9 | 10 | 11 | class DVRDataset(torch.utils.data.Dataset): 12 | """ 13 | Dataset from DVR (Niemeyer et al. 2020) 14 | Provides 3D-R2N2 and NMR renderings 15 | """ 16 | 17 | def __init__( 18 | self, 19 | path, 20 | stage="train", 21 | list_prefix="softras_", 22 | image_size=None, 23 | sub_format="shapenet", 24 | scale_focal=True, 25 | max_imgs=100000, 26 | z_near=1.2, 27 | z_far=4.0, 28 | skip_step=None, 29 | ): 30 | """ 31 | :param path dataset root path, contains metadata.yml 32 | :param stage train | val | test 33 | :param list_prefix prefix for split lists: [train, val, test].lst 34 | :param image_size result image size (resizes if different); None to keep original size 35 | :param sub_format shapenet | dtu dataset sub-type. 36 | :param scale_focal if true, assume focal length is specified for 37 | image of side length 2 instead of actual image size. This is used 38 | where image coordinates are placed in [-1, 1]. 39 | """ 40 | super().__init__() 41 | self.base_path = path 42 | assert os.path.exists(self.base_path) 43 | 44 | cats = [x for x in glob.glob(os.path.join(path, "*")) if os.path.isdir(x)] 45 | 46 | if stage == "train": 47 | file_lists = [os.path.join(x, list_prefix + "train.lst") for x in cats] 48 | elif stage == "val": 49 | file_lists = [os.path.join(x, list_prefix + "val.lst") for x in cats] 50 | elif stage == "test": 51 | file_lists = [os.path.join(x, list_prefix + "test.lst") for x in cats] 52 | 53 | all_objs = [] 54 | for file_list in file_lists: 55 | if not os.path.exists(file_list): 56 | continue 57 | base_dir = os.path.dirname(file_list) 58 | cat = os.path.basename(base_dir) 59 | with open(file_list, "r") as f: 60 | objs = [(cat, os.path.join(base_dir, x.strip())) for x in f.readlines()] 61 | all_objs.extend(objs) 62 | 63 | self.all_objs = all_objs 64 | self.stage = stage 65 | 66 | self.image_to_tensor = get_image_to_tensor_balanced() 67 | self.mask_to_tensor = get_mask_to_tensor() 68 | print( 69 | "Loading DVR dataset", 70 | self.base_path, 71 | "stage", 72 | stage, 73 | len(self.all_objs), 74 | "objs", 75 | "type:", 76 | sub_format, 77 | ) 78 | 79 | self.image_size = image_size 80 | if sub_format == "dtu": 81 | self._coord_trans_world = torch.tensor( 82 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], 83 | dtype=torch.float32, 84 | ) 85 | self._coord_trans_cam = torch.tensor( 86 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], 87 | dtype=torch.float32, 88 | ) 89 | else: 90 | self._coord_trans_world = torch.tensor( 91 | [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], 92 | dtype=torch.float32, 93 | ) 94 | self._coord_trans_cam = torch.tensor( 95 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], 96 | dtype=torch.float32, 97 | ) 98 | self.sub_format = sub_format 99 | self.scale_focal = scale_focal 100 | self.max_imgs = max_imgs 101 | 102 | self.z_near = z_near 103 | self.z_far = z_far 104 | self.lindisp = False 105 | 106 | def __len__(self): 107 | return len(self.all_objs) 108 | 109 | def __getitem__(self, index): 110 | cat, root_dir = self.all_objs[index] 111 | 112 | rgb_paths = [ 113 | x 114 | for x in glob.glob(os.path.join(root_dir, "image", "*")) 115 | if (x.endswith(".jpg") or x.endswith(".png")) 116 | ] 117 | rgb_paths = sorted(rgb_paths) 118 | mask_paths = sorted(glob.glob(os.path.join(root_dir, "mask", "*.png"))) 119 | if len(mask_paths) == 0: 120 | mask_paths = [None] * len(rgb_paths) 121 | 122 | if len(rgb_paths) <= self.max_imgs: 123 | sel_indices = np.arange(len(rgb_paths)) 124 | else: 125 | sel_indices = np.random.choice(len(rgb_paths), self.max_imgs, replace=False) 126 | rgb_paths = [rgb_paths[i] for i in sel_indices] 127 | mask_paths = [mask_paths[i] for i in sel_indices] 128 | 129 | cam_path = os.path.join(root_dir, "cameras.npz") 130 | all_cam = np.load(cam_path) 131 | 132 | all_imgs = [] 133 | all_poses = [] 134 | all_masks = [] 135 | all_bboxes = [] 136 | focal = None 137 | if self.sub_format != "shapenet": 138 | # Prepare to average intrinsics over images 139 | fx, fy, cx, cy = 0.0, 0.0, 0.0, 0.0 140 | 141 | for idx, (rgb_path, mask_path) in enumerate(zip(rgb_paths, mask_paths)): 142 | i = sel_indices[idx] 143 | img = imageio.imread(rgb_path)[..., :3] 144 | if self.scale_focal: 145 | x_scale = img.shape[1] / 2.0 146 | y_scale = img.shape[0] / 2.0 147 | xy_delta = 1.0 148 | else: 149 | x_scale = y_scale = 1.0 150 | xy_delta = 0.0 151 | 152 | if mask_path is not None: 153 | mask = imageio.imread(mask_path) 154 | if len(mask.shape) == 2: 155 | mask = mask[..., None] 156 | mask = mask[..., :1] 157 | if self.sub_format == "dtu": 158 | # Decompose projection matrix 159 | # DVR uses slightly different format for DTU set 160 | P = all_cam["world_mat_" + str(i)] 161 | P = P[:3] 162 | 163 | K, R, t = cv2.decomposeProjectionMatrix(P)[:3] 164 | K = K / K[2, 2] 165 | 166 | pose = np.eye(4, dtype=np.float32) 167 | pose[:3, :3] = R.transpose() 168 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 169 | 170 | scale_mtx = all_cam.get("scale_mat_" + str(i)) 171 | if scale_mtx is not None: 172 | norm_trans = scale_mtx[:3, 3:] 173 | norm_scale = np.diagonal(scale_mtx[:3, :3])[..., None] 174 | 175 | pose[:3, 3:] -= norm_trans 176 | pose[:3, 3:] /= norm_scale 177 | 178 | fx += torch.tensor(K[0, 0]) * x_scale 179 | fy += torch.tensor(K[1, 1]) * y_scale 180 | cx += (torch.tensor(K[0, 2]) + xy_delta) * x_scale 181 | cy += (torch.tensor(K[1, 2]) + xy_delta) * y_scale 182 | else: 183 | # ShapeNet 184 | wmat_inv_key = "world_mat_inv_" + str(i) 185 | wmat_key = "world_mat_" + str(i) 186 | if wmat_inv_key in all_cam: 187 | extr_inv_mtx = all_cam[wmat_inv_key] 188 | else: 189 | extr_inv_mtx = all_cam[wmat_key] 190 | if extr_inv_mtx.shape[0] == 3: 191 | extr_inv_mtx = np.vstack((extr_inv_mtx, np.array([0, 0, 0, 1]))) 192 | extr_inv_mtx = np.linalg.inv(extr_inv_mtx) 193 | 194 | intr_mtx = all_cam["camera_mat_" + str(i)] 195 | fx, fy = intr_mtx[0, 0], intr_mtx[1, 1] 196 | assert abs(fx - fy) < 1e-9 197 | fx = fx * x_scale 198 | if focal is None: 199 | focal = fx 200 | else: 201 | assert abs(fx - focal) < 1e-5 202 | pose = extr_inv_mtx 203 | 204 | pose = ( 205 | self._coord_trans_world 206 | @ torch.tensor(pose, dtype=torch.float32) 207 | @ self._coord_trans_cam 208 | ) 209 | 210 | img_tensor = self.image_to_tensor(img) 211 | if mask_path is not None: 212 | mask_tensor = self.mask_to_tensor(mask) 213 | 214 | rows = np.any(mask, axis=1) 215 | cols = np.any(mask, axis=0) 216 | rnz = np.where(rows)[0] 217 | cnz = np.where(cols)[0] 218 | if len(rnz) == 0: 219 | raise RuntimeError( 220 | "ERROR: Bad image at", rgb_path, "please investigate!" 221 | ) 222 | rmin, rmax = rnz[[0, -1]] 223 | cmin, cmax = cnz[[0, -1]] 224 | bbox = torch.tensor([cmin, rmin, cmax, rmax], dtype=torch.float32) 225 | all_masks.append(mask_tensor) 226 | all_bboxes.append(bbox) 227 | 228 | all_imgs.append(img_tensor) 229 | all_poses.append(pose) 230 | 231 | if self.sub_format != "shapenet": 232 | fx /= len(rgb_paths) 233 | fy /= len(rgb_paths) 234 | cx /= len(rgb_paths) 235 | cy /= len(rgb_paths) 236 | focal = torch.tensor((fx, fy), dtype=torch.float32) 237 | c = torch.tensor((cx, cy), dtype=torch.float32) 238 | all_bboxes = None 239 | elif mask_path is not None: 240 | all_bboxes = torch.stack(all_bboxes) 241 | 242 | all_imgs = torch.stack(all_imgs) 243 | all_poses = torch.stack(all_poses) 244 | if len(all_masks) > 0: 245 | all_masks = torch.stack(all_masks) 246 | else: 247 | all_masks = None 248 | 249 | if self.image_size is not None and all_imgs.shape[-2:] != self.image_size: 250 | scale = self.image_size[0] / all_imgs.shape[-2] 251 | focal *= scale 252 | if self.sub_format != "shapenet": 253 | c *= scale 254 | elif mask_path is not None: 255 | all_bboxes *= scale 256 | 257 | all_imgs = F.interpolate(all_imgs, size=self.image_size, mode="area") 258 | if all_masks is not None: 259 | all_masks = F.interpolate(all_masks, size=self.image_size, mode="area") 260 | 261 | result = { 262 | "path": root_dir, 263 | "img_id": index, 264 | "focal": focal, 265 | "images": all_imgs, 266 | "poses": all_poses, 267 | } 268 | if all_masks is not None: 269 | result["masks"] = all_masks 270 | if self.sub_format != "shapenet": 271 | result["c"] = c 272 | else: 273 | result["bbox"] = all_bboxes 274 | return result 275 | -------------------------------------------------------------------------------- /src/data/MultiObjectDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import imageio 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torchvision import transforms 10 | 11 | from util import get_image_to_tensor_balanced, get_mask_to_tensor 12 | 13 | 14 | class MultiObjectDataset(torch.utils.data.Dataset): 15 | """Synthetic dataset of scenes with multiple Shapenet objects""" 16 | 17 | def __init__(self, path, stage="train", z_near=4, z_far=9, n_views=None): 18 | super().__init__() 19 | path = os.path.join(path, stage) 20 | self.base_path = path 21 | print("Loading NeRF synthetic dataset", self.base_path) 22 | trans_files = [] 23 | TRANS_FILE = "transforms.json" 24 | for root, directories, filenames in os.walk(self.base_path): 25 | if TRANS_FILE in filenames: 26 | trans_files.append(os.path.join(root, TRANS_FILE)) 27 | self.trans_files = trans_files 28 | self.image_to_tensor = get_image_to_tensor_balanced() 29 | self.mask_to_tensor = get_mask_to_tensor() 30 | 31 | self.z_near = z_near 32 | self.z_far = z_far 33 | self.lindisp = False 34 | self.n_views = n_views 35 | 36 | print("{} instances in split {}".format(len(self.trans_files), stage)) 37 | 38 | def __len__(self): 39 | return len(self.trans_files) 40 | 41 | def _check_valid(self, index): 42 | if self.n_views is None: 43 | return True 44 | trans_file = self.trans_files[index] 45 | dir_path = os.path.dirname(trans_file) 46 | try: 47 | with open(trans_file, "r") as f: 48 | transform = json.load(f) 49 | except Exception as e: 50 | print("Problematic transforms.json file", trans_file) 51 | print("JSON loading exception", e) 52 | return False 53 | if len(transform["frames"]) != self.n_views: 54 | return False 55 | if len(glob.glob(os.path.join(dir_path, "*.png"))) != self.n_views: 56 | return False 57 | return True 58 | 59 | def __getitem__(self, index): 60 | if not self._check_valid(index): 61 | return {} 62 | 63 | trans_file = self.trans_files[index] 64 | dir_path = os.path.dirname(trans_file) 65 | with open(trans_file, "r") as f: 66 | transform = json.load(f) 67 | 68 | all_imgs = [] 69 | all_bboxes = [] 70 | all_masks = [] 71 | all_poses = [] 72 | for frame in transform["frames"]: 73 | fpath = frame["file_path"] 74 | basename = os.path.splitext(os.path.basename(fpath))[0] 75 | obj_path = os.path.join(dir_path, "{}_obj.png".format(basename)) 76 | img = imageio.imread(obj_path) 77 | mask = self.mask_to_tensor(img[..., 3]) 78 | rows = np.any(img, axis=1) 79 | cols = np.any(img, axis=0) 80 | rnz = np.where(rows)[0] 81 | cnz = np.where(cols)[0] 82 | if len(rnz) == 0: 83 | cmin = rmin = 0 84 | cmax = mask.shape[-1] 85 | rmax = mask.shape[-2] 86 | else: 87 | rmin, rmax = rnz[[0, -1]] 88 | cmin, cmax = cnz[[0, -1]] 89 | bbox = torch.tensor([cmin, rmin, cmax, rmax], dtype=torch.float32) 90 | 91 | img_tensor = self.image_to_tensor(img[..., :3]) 92 | img = img_tensor * mask + ( 93 | 1.0 - mask 94 | ) # solid white background where transparent 95 | all_imgs.append(img) 96 | all_bboxes.append(bbox) 97 | all_masks.append(mask) 98 | all_poses.append(torch.tensor(frame["transform_matrix"])) 99 | imgs = torch.stack(all_imgs) 100 | masks = torch.stack(all_masks) 101 | bboxes = torch.stack(all_bboxes) 102 | poses = torch.stack(all_poses) 103 | 104 | H, W = imgs.shape[-2:] 105 | camera_angle_x = transform.get("camera_angle_x") 106 | focal = 0.5 * W / np.tan(0.5 * camera_angle_x) 107 | 108 | result = { 109 | "path": dir_path, 110 | "img_id": index, 111 | "focal": focal, 112 | "images": imgs, 113 | "masks": masks, 114 | "bbox": bboxes, 115 | "poses": poses, 116 | } 117 | return result 118 | -------------------------------------------------------------------------------- /src/data/SRNDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import glob 5 | import imageio 6 | import numpy as np 7 | from util import get_image_to_tensor_balanced, get_mask_to_tensor 8 | 9 | 10 | class SRNDataset(torch.utils.data.Dataset): 11 | """ 12 | Dataset from SRN (V. Sitzmann et al. 2020) 13 | """ 14 | 15 | def __init__( 16 | self, path, stage="train", image_size=(128, 128), world_scale=1.0, 17 | ): 18 | """ 19 | :param stage train | val | test 20 | :param image_size result image size (resizes if different) 21 | :param world_scale amount to scale entire world by 22 | """ 23 | super().__init__() 24 | self.base_path = path + "_" + stage 25 | self.dataset_name = os.path.basename(path) 26 | 27 | print("Loading SRN dataset", self.base_path, "name:", self.dataset_name) 28 | self.stage = stage 29 | assert os.path.exists(self.base_path) 30 | 31 | is_chair = "chair" in self.dataset_name 32 | if is_chair and stage == "train": 33 | # Ugly thing from SRN's public dataset 34 | tmp = os.path.join(self.base_path, "chairs_2.0_train") 35 | if os.path.exists(tmp): 36 | self.base_path = tmp 37 | 38 | self.intrins = sorted( 39 | glob.glob(os.path.join(self.base_path, "*", "intrinsics.txt")) 40 | ) 41 | self.image_to_tensor = get_image_to_tensor_balanced() 42 | self.mask_to_tensor = get_mask_to_tensor() 43 | 44 | self.image_size = image_size 45 | self.world_scale = world_scale 46 | self._coord_trans = torch.diag( 47 | torch.tensor([1, -1, -1, 1], dtype=torch.float32) 48 | ) 49 | 50 | if is_chair: 51 | self.z_near = 1.25 52 | self.z_far = 2.75 53 | else: 54 | self.z_near = 0.8 55 | self.z_far = 1.8 56 | self.lindisp = False 57 | 58 | def __len__(self): 59 | return len(self.intrins) 60 | 61 | def __getitem__(self, index): 62 | intrin_path = self.intrins[index] 63 | dir_path = os.path.dirname(intrin_path) 64 | rgb_paths = sorted(glob.glob(os.path.join(dir_path, "rgb", "*"))) 65 | pose_paths = sorted(glob.glob(os.path.join(dir_path, "pose", "*"))) 66 | 67 | assert len(rgb_paths) == len(pose_paths) 68 | 69 | with open(intrin_path, "r") as intrinfile: 70 | lines = intrinfile.readlines() 71 | focal, cx, cy, _ = map(float, lines[0].split()) 72 | height, width = map(int, lines[-1].split()) 73 | 74 | all_imgs = [] 75 | all_poses = [] 76 | all_masks = [] 77 | all_bboxes = [] 78 | for rgb_path, pose_path in zip(rgb_paths, pose_paths): 79 | img = imageio.imread(rgb_path)[..., :3] 80 | img_tensor = self.image_to_tensor(img) 81 | mask = (img != 255).all(axis=-1)[..., None].astype(np.uint8) * 255 82 | mask_tensor = self.mask_to_tensor(mask) 83 | 84 | pose = torch.from_numpy( 85 | np.loadtxt(pose_path, dtype=np.float32).reshape(4, 4) 86 | ) 87 | pose = pose @ self._coord_trans 88 | 89 | rows = np.any(mask, axis=1) 90 | cols = np.any(mask, axis=0) 91 | rnz = np.where(rows)[0] 92 | cnz = np.where(cols)[0] 93 | if len(rnz) == 0: 94 | raise RuntimeError( 95 | "ERROR: Bad image at", rgb_path, "please investigate!" 96 | ) 97 | rmin, rmax = rnz[[0, -1]] 98 | cmin, cmax = cnz[[0, -1]] 99 | bbox = torch.tensor([cmin, rmin, cmax, rmax], dtype=torch.float32) 100 | 101 | all_imgs.append(img_tensor) 102 | all_masks.append(mask_tensor) 103 | all_poses.append(pose) 104 | all_bboxes.append(bbox) 105 | 106 | all_imgs = torch.stack(all_imgs) 107 | all_poses = torch.stack(all_poses) 108 | all_masks = torch.stack(all_masks) 109 | all_bboxes = torch.stack(all_bboxes) 110 | 111 | if all_imgs.shape[-2:] != self.image_size: 112 | scale = self.image_size[0] / all_imgs.shape[-2] 113 | focal *= scale 114 | cx *= scale 115 | cy *= scale 116 | all_bboxes *= scale 117 | 118 | all_imgs = F.interpolate(all_imgs, size=self.image_size, mode="area") 119 | all_masks = F.interpolate(all_masks, size=self.image_size, mode="area") 120 | 121 | if self.world_scale != 1.0: 122 | focal *= self.world_scale 123 | all_poses[:, :3, 3] *= self.world_scale 124 | focal = torch.tensor(focal, dtype=torch.float32) 125 | 126 | result = { 127 | "path": dir_path, 128 | "img_id": index, 129 | "focal": focal, 130 | "c": torch.tensor([cx, cy], dtype=torch.float32), 131 | "images": all_imgs, 132 | "masks": all_masks, 133 | "bbox": all_bboxes, 134 | "poses": all_poses, 135 | } 136 | return result 137 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .MultiObjectDataset import MultiObjectDataset 4 | from .DVRDataset import DVRDataset 5 | from .SRNDataset import SRNDataset 6 | 7 | from .data_util import ColorJitterDataset 8 | 9 | 10 | def get_split_dataset(dataset_type, datadir, want_split="all", training=True, **kwargs): 11 | """ 12 | Retrieved desired dataset class 13 | :param dataset_type dataset type name (srn|dvr|dvr_gen, etc) 14 | :param datadir root directory name for the dataset. For SRN/multi_obj data: 15 | if data is in dir/cars_train, dir/cars_test, ... then put dir/cars 16 | :param want_split root directory name for the dataset 17 | :param training set to False in eval scripts 18 | """ 19 | dset_class, train_aug = None, None 20 | flags, train_aug_flags = {}, {} 21 | 22 | if dataset_type == "srn": 23 | # For ShapeNet single-category (from SRN) 24 | dset_class = SRNDataset 25 | elif dataset_type == "multi_obj": 26 | # For multiple-object 27 | dset_class = MultiObjectDataset 28 | elif dataset_type.startswith("dvr"): 29 | # For ShapeNet 64x64 30 | dset_class = DVRDataset 31 | if dataset_type == "dvr_gen": 32 | # For generalization training (train some categories, eval on others) 33 | flags["list_prefix"] = "gen_" 34 | elif dataset_type == "dvr_dtu": 35 | # DTU dataset 36 | flags["list_prefix"] = "new_" 37 | if training: 38 | flags["max_imgs"] = 49 39 | flags["sub_format"] = "dtu" 40 | flags["scale_focal"] = False 41 | flags["z_near"] = 0.1 42 | flags["z_far"] = 5.0 43 | # Apply color jitter during train 44 | train_aug = ColorJitterDataset 45 | train_aug_flags = {"extra_inherit_attrs": ["sub_format"]} 46 | else: 47 | raise NotImplementedError("Unsupported dataset type", dataset_type) 48 | 49 | want_train = want_split != "val" and want_split != "test" 50 | want_val = want_split != "train" and want_split != "test" 51 | want_test = want_split != "train" and want_split != "val" 52 | 53 | if want_train: 54 | train_set = dset_class(datadir, stage="train", **flags, **kwargs) 55 | if train_aug is not None: 56 | train_set = train_aug(train_set, **train_aug_flags) 57 | 58 | if want_val: 59 | val_set = dset_class(datadir, stage="val", **flags, **kwargs) 60 | 61 | if want_test: 62 | test_set = dset_class(datadir, stage="test", **flags, **kwargs) 63 | 64 | if want_split == "train": 65 | return train_set 66 | elif want_split == "val": 67 | return val_set 68 | elif want_split == "test": 69 | return test_set 70 | return train_set, val_set, test_set 71 | -------------------------------------------------------------------------------- /src/data/data_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import torchvision.transforms.functional_tensor as F_t 5 | import torchvision.transforms.functional as TF 6 | import numpy as np 7 | import imageio 8 | 9 | # from util import GaussianBlur 10 | 11 | 12 | class ColorJitterDataset(torch.utils.data.Dataset): 13 | def __init__( 14 | self, 15 | base_dset, 16 | hue_range=0.1, 17 | saturation_range=0.1, 18 | brightness_range=0.1, 19 | contrast_range=0.1, 20 | extra_inherit_attrs=[], 21 | ): 22 | self.hue_range = [-hue_range, hue_range] 23 | self.saturation_range = [1 - saturation_range, 1 + saturation_range] 24 | self.brightness_range = [1 - brightness_range, 1 + brightness_range] 25 | self.contrast_range = [1 - contrast_range, 1 + contrast_range] 26 | inherit_attrs = ["z_near", "z_far", "lindisp", "base_path", "image_to_tensor"] 27 | inherit_attrs.extend(extra_inherit_attrs) 28 | 29 | self.base_dset = base_dset 30 | for inherit_attr in inherit_attrs: 31 | setattr(self, inherit_attr, getattr(self.base_dset, inherit_attr)) 32 | 33 | def apply_color_jitter(self, images): 34 | # apply the same color jitter over batch of images 35 | hue_factor = np.random.uniform(*self.hue_range) 36 | saturation_factor = np.random.uniform(*self.saturation_range) 37 | brightness_factor = np.random.uniform(*self.brightness_range) 38 | contrast_factor = np.random.uniform(*self.contrast_range) 39 | for i in range(len(images)): 40 | tmp = (images[i] + 1.0) * 0.5 41 | tmp = F_t.adjust_saturation(tmp, saturation_factor) 42 | tmp = F_t.adjust_hue(tmp, hue_factor) 43 | tmp = F_t.adjust_contrast(tmp, contrast_factor) 44 | tmp = F_t.adjust_brightness(tmp, brightness_factor) 45 | images[i] = tmp * 2.0 - 1.0 46 | return images 47 | 48 | def __len__(self): 49 | return len(self.base_dset) 50 | 51 | def __getitem__(self, idx): 52 | data = self.base_dset[idx] 53 | data["images"] = self.apply_color_jitter(data["images"]) 54 | return data 55 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import PixelNeRFNet 2 | 3 | 4 | def make_model(conf, *args, **kwargs): 5 | """ Placeholder to allow more model types """ 6 | model_type = conf.get_string("type", "pixelnerf") # single 7 | if model_type == "pixelnerf": 8 | net = PixelNeRFNet(conf, *args, **kwargs) 9 | else: 10 | raise NotImplementedError("Unsupported model type", model_type) 11 | return net 12 | -------------------------------------------------------------------------------- /src/model/code.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.autograd.profiler as profiler 4 | 5 | 6 | class PositionalEncoding(torch.nn.Module): 7 | """ 8 | Implement NeRF's positional encoding 9 | """ 10 | 11 | def __init__(self, num_freqs=6, d_in=3, freq_factor=np.pi, include_input=True): 12 | super().__init__() 13 | self.num_freqs = num_freqs 14 | self.d_in = d_in 15 | self.freqs = freq_factor * 2.0 ** torch.arange(0, num_freqs) 16 | self.d_out = self.num_freqs * 2 * d_in 17 | self.include_input = include_input 18 | if include_input: 19 | self.d_out += d_in 20 | # f1 f1 f2 f2 ... to multiply x by 21 | self.register_buffer( 22 | "_freqs", torch.repeat_interleave(self.freqs, 2).view(1, -1, 1) 23 | ) 24 | # 0 pi/2 0 pi/2 ... so that 25 | # (sin(x + _phases[0]), sin(x + _phases[1]) ...) = (sin(x), cos(x)...) 26 | _phases = torch.zeros(2 * self.num_freqs) 27 | _phases[1::2] = np.pi * 0.5 28 | self.register_buffer("_phases", _phases.view(1, -1, 1)) 29 | 30 | def forward(self, x): 31 | """ 32 | Apply positional encoding (new implementation) 33 | :param x (batch, self.d_in) 34 | :return (batch, self.d_out) 35 | """ 36 | with profiler.record_function("positional_enc"): 37 | embed = x.unsqueeze(1).repeat(1, self.num_freqs * 2, 1) 38 | embed = torch.sin(torch.addcmul(self._phases, embed, self._freqs)) 39 | embed = embed.view(x.shape[0], -1) 40 | if self.include_input: 41 | embed = torch.cat((x, embed), dim=-1) 42 | return embed 43 | 44 | @classmethod 45 | def from_conf(cls, conf, d_in=3): 46 | # PyHocon construction 47 | return cls( 48 | conf.get_int("num_freqs", 6), 49 | d_in, 50 | conf.get_float("freq_factor", np.pi), 51 | conf.get_bool("include_input", True), 52 | ) 53 | -------------------------------------------------------------------------------- /src/model/custom_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import util 5 | 6 | 7 | class ConvEncoder(nn.Module): 8 | """ 9 | Basic, extremely simple convolutional encoder 10 | """ 11 | 12 | def __init__( 13 | self, 14 | dim_in=3, 15 | norm_layer=util.get_norm_layer("group"), 16 | padding_type="reflect", 17 | use_leaky_relu=True, 18 | use_skip_conn=True, 19 | ): 20 | super().__init__() 21 | self.dim_in = dim_in 22 | self.norm_layer = norm_layer 23 | self.activation = nn.LeakyReLU() if use_leaky_relu else nn.ReLU() 24 | self.padding_type = padding_type 25 | self.use_skip_conn = use_skip_conn 26 | 27 | # TODO: make these configurable 28 | first_layer_chnls = 64 29 | mid_layer_chnls = 128 30 | last_layer_chnls = 128 31 | n_down_layers = 3 32 | self.n_down_layers = n_down_layers 33 | 34 | self.conv_in = nn.Sequential( 35 | nn.Conv2d(dim_in, first_layer_chnls, kernel_size=7, stride=2, bias=False), 36 | norm_layer(first_layer_chnls), 37 | self.activation, 38 | ) 39 | 40 | chnls = first_layer_chnls 41 | for i in range(0, n_down_layers): 42 | conv = nn.Sequential( 43 | nn.Conv2d(chnls, 2 * chnls, kernel_size=3, stride=2, bias=False), 44 | norm_layer(2 * chnls), 45 | self.activation, 46 | ) 47 | setattr(self, "conv" + str(i), conv) 48 | 49 | deconv = nn.Sequential( 50 | nn.ConvTranspose2d( 51 | 4 * chnls, chnls, kernel_size=3, stride=2, bias=False 52 | ), 53 | norm_layer(chnls), 54 | self.activation, 55 | ) 56 | setattr(self, "deconv" + str(i), deconv) 57 | chnls *= 2 58 | 59 | self.conv_mid = nn.Sequential( 60 | nn.Conv2d(chnls, mid_layer_chnls, kernel_size=4, stride=4, bias=False), 61 | norm_layer(mid_layer_chnls), 62 | self.activation, 63 | ) 64 | 65 | self.deconv_last = nn.ConvTranspose2d( 66 | first_layer_chnls, last_layer_chnls, kernel_size=3, stride=2, bias=True 67 | ) 68 | 69 | self.dims = [last_layer_chnls] 70 | 71 | def forward(self, x): 72 | x = util.same_pad_conv2d(x, padding_type=self.padding_type, layer=self.conv_in) 73 | x = self.conv_in(x) 74 | 75 | inters = [] 76 | for i in range(0, self.n_down_layers): 77 | conv_i = getattr(self, "conv" + str(i)) 78 | x = util.same_pad_conv2d(x, padding_type=self.padding_type, layer=conv_i) 79 | x = conv_i(x) 80 | inters.append(x) 81 | 82 | x = util.same_pad_conv2d(x, padding_type=self.padding_type, layer=self.conv_mid) 83 | x = self.conv_mid(x) 84 | x = x.reshape(x.shape[0], -1, 1, 1).expand(-1, -1, *inters[-1].shape[-2:]) 85 | 86 | for i in reversed(range(0, self.n_down_layers)): 87 | if self.use_skip_conn: 88 | x = torch.cat((x, inters[i]), dim=1) 89 | deconv_i = getattr(self, "deconv" + str(i)) 90 | x = deconv_i(x) 91 | x = util.same_unpad_deconv2d(x, layer=deconv_i) 92 | x = self.deconv_last(x) 93 | x = util.same_unpad_deconv2d(x, layer=self.deconv_last) 94 | return x 95 | -------------------------------------------------------------------------------- /src/model/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements image encoders 3 | """ 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torchvision 8 | import util 9 | from model.custom_encoder import ConvEncoder 10 | import torch.autograd.profiler as profiler 11 | 12 | 13 | class SpatialEncoder(nn.Module): 14 | """ 15 | 2D (Spatial/Pixel-aligned/local) image encoder 16 | """ 17 | 18 | def __init__( 19 | self, 20 | backbone="resnet34", 21 | pretrained=True, 22 | num_layers=4, 23 | index_interp="bilinear", 24 | index_padding="border", 25 | upsample_interp="bilinear", 26 | feature_scale=1.0, 27 | use_first_pool=True, 28 | norm_type="batch", 29 | ): 30 | """ 31 | :param backbone Backbone network. Either custom, in which case 32 | model.custom_encoder.ConvEncoder is used OR resnet18/resnet34, in which case the relevant 33 | model from torchvision is used 34 | :param num_layers number of resnet layers to use, 1-5 35 | :param pretrained Whether to use model weights pretrained on ImageNet 36 | :param index_interp Interpolation to use for indexing 37 | :param index_padding Padding mode to use for indexing, border | zeros | reflection 38 | :param upsample_interp Interpolation to use for upscaling latent code 39 | :param feature_scale factor to scale all latent by. Useful (<1) if image 40 | is extremely large, to fit in memory. 41 | :param use_first_pool if false, skips first maxpool layer to avoid downscaling image 42 | features too much (ResNet only) 43 | :param norm_type norm type to applied; pretrained model must use batch 44 | """ 45 | super().__init__() 46 | 47 | if norm_type != "batch": 48 | assert not pretrained 49 | 50 | self.use_custom_resnet = backbone == "custom" 51 | self.feature_scale = feature_scale 52 | self.use_first_pool = use_first_pool 53 | norm_layer = util.get_norm_layer(norm_type) 54 | 55 | if self.use_custom_resnet: 56 | print("WARNING: Custom encoder is experimental only") 57 | print("Using simple convolutional encoder") 58 | self.model = ConvEncoder(3, norm_layer=norm_layer) 59 | self.latent_size = self.model.dims[-1] 60 | else: 61 | print("Using torchvision", backbone, "encoder") 62 | self.model = getattr(torchvision.models, backbone)( 63 | pretrained=pretrained, norm_layer=norm_layer 64 | ) 65 | # Following 2 lines need to be uncommented for older configs 66 | self.model.fc = nn.Sequential() 67 | self.model.avgpool = nn.Sequential() 68 | self.latent_size = [0, 64, 128, 256, 512, 1024][num_layers] 69 | 70 | self.num_layers = num_layers 71 | self.index_interp = index_interp 72 | self.index_padding = index_padding 73 | self.upsample_interp = upsample_interp 74 | self.register_buffer("latent", torch.empty(1, 1, 1, 1), persistent=False) 75 | self.register_buffer( 76 | "latent_scaling", torch.empty(2, dtype=torch.float32), persistent=False 77 | ) 78 | # self.latent (B, L, H, W) 79 | 80 | def index(self, uv, cam_z=None, image_size=(), z_bounds=None): 81 | """ 82 | Get pixel-aligned image features at 2D image coordinates 83 | :param uv (B, N, 2) image points (x,y) 84 | :param cam_z ignored (for compatibility) 85 | :param image_size image size, either (width, height) or single int. 86 | if not specified, assumes coords are in [-1, 1] 87 | :param z_bounds ignored (for compatibility) 88 | :return (B, L, N) L is latent size 89 | """ 90 | with profiler.record_function("encoder_index"): 91 | if uv.shape[0] == 1 and self.latent.shape[0] > 1: 92 | uv = uv.expand(self.latent.shape[0], -1, -1) 93 | 94 | with profiler.record_function("encoder_index_pre"): 95 | if len(image_size) > 0: 96 | if len(image_size) == 1: 97 | image_size = (image_size, image_size) 98 | scale = self.latent_scaling / image_size 99 | uv = uv * scale - 1.0 100 | 101 | uv = uv.unsqueeze(2) # (B, N, 1, 2) 102 | samples = F.grid_sample( 103 | self.latent, 104 | uv, 105 | align_corners=True, 106 | mode=self.index_interp, 107 | padding_mode=self.index_padding, 108 | ) 109 | return samples[:, :, :, 0] # (B, C, N) 110 | 111 | def forward(self, x): 112 | """ 113 | For extracting ResNet's features. 114 | :param x image (B, C, H, W) 115 | :return latent (B, latent_size, H, W) 116 | """ 117 | if self.feature_scale != 1.0: 118 | x = F.interpolate( 119 | x, 120 | scale_factor=self.feature_scale, 121 | mode="bilinear" if self.feature_scale > 1.0 else "area", 122 | align_corners=True if self.feature_scale > 1.0 else None, 123 | recompute_scale_factor=True, 124 | ) 125 | x = x.to(device=self.latent.device) 126 | 127 | if self.use_custom_resnet: 128 | self.latent = self.model(x) 129 | else: 130 | x = self.model.conv1(x) 131 | x = self.model.bn1(x) 132 | x = self.model.relu(x) 133 | 134 | latents = [x] 135 | if self.num_layers > 1: 136 | if self.use_first_pool: 137 | x = self.model.maxpool(x) 138 | x = self.model.layer1(x) 139 | latents.append(x) 140 | if self.num_layers > 2: 141 | x = self.model.layer2(x) 142 | latents.append(x) 143 | if self.num_layers > 3: 144 | x = self.model.layer3(x) 145 | latents.append(x) 146 | if self.num_layers > 4: 147 | x = self.model.layer4(x) 148 | latents.append(x) 149 | 150 | self.latents = latents 151 | align_corners = None if self.index_interp == "nearest " else True 152 | latent_sz = latents[0].shape[-2:] 153 | for i in range(len(latents)): 154 | latents[i] = F.interpolate( 155 | latents[i], 156 | latent_sz, 157 | mode=self.upsample_interp, 158 | align_corners=align_corners, 159 | ) 160 | self.latent = torch.cat(latents, dim=1) 161 | self.latent_scaling[0] = self.latent.shape[-1] 162 | self.latent_scaling[1] = self.latent.shape[-2] 163 | self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0 164 | return self.latent 165 | 166 | @classmethod 167 | def from_conf(cls, conf): 168 | return cls( 169 | conf.get_string("backbone"), 170 | pretrained=conf.get_bool("pretrained", True), 171 | num_layers=conf.get_int("num_layers", 4), 172 | index_interp=conf.get_string("index_interp", "bilinear"), 173 | index_padding=conf.get_string("index_padding", "border"), 174 | upsample_interp=conf.get_string("upsample_interp", "bilinear"), 175 | feature_scale=conf.get_float("feature_scale", 1.0), 176 | use_first_pool=conf.get_bool("use_first_pool", True), 177 | ) 178 | 179 | 180 | class ImageEncoder(nn.Module): 181 | """ 182 | Global image encoder 183 | """ 184 | 185 | def __init__(self, backbone="resnet34", pretrained=True, latent_size=128): 186 | """ 187 | :param backbone Backbone network. Assumes it is resnet* 188 | e.g. resnet34 | resnet50 189 | :param num_layers number of resnet layers to use, 1-5 190 | :param pretrained Whether to use model pretrained on ImageNet 191 | """ 192 | super().__init__() 193 | self.model = getattr(torchvision.models, backbone)(pretrained=pretrained) 194 | self.model.fc = nn.Sequential() 195 | self.register_buffer("latent", torch.empty(1, 1), persistent=False) 196 | # self.latent (B, L) 197 | self.latent_size = latent_size 198 | if latent_size != 512: 199 | self.fc = nn.Linear(512, latent_size) 200 | 201 | def index(self, uv, cam_z=None, image_size=(), z_bounds=()): 202 | """ 203 | Params ignored (compatibility) 204 | :param uv (B, N, 2) only used for shape 205 | :return latent vector (B, L, N) 206 | """ 207 | return self.latent.unsqueeze(-1).expand(-1, -1, uv.shape[1]) 208 | 209 | def forward(self, x): 210 | """ 211 | For extracting ResNet's features. 212 | :param x image (B, C, H, W) 213 | :return latent (B, latent_size) 214 | """ 215 | x = x.to(device=self.latent.device) 216 | x = self.model.conv1(x) 217 | x = self.model.bn1(x) 218 | x = self.model.relu(x) 219 | 220 | x = self.model.maxpool(x) 221 | x = self.model.layer1(x) 222 | x = self.model.layer2(x) 223 | x = self.model.layer3(x) 224 | x = self.model.layer4(x) 225 | 226 | x = self.model.avgpool(x) 227 | x = torch.flatten(x, 1) 228 | 229 | if self.latent_size != 512: 230 | x = self.fc(x) 231 | 232 | self.latent = x # (B, latent_size) 233 | return self.latent 234 | 235 | @classmethod 236 | def from_conf(cls, conf): 237 | return cls( 238 | conf.get_string("backbone"), 239 | pretrained=conf.get_bool("pretrained", True), 240 | latent_size=conf.get_int("latent_size", 128), 241 | ) 242 | -------------------------------------------------------------------------------- /src/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AlphaLossNV2(torch.nn.Module): 5 | """ 6 | Implement Neural Volumes alpha loss 2 7 | """ 8 | 9 | def __init__(self, lambda_alpha, clamp_alpha, init_epoch, force_opaque=False): 10 | super().__init__() 11 | self.lambda_alpha = lambda_alpha 12 | self.clamp_alpha = clamp_alpha 13 | self.init_epoch = init_epoch 14 | self.force_opaque = force_opaque 15 | if force_opaque: 16 | self.bceloss = torch.nn.BCELoss() 17 | self.register_buffer( 18 | "epoch", torch.tensor(0, dtype=torch.long), persistent=True 19 | ) 20 | 21 | def sched_step(self, num=1): 22 | self.epoch += num 23 | 24 | def forward(self, alpha_fine): 25 | if self.lambda_alpha > 0.0 and self.epoch.item() >= self.init_epoch: 26 | alpha_fine = torch.clamp(alpha_fine, 0.01, 0.99) 27 | if self.force_opaque: 28 | alpha_loss = self.lambda_alpha * self.bceloss( 29 | alpha_fine, torch.ones_like(alpha_fine) 30 | ) 31 | else: 32 | alpha_loss = torch.log(alpha_fine) + torch.log(1.0 - alpha_fine) 33 | alpha_loss = torch.clamp_min(alpha_loss, -self.clamp_alpha) 34 | alpha_loss = self.lambda_alpha * alpha_loss.mean() 35 | else: 36 | alpha_loss = torch.zeros(1, device=alpha_fine.device) 37 | return alpha_loss 38 | 39 | 40 | def get_alpha_loss(conf): 41 | lambda_alpha = conf.get_float("lambda_alpha") 42 | clamp_alpha = conf.get_float("clamp_alpha") 43 | init_epoch = conf.get_int("init_epoch") 44 | force_opaque = conf.get_bool("force_opaque", False) 45 | 46 | return AlphaLossNV2( 47 | lambda_alpha, clamp_alpha, init_epoch, force_opaque=force_opaque 48 | ) 49 | 50 | 51 | class RGBWithUncertainty(torch.nn.Module): 52 | """Implement the uncertainty loss from Kendall '17""" 53 | 54 | def __init__(self, conf): 55 | super().__init__() 56 | self.element_loss = ( 57 | torch.nn.L1Loss(reduction="none") 58 | if conf.get_bool("use_l1") 59 | else torch.nn.MSELoss(reduction="none") 60 | ) 61 | 62 | def forward(self, outputs, targets, betas): 63 | """computes the error per output, weights each element by the log variance 64 | outputs is B x 3, targets is B x 3, betas is B""" 65 | weighted_element_err = ( 66 | torch.mean(self.element_loss(outputs, targets), -1) / betas 67 | ) 68 | return torch.mean(weighted_element_err) + torch.mean(torch.log(betas)) 69 | 70 | 71 | class RGBWithBackground(torch.nn.Module): 72 | """Implement the uncertainty loss from Kendall '17""" 73 | 74 | def __init__(self, conf): 75 | super().__init__() 76 | self.element_loss = ( 77 | torch.nn.L1Loss(reduction="none") 78 | if conf.get_bool("use_l1") 79 | else torch.nn.MSELoss(reduction="none") 80 | ) 81 | 82 | def forward(self, outputs, targets, lambda_bg): 83 | """If we're using background, then the color is color_fg + lambda_bg * color_bg. 84 | We want to weight the background rays less, while not putting all alpha on bg""" 85 | weighted_element_err = torch.mean(self.element_loss(outputs, targets), -1) / ( 86 | 1 + lambda_bg 87 | ) 88 | return torch.mean(weighted_element_err) + torch.mean(torch.log(lambda_bg)) 89 | 90 | 91 | def get_rgb_loss(conf, coarse=True, using_bg=False, reduction="mean"): 92 | if conf.get_bool("use_uncertainty", False) and not coarse: 93 | print("using loss with uncertainty") 94 | return RGBWithUncertainty(conf) 95 | # if using_bg: 96 | # print("using loss with background") 97 | # return RGBWithBackground(conf) 98 | print("using vanilla rgb loss") 99 | return ( 100 | torch.nn.L1Loss(reduction=reduction) 101 | if conf.get_bool("use_l1") 102 | else torch.nn.MSELoss(reduction=reduction) 103 | ) 104 | -------------------------------------------------------------------------------- /src/model/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import util 5 | 6 | 7 | class ImplicitNet(nn.Module): 8 | """ 9 | Represents a MLP; 10 | Original code from IGR 11 | """ 12 | 13 | def __init__( 14 | self, 15 | d_in, 16 | dims, 17 | skip_in=(), 18 | d_out=4, 19 | geometric_init=True, 20 | radius_init=0.3, 21 | beta=0.0, 22 | output_init_gain=2.0, 23 | num_position_inputs=3, 24 | sdf_scale=1.0, 25 | dim_excludes_skip=False, 26 | combine_layer=1000, 27 | combine_type="average", 28 | ): 29 | """ 30 | :param d_in input size 31 | :param dims dimensions of hidden layers. Num hidden layers == len(dims) 32 | :param skip_in layers with skip connections from input (residual) 33 | :param d_out output size 34 | :param geometric_init if true, uses geometric initialization 35 | (to SDF of sphere) 36 | :param radius_init if geometric_init, then SDF sphere will have 37 | this radius 38 | :param beta softplus beta, 100 is reasonable; if <=0 uses ReLU activations instead 39 | :param output_init_gain output layer normal std, only used for 40 | output dimension >= 1, when d_out >= 1 41 | :param dim_excludes_skip if true, dimension sizes do not include skip 42 | connections 43 | """ 44 | super().__init__() 45 | 46 | dims = [d_in] + dims + [d_out] 47 | if dim_excludes_skip: 48 | for i in range(1, len(dims) - 1): 49 | if i in skip_in: 50 | dims[i] += d_in 51 | 52 | self.num_layers = len(dims) 53 | self.skip_in = skip_in 54 | self.dims = dims 55 | self.combine_layer = combine_layer 56 | self.combine_type = combine_type 57 | 58 | for layer in range(0, self.num_layers - 1): 59 | if layer + 1 in skip_in: 60 | out_dim = dims[layer + 1] - d_in 61 | else: 62 | out_dim = dims[layer + 1] 63 | lin = nn.Linear(dims[layer], out_dim) 64 | 65 | # if true preform geometric initialization 66 | if geometric_init: 67 | if layer == self.num_layers - 2: 68 | # Note our geometric init is negated (compared to IDR) 69 | # since we are using the opposite SDF convention: 70 | # inside is + 71 | nn.init.normal_( 72 | lin.weight[0], 73 | mean=-np.sqrt(np.pi) / np.sqrt(dims[layer]) * sdf_scale, 74 | std=0.00001, 75 | ) 76 | nn.init.constant_(lin.bias[0], radius_init) 77 | if d_out > 1: 78 | # More than SDF output 79 | nn.init.normal_(lin.weight[1:], mean=0.0, std=output_init_gain) 80 | nn.init.constant_(lin.bias[1:], 0.0) 81 | else: 82 | nn.init.constant_(lin.bias, 0.0) 83 | nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 84 | if d_in > num_position_inputs and (layer == 0 or layer in skip_in): 85 | # Special handling for input to allow positional encoding 86 | nn.init.constant_(lin.weight[:, -d_in + num_position_inputs :], 0.0) 87 | else: 88 | nn.init.constant_(lin.bias, 0.0) 89 | nn.init.kaiming_normal_(lin.weight, a=0, mode="fan_in") 90 | 91 | setattr(self, "lin" + str(layer), lin) 92 | 93 | if beta > 0: 94 | self.activation = nn.Softplus(beta=beta) 95 | else: 96 | # Vanilla ReLU 97 | self.activation = nn.ReLU() 98 | 99 | def forward(self, x, combine_inner_dims=(1,)): 100 | """ 101 | :param x (..., d_in) 102 | :param combine_inner_dims Combining dimensions for use with multiview inputs. 103 | Tensor will be reshaped to (-1, combine_inner_dims, ...) and reduced using combine_type 104 | on dim 1, at combine_layer 105 | """ 106 | x_init = x 107 | for layer in range(0, self.num_layers - 1): 108 | lin = getattr(self, "lin" + str(layer)) 109 | 110 | if layer == self.combine_layer: 111 | x = util.combine_interleaved(x, combine_inner_dims, self.combine_type) 112 | x_init = util.combine_interleaved( 113 | x_init, combine_inner_dims, self.combine_type 114 | ) 115 | 116 | if layer < self.combine_layer and layer in self.skip_in: 117 | x = torch.cat([x, x_init], -1) / np.sqrt(2) 118 | 119 | x = lin(x) 120 | if layer < self.num_layers - 2: 121 | x = self.activation(x) 122 | 123 | return x 124 | 125 | @classmethod 126 | def from_conf(cls, conf, d_in, **kwargs): 127 | # PyHocon construction 128 | return cls( 129 | d_in, 130 | conf.get_list("dims"), 131 | skip_in=conf.get_list("skip_in"), 132 | beta=conf.get_float("beta", 0.0), 133 | dim_excludes_skip=conf.get_bool("dim_excludes_skip", False), 134 | combine_layer=conf.get_int("combine_layer", 1000), 135 | combine_type=conf.get_string("combine_type", "average"), # average | max 136 | **kwargs 137 | ) 138 | -------------------------------------------------------------------------------- /src/model/model_util.py: -------------------------------------------------------------------------------- 1 | from .encoder import SpatialEncoder, ImageEncoder 2 | from .resnetfc import ResnetFC 3 | 4 | 5 | def make_mlp(conf, d_in, d_latent=0, allow_empty=False, **kwargs): 6 | mlp_type = conf.get_string("type", "mlp") # mlp | resnet 7 | if mlp_type == "mlp": 8 | net = ImplicitNet.from_conf(conf, d_in + d_latent, **kwargs) 9 | elif mlp_type == "resnet": 10 | net = ResnetFC.from_conf(conf, d_in, d_latent=d_latent, **kwargs) 11 | elif mlp_type == "empty" and allow_empty: 12 | net = None 13 | else: 14 | raise NotImplementedError("Unsupported MLP type") 15 | return net 16 | 17 | 18 | def make_encoder(conf, **kwargs): 19 | enc_type = conf.get_string("type", "spatial") # spatial | global 20 | if enc_type == "spatial": 21 | net = SpatialEncoder.from_conf(conf, **kwargs) 22 | elif enc_type == "global": 23 | net = ImageEncoder.from_conf(conf, **kwargs) 24 | else: 25 | raise NotImplementedError("Unsupported encoder type") 26 | return net 27 | -------------------------------------------------------------------------------- /src/model/resnetfc.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | # import torch_scatter 5 | import torch.autograd.profiler as profiler 6 | import util 7 | 8 | 9 | # Resnet Blocks 10 | class ResnetBlockFC(nn.Module): 11 | """ 12 | Fully connected ResNet Block class. 13 | Taken from DVR code. 14 | :param size_in (int): input dimension 15 | :param size_out (int): output dimension 16 | :param size_h (int): hidden dimension 17 | """ 18 | 19 | def __init__(self, size_in, size_out=None, size_h=None, beta=0.0): 20 | super().__init__() 21 | # Attributes 22 | if size_out is None: 23 | size_out = size_in 24 | 25 | if size_h is None: 26 | size_h = min(size_in, size_out) 27 | 28 | self.size_in = size_in 29 | self.size_h = size_h 30 | self.size_out = size_out 31 | # Submodules 32 | self.fc_0 = nn.Linear(size_in, size_h) 33 | self.fc_1 = nn.Linear(size_h, size_out) 34 | 35 | # Init 36 | nn.init.constant_(self.fc_0.bias, 0.0) 37 | nn.init.kaiming_normal_(self.fc_0.weight, a=0, mode="fan_in") 38 | nn.init.constant_(self.fc_1.bias, 0.0) 39 | nn.init.zeros_(self.fc_1.weight) 40 | 41 | if beta > 0: 42 | self.activation = nn.Softplus(beta=beta) 43 | else: 44 | self.activation = nn.ReLU() 45 | 46 | if size_in == size_out: 47 | self.shortcut = None 48 | else: 49 | self.shortcut = nn.Linear(size_in, size_out, bias=False) 50 | nn.init.constant_(self.shortcut.bias, 0.0) 51 | nn.init.kaiming_normal_(self.shortcut.weight, a=0, mode="fan_in") 52 | 53 | def forward(self, x): 54 | with profiler.record_function("resblock"): 55 | net = self.fc_0(self.activation(x)) 56 | dx = self.fc_1(self.activation(net)) 57 | 58 | if self.shortcut is not None: 59 | x_s = self.shortcut(x) 60 | else: 61 | x_s = x 62 | return x_s + dx 63 | 64 | 65 | class ResnetFC(nn.Module): 66 | def __init__( 67 | self, 68 | d_in, 69 | d_out=4, 70 | n_blocks=5, 71 | d_latent=0, 72 | d_hidden=128, 73 | beta=0.0, 74 | combine_layer=1000, 75 | combine_type="average", 76 | use_spade=False, 77 | ): 78 | """ 79 | :param d_in input size 80 | :param d_out output size 81 | :param n_blocks number of Resnet blocks 82 | :param d_latent latent size, added in each resnet block (0 = disable) 83 | :param d_hidden hiddent dimension throughout network 84 | :param beta softplus beta, 100 is reasonable; if <=0 uses ReLU activations instead 85 | """ 86 | super().__init__() 87 | if d_in > 0: 88 | self.lin_in = nn.Linear(d_in, d_hidden) 89 | nn.init.constant_(self.lin_in.bias, 0.0) 90 | nn.init.kaiming_normal_(self.lin_in.weight, a=0, mode="fan_in") 91 | 92 | self.lin_out = nn.Linear(d_hidden, d_out) 93 | nn.init.constant_(self.lin_out.bias, 0.0) 94 | nn.init.kaiming_normal_(self.lin_out.weight, a=0, mode="fan_in") 95 | 96 | self.n_blocks = n_blocks 97 | self.d_latent = d_latent 98 | self.d_in = d_in 99 | self.d_out = d_out 100 | self.d_hidden = d_hidden 101 | 102 | self.combine_layer = combine_layer 103 | self.combine_type = combine_type 104 | self.use_spade = use_spade 105 | 106 | self.blocks = nn.ModuleList( 107 | [ResnetBlockFC(d_hidden, beta=beta) for i in range(n_blocks)] 108 | ) 109 | 110 | if d_latent != 0: 111 | n_lin_z = min(combine_layer, n_blocks) 112 | self.lin_z = nn.ModuleList( 113 | [nn.Linear(d_latent, d_hidden) for i in range(n_lin_z)] 114 | ) 115 | for i in range(n_lin_z): 116 | nn.init.constant_(self.lin_z[i].bias, 0.0) 117 | nn.init.kaiming_normal_(self.lin_z[i].weight, a=0, mode="fan_in") 118 | 119 | if self.use_spade: 120 | self.scale_z = nn.ModuleList( 121 | [nn.Linear(d_latent, d_hidden) for _ in range(n_lin_z)] 122 | ) 123 | for i in range(n_lin_z): 124 | nn.init.constant_(self.scale_z[i].bias, 0.0) 125 | nn.init.kaiming_normal_(self.scale_z[i].weight, a=0, mode="fan_in") 126 | 127 | if beta > 0: 128 | self.activation = nn.Softplus(beta=beta) 129 | else: 130 | self.activation = nn.ReLU() 131 | 132 | def forward(self, zx, combine_inner_dims=(1,), combine_index=None, dim_size=None): 133 | """ 134 | :param zx (..., d_latent + d_in) 135 | :param combine_inner_dims Combining dimensions for use with multiview inputs. 136 | Tensor will be reshaped to (-1, combine_inner_dims, ...) and reduced using combine_type 137 | on dim 1, at combine_layer 138 | """ 139 | with profiler.record_function("resnetfc_infer"): 140 | assert zx.size(-1) == self.d_latent + self.d_in 141 | if self.d_latent > 0: 142 | z = zx[..., : self.d_latent] 143 | x = zx[..., self.d_latent :] 144 | else: 145 | x = zx 146 | if self.d_in > 0: 147 | x = self.lin_in(x) 148 | else: 149 | x = torch.zeros(self.d_hidden, device=zx.device) 150 | 151 | for blkid in range(self.n_blocks): 152 | if blkid == self.combine_layer: 153 | # The following implements camera frustum culling, requires torch_scatter 154 | # if combine_index is not None: 155 | # combine_type = ( 156 | # "mean" 157 | # if self.combine_type == "average" 158 | # else self.combine_type 159 | # ) 160 | # if dim_size is not None: 161 | # assert isinstance(dim_size, int) 162 | # x = torch_scatter.scatter( 163 | # x, 164 | # combine_index, 165 | # dim=0, 166 | # dim_size=dim_size, 167 | # reduce=combine_type, 168 | # ) 169 | # else: 170 | x = util.combine_interleaved( 171 | x, combine_inner_dims, self.combine_type 172 | ) 173 | 174 | if self.d_latent > 0 and blkid < self.combine_layer: 175 | tz = self.lin_z[blkid](z) 176 | if self.use_spade: 177 | sz = self.scale_z[blkid](z) 178 | x = sz * x + tz 179 | else: 180 | x = x + tz 181 | 182 | x = self.blocks[blkid](x) 183 | out = self.lin_out(self.activation(x)) 184 | return out 185 | 186 | @classmethod 187 | def from_conf(cls, conf, d_in, **kwargs): 188 | # PyHocon construction 189 | return cls( 190 | d_in, 191 | n_blocks=conf.get_int("n_blocks", 5), 192 | d_hidden=conf.get_int("d_hidden", 128), 193 | beta=conf.get_float("beta", 0.0), 194 | combine_layer=conf.get_int("combine_layer", 1000), 195 | combine_type=conf.get_string("combine_type", "average"), # average | max 196 | use_spade=conf.get_bool("use_spade", False), 197 | **kwargs 198 | ) 199 | -------------------------------------------------------------------------------- /src/render/__init__.py: -------------------------------------------------------------------------------- 1 | from .nerf import NeRFRenderer 2 | -------------------------------------------------------------------------------- /src/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | from . import args 3 | 4 | # from . import recon 5 | -------------------------------------------------------------------------------- /src/util/args.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | 5 | import argparse 6 | from pyhocon import ConfigFactory 7 | 8 | 9 | def parse_args( 10 | callback=None, 11 | training=False, 12 | default_conf="conf/default_mv.conf", 13 | default_expname="example", 14 | default_data_format="dvr", 15 | default_num_epochs=10000000, 16 | default_lr=1e-4, 17 | default_gamma=1.00, 18 | default_datadir="data", 19 | default_ray_batch_size=50000, 20 | ): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--conf", "-c", type=str, default=None) 23 | parser.add_argument("--resume", "-r", action="store_true", help="continue training") 24 | parser.add_argument( 25 | "--gpu_id", type=str, default="0", help="GPU(s) to use, space delimited" 26 | ) 27 | parser.add_argument( 28 | "--name", "-n", type=str, default=default_expname, help="experiment name" 29 | ) 30 | parser.add_argument( 31 | "--dataset_format", 32 | "-F", 33 | type=str, 34 | default=None, 35 | help="Dataset format, multi_obj | dvr | dvr_gen | dvr_dtu | srn", 36 | ) 37 | parser.add_argument( 38 | "--exp_group_name", 39 | "-G", 40 | type=str, 41 | default=None, 42 | help="if we want to group some experiments together", 43 | ) 44 | parser.add_argument( 45 | "--logs_path", type=str, default="logs", help="logs output directory", 46 | ) 47 | parser.add_argument( 48 | "--checkpoints_path", 49 | type=str, 50 | default="checkpoints", 51 | help="checkpoints output directory", 52 | ) 53 | parser.add_argument( 54 | "--visual_path", 55 | type=str, 56 | default="visuals", 57 | help="visualization output directory", 58 | ) 59 | parser.add_argument( 60 | "--epochs", 61 | type=int, 62 | default=default_num_epochs, 63 | help="number of epochs to train for", 64 | ) 65 | parser.add_argument("--lr", type=float, default=default_lr, help="learning rate") 66 | parser.add_argument( 67 | "--gamma", type=float, default=default_gamma, help="learning rate decay factor" 68 | ) 69 | parser.add_argument( 70 | "--datadir", "-D", type=str, default=None, help="Dataset directory" 71 | ) 72 | parser.add_argument( 73 | "--ray_batch_size", "-R", type=int, default=default_ray_batch_size, help="Ray batch size" 74 | ) 75 | if callback is not None: 76 | parser = callback(parser) 77 | args = parser.parse_args() 78 | 79 | if args.exp_group_name is not None: 80 | args.logs_path = os.path.join(args.logs_path, args.exp_group_name) 81 | args.checkpoints_path = os.path.join(args.checkpoints_path, args.exp_group_name) 82 | args.visual_path = os.path.join(args.visual_path, args.exp_group_name) 83 | 84 | os.makedirs(os.path.join(args.checkpoints_path, args.name), exist_ok=True) 85 | os.makedirs(os.path.join(args.visual_path, args.name), exist_ok=True) 86 | 87 | PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) 88 | EXPCONF_PATH = os.path.join(PROJECT_ROOT, "expconf.conf") 89 | expconf = ConfigFactory.parse_file(EXPCONF_PATH) 90 | 91 | if args.conf is None: 92 | args.conf = expconf.get_string("config." + args.name, default_conf) 93 | 94 | if args.conf is None: 95 | args.conf = expconf.get_string("config." + args.name, default_conf) 96 | if args.datadir is None: 97 | args.datadir = expconf.get_string("datadir." + args.name, default_datadir) 98 | 99 | conf = ConfigFactory.parse_file(args.conf) 100 | 101 | if args.dataset_format is None: 102 | args.dataset_format = conf.get_string("data.format", default_data_format) 103 | 104 | args.gpu_id = list(map(int, args.gpu_id.split())) 105 | 106 | print("EXPERIMENT NAME:", args.name) 107 | if training: 108 | print("CONTINUE?", "yes" if args.resume else "no") 109 | print("* Config file:", args.conf) 110 | print("* Dataset format:", args.dataset_format) 111 | print("* Dataset location:", args.datadir) 112 | return args, conf 113 | -------------------------------------------------------------------------------- /src/util/recon.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mesh reconstruction tools 3 | """ 4 | import mcubes 5 | import torch 6 | import numpy as np 7 | import util 8 | import tqdm 9 | import warnings 10 | 11 | 12 | def marching_cubes( 13 | occu_net, 14 | c1=[-1, -1, -1], 15 | c2=[1, 1, 1], 16 | reso=[128, 128, 128], 17 | isosurface=50.0, 18 | sigma_idx=3, 19 | eval_batch_size=100000, 20 | coarse=True, 21 | device=None, 22 | ): 23 | """ 24 | Run marching cubes on network. Uses PyMCubes. 25 | WARNING: does not make much sense with viewdirs in current form, since 26 | sigma depends on viewdirs. 27 | :param occu_net main NeRF type network 28 | :param c1 corner 1 of marching cube bounds x,y,z 29 | :param c2 corner 2 of marching cube bounds x,y,z (all > c1) 30 | :param reso resolutions of marching cubes x,y,z 31 | :param isosurface sigma-isosurface of marching cubes 32 | :param sigma_idx index of 'sigma' value in last dimension of occu_net's output 33 | :param eval_batch_size batch size for evaluation 34 | :param coarse whether to use coarse NeRF for evaluation 35 | :param device optionally, device to put points for evaluation. 36 | By default uses device of occu_net's first parameter. 37 | """ 38 | if occu_net.use_viewdirs: 39 | warnings.warn( 40 | "Running marching cubes with fake view dirs (pointing to origin), output may be invalid" 41 | ) 42 | with torch.no_grad(): 43 | grid = util.gen_grid(*zip(c1, c2, reso), ij_indexing=True) 44 | is_train = occu_net.training 45 | 46 | print("Evaluating sigma @", grid.size(0), "points") 47 | occu_net.eval() 48 | 49 | all_sigmas = [] 50 | if device is None: 51 | device = next(occu_net.parameters()).device 52 | grid_spl = torch.split(grid, eval_batch_size, dim=0) 53 | if occu_net.use_viewdirs: 54 | fake_viewdirs = -grid / torch.norm(grid, dim=-1).unsqueeze(-1) 55 | vd_spl = torch.split(fake_viewdirs, eval_batch_size, dim=0) 56 | for pnts, vd in tqdm.tqdm(zip(grid_spl, vd_spl), total=len(grid_spl)): 57 | outputs = occu_net( 58 | pnts.to(device=device), coarse=coarse, viewdirs=vd.to(device=device) 59 | ) 60 | sigmas = outputs[..., sigma_idx] 61 | all_sigmas.append(sigmas.cpu()) 62 | else: 63 | for pnts in tqdm.tqdm(grid_spl): 64 | outputs = occu_net(pnts.to(device=device), coarse=coarse) 65 | sigmas = outputs[..., sigma_idx] 66 | all_sigmas.append(sigmas.cpu()) 67 | sigmas = torch.cat(all_sigmas, dim=0) 68 | sigmas = sigmas.view(*reso).cpu().numpy() 69 | 70 | print("Running marching cubes") 71 | vertices, triangles = mcubes.marching_cubes(sigmas, isosurface) 72 | # Scale 73 | c1, c2 = np.array(c1), np.array(c2) 74 | vertices *= (c2 - c1) / np.array(reso) 75 | 76 | if is_train: 77 | occu_net.train() 78 | return vertices + c1, triangles 79 | 80 | 81 | def save_obj(vertices, triangles, path, vert_rgb=None): 82 | """ 83 | Save OBJ file, optionally with vertex colors. 84 | This version is faster than PyMCubes and supports color. 85 | Taken from PIFu. 86 | :param vertices (N, 3) 87 | :param triangles (N, 3) 88 | :param vert_rgb (N, 3) rgb 89 | """ 90 | file = open(path, "w") 91 | if vert_rgb is None: 92 | # No color 93 | for v in vertices: 94 | file.write("v %.4f %.4f %.4f\n" % (v[0], v[1], v[2])) 95 | else: 96 | # Color 97 | for idx, v in enumerate(vertices): 98 | c = vert_rgb[idx] 99 | file.write( 100 | "v %.4f %.4f %.4f %.4f %.4f %.4f\n" 101 | % (v[0], v[1], v[2], c[0], c[1], c[2]) 102 | ) 103 | for f in triangles: 104 | f_plus = f + 1 105 | file.write("f %d %d %d\n" % (f_plus[0], f_plus[1], f_plus[2])) 106 | file.close() 107 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | # Training to a set of multiple objects (e.g. ShapeNet or DTU) 2 | # tensorboard logs available in logs/ 3 | 4 | import sys 5 | import os 6 | 7 | sys.path.insert( 8 | 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")) 9 | ) 10 | 11 | import warnings 12 | import trainlib 13 | from model import make_model, loss 14 | from render import NeRFRenderer 15 | from data import get_split_dataset 16 | import util 17 | import numpy as np 18 | import torch.nn.functional as F 19 | import torch 20 | from dotmap import DotMap 21 | 22 | 23 | def extra_args(parser): 24 | parser.add_argument( 25 | "--batch_size", "-B", type=int, default=4, help="Object batch size ('SB')" 26 | ) 27 | parser.add_argument( 28 | "--nviews", 29 | "-V", 30 | type=str, 31 | default="1", 32 | help="Number of source views (multiview); put multiple (space delim) to pick randomly per batch ('NV')", 33 | ) 34 | parser.add_argument( 35 | "--freeze_enc", 36 | action="store_true", 37 | default=None, 38 | help="Freeze encoder weights and only train MLP", 39 | ) 40 | 41 | parser.add_argument( 42 | "--no_bbox_step", 43 | type=int, 44 | default=100000, 45 | help="Step to stop using bbox sampling", 46 | ) 47 | parser.add_argument( 48 | "--fixed_test", 49 | action="store_true", 50 | default=None, 51 | help="Freeze encoder weights and only train MLP", 52 | ) 53 | return parser 54 | 55 | 56 | args, conf = util.args.parse_args(extra_args, training=True, default_ray_batch_size=128) 57 | device = util.get_cuda(args.gpu_id[0]) 58 | 59 | dset, val_dset, _ = get_split_dataset(args.dataset_format, args.datadir) 60 | print( 61 | "dset z_near {}, z_far {}, lindisp {}".format(dset.z_near, dset.z_far, dset.lindisp) 62 | ) 63 | 64 | net = make_model(conf["model"]).to(device=device) 65 | net.stop_encoder_grad = args.freeze_enc 66 | if args.freeze_enc: 67 | print("Encoder frozen") 68 | net.encoder.eval() 69 | 70 | renderer = NeRFRenderer.from_conf(conf["renderer"], lindisp=dset.lindisp,).to( 71 | device=device 72 | ) 73 | 74 | # Parallize 75 | render_par = renderer.bind_parallel(net, args.gpu_id).eval() 76 | 77 | nviews = list(map(int, args.nviews.split())) 78 | 79 | 80 | class PixelNeRFTrainer(trainlib.Trainer): 81 | def __init__(self): 82 | super().__init__(net, dset, val_dset, args, conf["train"], device=device) 83 | self.renderer_state_path = "%s/%s/_renderer" % ( 84 | self.args.checkpoints_path, 85 | self.args.name, 86 | ) 87 | 88 | self.lambda_coarse = conf.get_float("loss.lambda_coarse") 89 | self.lambda_fine = conf.get_float("loss.lambda_fine", 1.0) 90 | print( 91 | "lambda coarse {} and fine {}".format(self.lambda_coarse, self.lambda_fine) 92 | ) 93 | self.rgb_coarse_crit = loss.get_rgb_loss(conf["loss.rgb"], True) 94 | fine_loss_conf = conf["loss.rgb"] 95 | if "rgb_fine" in conf["loss"]: 96 | print("using fine loss") 97 | fine_loss_conf = conf["loss.rgb_fine"] 98 | self.rgb_fine_crit = loss.get_rgb_loss(fine_loss_conf, False) 99 | 100 | if args.resume: 101 | if os.path.exists(self.renderer_state_path): 102 | renderer.load_state_dict( 103 | torch.load(self.renderer_state_path, map_location=device) 104 | ) 105 | 106 | self.z_near = dset.z_near 107 | self.z_far = dset.z_far 108 | 109 | self.use_bbox = args.no_bbox_step > 0 110 | 111 | def post_batch(self, epoch, batch): 112 | renderer.sched_step(args.batch_size) 113 | 114 | def extra_save_state(self): 115 | torch.save(renderer.state_dict(), self.renderer_state_path) 116 | 117 | def calc_losses(self, data, is_train=True, global_step=0): 118 | if "images" not in data: 119 | return {} 120 | all_images = data["images"].to(device=device) # (SB, NV, 3, H, W) 121 | 122 | SB, NV, _, H, W = all_images.shape 123 | all_poses = data["poses"].to(device=device) # (SB, NV, 4, 4) 124 | all_bboxes = data.get("bbox") # (SB, NV, 4) cmin rmin cmax rmax 125 | all_focals = data["focal"] # (SB) 126 | all_c = data.get("c") # (SB) 127 | 128 | if self.use_bbox and global_step >= args.no_bbox_step: 129 | self.use_bbox = False 130 | print(">>> Stopped using bbox sampling @ iter", global_step) 131 | 132 | if not is_train or not self.use_bbox: 133 | all_bboxes = None 134 | 135 | all_rgb_gt = [] 136 | all_rays = [] 137 | 138 | curr_nviews = nviews[torch.randint(0, len(nviews), ()).item()] 139 | if curr_nviews == 1: 140 | image_ord = torch.randint(0, NV, (SB, 1)) 141 | else: 142 | image_ord = torch.empty((SB, curr_nviews), dtype=torch.long) 143 | for obj_idx in range(SB): 144 | if all_bboxes is not None: 145 | bboxes = all_bboxes[obj_idx] 146 | images = all_images[obj_idx] # (NV, 3, H, W) 147 | poses = all_poses[obj_idx] # (NV, 4, 4) 148 | focal = all_focals[obj_idx] 149 | c = None 150 | if "c" in data: 151 | c = data["c"][obj_idx] 152 | if curr_nviews > 1: 153 | # Somewhat inefficient, don't know better way 154 | image_ord[obj_idx] = torch.from_numpy( 155 | np.random.choice(NV, curr_nviews, replace=False) 156 | ) 157 | images_0to1 = images * 0.5 + 0.5 158 | 159 | cam_rays = util.gen_rays( 160 | poses, W, H, focal, self.z_near, self.z_far, c=c 161 | ) # (NV, H, W, 8) 162 | rgb_gt_all = images_0to1 163 | rgb_gt_all = ( 164 | rgb_gt_all.permute(0, 2, 3, 1).contiguous().reshape(-1, 3) 165 | ) # (NV, H, W, 3) 166 | 167 | if all_bboxes is not None: 168 | pix = util.bbox_sample(bboxes, args.ray_batch_size) 169 | pix_inds = pix[..., 0] * H * W + pix[..., 1] * W + pix[..., 2] 170 | else: 171 | pix_inds = torch.randint(0, NV * H * W, (args.ray_batch_size,)) 172 | 173 | rgb_gt = rgb_gt_all[pix_inds] # (ray_batch_size, 3) 174 | rays = cam_rays.view(-1, cam_rays.shape[-1])[pix_inds].to( 175 | device=device 176 | ) # (ray_batch_size, 8) 177 | 178 | all_rgb_gt.append(rgb_gt) 179 | all_rays.append(rays) 180 | 181 | all_rgb_gt = torch.stack(all_rgb_gt) # (SB, ray_batch_size, 3) 182 | all_rays = torch.stack(all_rays) # (SB, ray_batch_size, 8) 183 | 184 | image_ord = image_ord.to(device) 185 | src_images = util.batched_index_select_nd( 186 | all_images, image_ord 187 | ) # (SB, NS, 3, H, W) 188 | src_poses = util.batched_index_select_nd(all_poses, image_ord) # (SB, NS, 4, 4) 189 | 190 | all_bboxes = all_poses = all_images = None 191 | 192 | net.encode( 193 | src_images, 194 | src_poses, 195 | all_focals.to(device=device), 196 | c=all_c.to(device=device) if all_c is not None else None, 197 | ) 198 | 199 | render_dict = DotMap(render_par(all_rays, want_weights=True,)) 200 | coarse = render_dict.coarse 201 | fine = render_dict.fine 202 | using_fine = len(fine) > 0 203 | 204 | loss_dict = {} 205 | 206 | rgb_loss = self.rgb_coarse_crit(coarse.rgb, all_rgb_gt) 207 | loss_dict["rc"] = rgb_loss.item() * self.lambda_coarse 208 | if using_fine: 209 | fine_loss = self.rgb_fine_crit(fine.rgb, all_rgb_gt) 210 | rgb_loss = rgb_loss * self.lambda_coarse + fine_loss * self.lambda_fine 211 | loss_dict["rf"] = fine_loss.item() * self.lambda_fine 212 | 213 | loss = rgb_loss 214 | if is_train: 215 | loss.backward() 216 | loss_dict["t"] = loss.item() 217 | 218 | return loss_dict 219 | 220 | def train_step(self, data, global_step): 221 | return self.calc_losses(data, is_train=True, global_step=global_step) 222 | 223 | def eval_step(self, data, global_step): 224 | renderer.eval() 225 | losses = self.calc_losses(data, is_train=False, global_step=global_step) 226 | renderer.train() 227 | return losses 228 | 229 | def vis_step(self, data, global_step, idx=None): 230 | if "images" not in data: 231 | return {} 232 | if idx is None: 233 | batch_idx = np.random.randint(0, data["images"].shape[0]) 234 | else: 235 | print(idx) 236 | batch_idx = idx 237 | images = data["images"][batch_idx].to(device=device) # (NV, 3, H, W) 238 | poses = data["poses"][batch_idx].to(device=device) # (NV, 4, 4) 239 | focal = data["focal"][batch_idx : batch_idx + 1] # (1) 240 | c = data.get("c") 241 | if c is not None: 242 | c = c[batch_idx : batch_idx + 1] # (1) 243 | NV, _, H, W = images.shape 244 | cam_rays = util.gen_rays( 245 | poses, W, H, focal, self.z_near, self.z_far, c=c 246 | ) # (NV, H, W, 8) 247 | images_0to1 = images * 0.5 + 0.5 # (NV, 3, H, W) 248 | 249 | curr_nviews = nviews[torch.randint(0, len(nviews), (1,)).item()] 250 | views_src = np.sort(np.random.choice(NV, curr_nviews, replace=False)) 251 | view_dest = np.random.randint(0, NV - curr_nviews) 252 | for vs in range(curr_nviews): 253 | view_dest += view_dest >= views_src[vs] 254 | views_src = torch.from_numpy(views_src) 255 | 256 | # set renderer net to eval mode 257 | renderer.eval() 258 | source_views = ( 259 | images_0to1[views_src] 260 | .permute(0, 2, 3, 1) 261 | .cpu() 262 | .numpy() 263 | .reshape(-1, H, W, 3) 264 | ) 265 | 266 | gt = images_0to1[view_dest].permute(1, 2, 0).cpu().numpy().reshape(H, W, 3) 267 | with torch.no_grad(): 268 | test_rays = cam_rays[view_dest] # (H, W, 8) 269 | test_images = images[views_src] # (NS, 3, H, W) 270 | net.encode( 271 | test_images.unsqueeze(0), 272 | poses[views_src].unsqueeze(0), 273 | focal.to(device=device), 274 | c=c.to(device=device) if c is not None else None, 275 | ) 276 | test_rays = test_rays.reshape(1, H * W, -1) 277 | render_dict = DotMap(render_par(test_rays, want_weights=True)) 278 | coarse = render_dict.coarse 279 | fine = render_dict.fine 280 | 281 | using_fine = len(fine) > 0 282 | 283 | alpha_coarse_np = coarse.weights[0].sum(dim=-1).cpu().numpy().reshape(H, W) 284 | rgb_coarse_np = coarse.rgb[0].cpu().numpy().reshape(H, W, 3) 285 | depth_coarse_np = coarse.depth[0].cpu().numpy().reshape(H, W) 286 | 287 | if using_fine: 288 | alpha_fine_np = fine.weights[0].sum(dim=1).cpu().numpy().reshape(H, W) 289 | depth_fine_np = fine.depth[0].cpu().numpy().reshape(H, W) 290 | rgb_fine_np = fine.rgb[0].cpu().numpy().reshape(H, W, 3) 291 | 292 | print("c rgb min {} max {}".format(rgb_coarse_np.min(), rgb_coarse_np.max())) 293 | print( 294 | "c alpha min {}, max {}".format( 295 | alpha_coarse_np.min(), alpha_coarse_np.max() 296 | ) 297 | ) 298 | alpha_coarse_cmap = util.cmap(alpha_coarse_np) / 255 299 | depth_coarse_cmap = util.cmap(depth_coarse_np) / 255 300 | vis_list = [ 301 | *source_views, 302 | gt, 303 | depth_coarse_cmap, 304 | rgb_coarse_np, 305 | alpha_coarse_cmap, 306 | ] 307 | 308 | vis_coarse = np.hstack(vis_list) 309 | vis = vis_coarse 310 | 311 | if using_fine: 312 | print("f rgb min {} max {}".format(rgb_fine_np.min(), rgb_fine_np.max())) 313 | print( 314 | "f alpha min {}, max {}".format( 315 | alpha_fine_np.min(), alpha_fine_np.max() 316 | ) 317 | ) 318 | depth_fine_cmap = util.cmap(depth_fine_np) / 255 319 | alpha_fine_cmap = util.cmap(alpha_fine_np) / 255 320 | vis_list = [ 321 | *source_views, 322 | gt, 323 | depth_fine_cmap, 324 | rgb_fine_np, 325 | alpha_fine_cmap, 326 | ] 327 | 328 | vis_fine = np.hstack(vis_list) 329 | vis = np.vstack((vis_coarse, vis_fine)) 330 | rgb_psnr = rgb_fine_np 331 | else: 332 | rgb_psnr = rgb_coarse_np 333 | 334 | psnr = util.psnr(rgb_psnr, gt) 335 | vals = {"psnr": psnr} 336 | print("psnr", psnr) 337 | 338 | # set the renderer network back to train mode 339 | renderer.train() 340 | return vis, vals 341 | 342 | 343 | trainer = PixelNeRFTrainer() 344 | trainer.start() 345 | -------------------------------------------------------------------------------- /train/trainlib/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | -------------------------------------------------------------------------------- /train/trainlib/trainer.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torch 3 | import numpy as np 4 | from torch.utils.tensorboard import SummaryWriter 5 | import tqdm 6 | import warnings 7 | 8 | 9 | class Trainer: 10 | def __init__(self, net, train_dataset, test_dataset, args, conf, device=None): 11 | self.args = args 12 | self.net = net 13 | self.train_dataset = train_dataset 14 | self.test_dataset = test_dataset 15 | 16 | self.train_data_loader = torch.utils.data.DataLoader( 17 | train_dataset, 18 | batch_size=args.batch_size, 19 | shuffle=True, 20 | num_workers=8, 21 | pin_memory=False, 22 | ) 23 | self.test_data_loader = torch.utils.data.DataLoader( 24 | test_dataset, 25 | batch_size=min(args.batch_size, 16), 26 | shuffle=True, 27 | num_workers=4, 28 | pin_memory=False, 29 | ) 30 | 31 | self.num_total_batches = len(self.train_dataset) 32 | self.exp_name = args.name 33 | self.save_interval = conf.get_int("save_interval") 34 | self.print_interval = conf.get_int("print_interval") 35 | self.vis_interval = conf.get_int("vis_interval") 36 | self.eval_interval = conf.get_int("eval_interval") 37 | self.num_epoch_repeats = conf.get_int("num_epoch_repeats", 1) 38 | self.num_epochs = args.epochs 39 | self.accu_grad = conf.get_int("accu_grad", 1) 40 | self.summary_path = os.path.join(args.logs_path, args.name) 41 | self.writer = SummaryWriter(self.summary_path) 42 | 43 | self.fixed_test = hasattr(args, "fixed_test") and args.fixed_test 44 | 45 | os.makedirs(self.summary_path, exist_ok=True) 46 | 47 | # Currently only Adam supported 48 | self.optim = torch.optim.Adam(net.parameters(), lr=args.lr) 49 | if args.gamma != 1.0: 50 | self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( 51 | optimizer=self.optim, gamma=args.gamma 52 | ) 53 | else: 54 | self.lr_scheduler = None 55 | 56 | # Load weights 57 | self.managed_weight_saving = hasattr(net, "load_weights") 58 | if self.managed_weight_saving: 59 | net.load_weights(self.args) 60 | self.iter_state_path = "%s/%s/_iter" % ( 61 | self.args.checkpoints_path, 62 | self.args.name, 63 | ) 64 | self.optim_state_path = "%s/%s/_optim" % ( 65 | self.args.checkpoints_path, 66 | self.args.name, 67 | ) 68 | self.lrsched_state_path = "%s/%s/_lrsched" % ( 69 | self.args.checkpoints_path, 70 | self.args.name, 71 | ) 72 | self.default_net_state_path = "%s/%s/net" % ( 73 | self.args.checkpoints_path, 74 | self.args.name, 75 | ) 76 | self.start_iter_id = 0 77 | if args.resume: 78 | if os.path.exists(self.optim_state_path): 79 | try: 80 | self.optim.load_state_dict( 81 | torch.load(self.optim_state_path, map_location=device) 82 | ) 83 | except: 84 | warnings.warn( 85 | "Failed to load optimizer state at", self.optim_state_path 86 | ) 87 | if self.lr_scheduler is not None and os.path.exists( 88 | self.lrsched_state_path 89 | ): 90 | self.lr_scheduler.load_state_dict( 91 | torch.load(self.lrsched_state_path, map_location=device) 92 | ) 93 | if os.path.exists(self.iter_state_path): 94 | self.start_iter_id = torch.load( 95 | self.iter_state_path, map_location=device 96 | )["iter"] 97 | if not self.managed_weight_saving and os.path.exists( 98 | self.default_net_state_path 99 | ): 100 | net.load_state_dict( 101 | torch.load(self.default_net_state_path, map_location=device) 102 | ) 103 | 104 | self.visual_path = os.path.join(self.args.visual_path, self.args.name) 105 | self.conf = conf 106 | 107 | def post_batch(self, epoch, batch): 108 | """ 109 | Ran after each batch 110 | """ 111 | pass 112 | 113 | def extra_save_state(self): 114 | """ 115 | Ran at each save step for saving extra state 116 | """ 117 | pass 118 | 119 | def train_step(self, data, global_step): 120 | """ 121 | Training step 122 | """ 123 | raise NotImplementedError() 124 | 125 | def eval_step(self, data, global_step): 126 | """ 127 | Evaluation step 128 | """ 129 | raise NotImplementedError() 130 | 131 | def vis_step(self, data, global_step): 132 | """ 133 | Visualization step 134 | """ 135 | return None, None 136 | 137 | def start(self): 138 | def fmt_loss_str(losses): 139 | return "loss " + (" ".join(k + ":" + str(losses[k]) for k in losses)) 140 | 141 | def data_loop(dl): 142 | """ 143 | Loop an iterable infinitely 144 | """ 145 | while True: 146 | for x in iter(dl): 147 | yield x 148 | 149 | test_data_iter = data_loop(self.test_data_loader) 150 | 151 | step_id = self.start_iter_id 152 | 153 | progress = tqdm.tqdm(bar_format="[{rate_fmt}] ") 154 | for epoch in range(self.num_epochs): 155 | self.writer.add_scalar( 156 | "lr", self.optim.param_groups[0]["lr"], global_step=step_id 157 | ) 158 | 159 | batch = 0 160 | for _ in range(self.num_epoch_repeats): 161 | for data in self.train_data_loader: 162 | losses = self.train_step(data, global_step=step_id) 163 | loss_str = fmt_loss_str(losses) 164 | if batch % self.print_interval == 0: 165 | print( 166 | "E", 167 | epoch, 168 | "B", 169 | batch, 170 | loss_str, 171 | " lr", 172 | self.optim.param_groups[0]["lr"], 173 | ) 174 | 175 | if batch % self.eval_interval == 0: 176 | test_data = next(test_data_iter) 177 | self.net.eval() 178 | with torch.no_grad(): 179 | test_losses = self.eval_step(test_data, global_step=step_id) 180 | self.net.train() 181 | test_loss_str = fmt_loss_str(test_losses) 182 | self.writer.add_scalars("train", losses, global_step=step_id) 183 | self.writer.add_scalars( 184 | "test", test_losses, global_step=step_id 185 | ) 186 | print("*** Eval:", "E", epoch, "B", batch, test_loss_str, " lr") 187 | 188 | if batch % self.save_interval == 0 and (epoch > 0 or batch > 0): 189 | print("saving") 190 | if self.managed_weight_saving: 191 | self.net.save_weights(self.args) 192 | else: 193 | torch.save( 194 | self.net.state_dict(), self.default_net_state_path 195 | ) 196 | torch.save(self.optim.state_dict(), self.optim_state_path) 197 | if self.lr_scheduler is not None: 198 | torch.save( 199 | self.lr_scheduler.state_dict(), self.lrsched_state_path 200 | ) 201 | torch.save({"iter": step_id + 1}, self.iter_state_path) 202 | self.extra_save_state() 203 | 204 | if batch % self.vis_interval == 0: 205 | print("generating visualization") 206 | if self.fixed_test: 207 | test_data = next(iter(self.test_data_loader)) 208 | else: 209 | test_data = next(test_data_iter) 210 | self.net.eval() 211 | with torch.no_grad(): 212 | vis, vis_vals = self.vis_step( 213 | test_data, global_step=step_id 214 | ) 215 | if vis_vals is not None: 216 | self.writer.add_scalars( 217 | "vis", vis_vals, global_step=step_id 218 | ) 219 | self.net.train() 220 | if vis is not None: 221 | import imageio 222 | 223 | vis_u8 = (vis * 255).astype(np.uint8) 224 | imageio.imwrite( 225 | os.path.join( 226 | self.visual_path, 227 | "{:04}_{:04}_vis.png".format(epoch, batch), 228 | ), 229 | vis_u8, 230 | ) 231 | 232 | if ( 233 | batch == self.num_total_batches - 1 234 | or batch % self.accu_grad == self.accu_grad - 1 235 | ): 236 | self.optim.step() 237 | self.optim.zero_grad() 238 | 239 | self.post_batch(epoch, batch) 240 | step_id += 1 241 | batch += 1 242 | progress.update(1) 243 | if self.lr_scheduler is not None: 244 | self.lr_scheduler.step() 245 | -------------------------------------------------------------------------------- /viewlist/2obj_eval_views.txt: -------------------------------------------------------------------------------- 1 | 0 2 4 6 8 10 12 14 16 18 2 | -------------------------------------------------------------------------------- /viewlist/srn_eval_views.txt: -------------------------------------------------------------------------------- 1 | 0 24 49 74 99 124 149 174 199 224 249 2 | --------------------------------------------------------------------------------