├── README.md ├── config ├── __init__.py └── defaults.py ├── configs ├── config_taekwondo.yml └── config_walking.yml ├── data ├── __init__.py ├── build.py ├── collate_batch.py ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── frame_dataset.cpython-36.pyc │ │ ├── frame_dataset.cpython-38.pyc │ │ ├── ibr_dynamic.cpython-36.pyc │ │ ├── ibr_dynamic.cpython-38.pyc │ │ ├── ray_dataset.cpython-36.pyc │ │ ├── ray_dataset.cpython-38.pyc │ │ ├── ray_source.cpython-36.pyc │ │ ├── ray_source.cpython-38.pyc │ │ ├── utils.cpython-36.pyc │ │ └── utils.cpython-38.pyc │ ├── frame_dataset.py │ ├── ray_dataset.py │ └── utils.py └── transforms │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── build.cpython-36.pyc │ ├── build.cpython-38.pyc │ ├── random_transforms.cpython-36.pyc │ └── random_transforms.cpython-38.pyc │ ├── build.py │ └── random_transforms.py ├── demo ├── taekwondo_demo.py └── walking_demo.py ├── engine ├── __init__.py ├── layered_trainer.py └── render.py ├── images └── teaser.jpg ├── layers ├── RaySamplePoint-1.py ├── RaySamplePoint.py ├── __init__.py ├── __pycache__ │ ├── RaySamplePoint.cpython-36.pyc │ ├── RaySamplePoint.cpython-38.pyc │ ├── RaySamplePoint1.cpython-38.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── camera_transform.cpython-38.pyc │ ├── loss.cpython-36.pyc │ ├── loss.cpython-38.pyc │ ├── render_layer.cpython-36.pyc │ └── render_layer.cpython-38.pyc ├── camera_transform.py ├── loss.py └── render_layer.py ├── modeling ├── __init__.py ├── layered_rfrender.py ├── motion_net.py └── spacenet.py ├── outputs ├── taekwondo │ └── layered_rfnr_checkpoint_1.pt └── walking │ └── layered_rfnr_checkpoint_1.pt ├── render ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── bkgd_renderer.cpython-38.pyc │ ├── layered_neural_renderer.cpython-38.pyc │ ├── neural_renderer.cpython-38.pyc │ └── render_functions.cpython-38.pyc ├── bkgd_renderer.py ├── layered_neural_renderer.py ├── neural_renderer.py └── render_functions.py ├── solver ├── __init__.py ├── build.py └── lr_scheduler.py └── utils ├── __init__.py ├── batchify_rays.py ├── dimension_kernel.py ├── high_dim_dics.py ├── logger.py ├── metrics.py ├── ray_sampling.py ├── render_helpers.py ├── sample_pdf.py └── vis_density.py /README.md: -------------------------------------------------------------------------------- 1 | # st-nerf 2 | 3 | We provide PyTorch implementations for our paper: 4 | [Editable Free-viewpoint Video Using a Layered Neural Representation](https://arxiv.org/abs/2104.14786) 5 | 6 | SIGGRAPH 2021 7 | 8 | Jiakai Zhang, Xinhang Liu, Xinyi Ye, Fuqiang Zhao, Yanshun Zhang, Minye Wu, Yingliang Zhang, Lan Xu and Jingyi Yu 9 | 10 | 11 | 12 | 13 | **st-nerf: [Project](https://jiakai-zhang.github.io/st-nerf/) | [Paper](https://arxiv.org/abs/2104.14786)** 14 | 15 | ## Getting Started 16 | ### Installation 17 | 18 | - Clone this repo: 19 | ```bash 20 | git clone https://github.com/DarlingHang/st-nerf 21 | cd st-nerf 22 | ``` 23 | 24 | - Install [PyTorch](http://pytorch.org) and other dependencies using: 25 | ``` 26 | conda create -n st-nerf python=3.8.5 27 | conda activate st-nerf 28 | conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch 29 | conda install imageio matplotlib 30 | pip install yacs kornia robpy 31 | ``` 32 | 33 | 34 | ### Datasets 35 | The walking and taekwondo datasets can be downloaded from [here](https://hkustconnect-my.sharepoint.com/:f:/g/personal/xliufe_connect_ust_hk/EjqArjZxmmtDplj_IrwlUq0BMUyG69zr5YqXFBxgku4rRQ?e=n2fSBs). 36 | 37 | ### Apply a pre-trained model to render demo videos 38 | - We provide our pretrained models which can be found under the `outputs` folder. 39 | - We provide some example scripts under the `demo` folder. 40 | - To run our demo scripts, you need to first downloaded the corresponding dataset, and put them under the folder specified by `DATASETS` -> `TRAIN` in `configs/config_taekwondo.yml` and `configs/config_walking.yml` 41 | - For the walking sequence, you can render videos where some performers are hided by typing the command: 42 | ``` 43 | python demo/walking_demo.py -c configs/config_walking.yml 44 | ``` 45 | - For the taekwondo sequence, you can render videos where performers are translated and scaled by typing the command: 46 | ``` 47 | python demo/taekwondo_demo.py -c configs/config_taekwondo.yml 48 | ``` 49 | - The rendered images and videos will be under `outputs/taekwondo/rendered` and `outputs/walking/rendered` 50 | 51 | ## Acknowlegements 52 | We borrowed some codes from [Multi-view Neural Human Rendering (NHR)](https://github.com/wuminye/NHR). 53 | 54 | ## Citation 55 | If you use this code for your research, please cite our papers. 56 | ``` 57 | @article{zhang2021editable, 58 | title={Editable free-viewpoint video using a layered neural representation}, 59 | author={Zhang, Jiakai and Liu, Xinhang and Ye, Xinyi and Zhao, Fuqiang and Zhang, Yanshun and Wu, Minye and Zhang, Yingliang and Xu, Lan and Yu, Jingyi}, 60 | journal={ACM Transactions on Graphics (TOG)}, 61 | volume={40}, 62 | number={4}, 63 | pages={1--18}, 64 | year={2021}, 65 | publisher={ACM New York, NY, USA} 66 | } 67 | ``` 68 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .defaults import _C as cfg 8 | -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Convention about Training / Test specific parameters 5 | # ----------------------------------------------------------------------------- 6 | # Whenever an argument can be either used for training or for testing, the 7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 8 | # or _TEST for a test-specific parameter. 9 | # For example, the number of images during training will be 10 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be 11 | # IMAGES_PER_BATCH_TEST 12 | 13 | # ----------------------------------------------------------------------------- 14 | # Config definition 15 | # ----------------------------------------------------------------------------- 16 | 17 | _C = CN() 18 | 19 | _C.deep_rgb = True 20 | 21 | _C.MODEL = CN() 22 | _C.MODEL.DEVICE = "cuda" 23 | _C.MODEL.COARSE_RAY_SAMPLING = 64 24 | _C.MODEL.FINE_RAY_SAMPLING = 80 25 | _C.MODEL.SAMPLE_METHOD = "NEAR_FAR" 26 | _C.MODEL.BOARDER_WEIGHT = 1e10 27 | _C.MODEL.SAME_SPACENET = False 28 | 29 | _C.MODEL.TKERNEL_INC_RAW = True 30 | _C.MODEL.POSE_REFINEMENT = True 31 | _C.MODEL.USE_DIR = True 32 | _C.MODEL.REMOVE_OUTLIERS = False 33 | _C.MODEL.TRAIN_BY_POINTCLOUD = False 34 | _C.MODEL.USE_DEFORM_VIEW = False # Use deformnet to deform view inconsisdency 35 | _C.MODEL.USE_DEFORM_TIME = False # Use deformnet to deform time inconsisdency 36 | _C.MODEL.BKGD_USE_DEFORM_TIME = False 37 | _C.MODEL.BKGD_USE_SPACE_TIME = False 38 | _C.MODEL.USE_SPACE_TIME = False 39 | _C.MODEL.DEEP_RGB = True 40 | 41 | 42 | 43 | 44 | # ----------------------------------------------------------------------------- 45 | # INPUT 46 | # ----------------------------------------------------------------------------- 47 | _C.INPUT = CN() 48 | # Size of the image during training 49 | _C.INPUT.SIZE_TRAIN = [400,250] 50 | # Size of the image during test 51 | _C.INPUT.SIZE_TEST = [400,250] 52 | # Size of the image during sample layer 53 | _C.INPUT.SIZE_LAYER = [400,250] 54 | # Minimum scale for the image during training 55 | _C.INPUT.MIN_SCALE_TRAIN = 0.5 56 | # Maximum scale for the image during test 57 | _C.INPUT.MAX_SCALE_TRAIN = 1.2 58 | # Random probability for image horizontal flip 59 | _C.INPUT.PROB = 0.5 60 | # Values to be used for image normalization 61 | _C.INPUT.PIXEL_MEAN = [0.1307, ] 62 | # Values to be used for image normalization 63 | _C.INPUT.PIXEL_STD = [0.3081, ] 64 | 65 | # ----------------------------------------------------------------------------- 66 | # Dataset 67 | # ----------------------------------------------------------------------------- 68 | _C.DATASETS = CN() 69 | # List of the dataset names for training, as present in paths_catalog.py 70 | _C.DATASETS.TRAIN = "" 71 | _C.DATASETS.TMP_RAYS = "rays_tmp" 72 | # List of the dataset names for testing, as present in paths_catalog.py 73 | _C.DATASETS.TEST = () 74 | _C.DATASETS.SHIFT = 0.0 75 | _C.DATASETS.MAXRATION = 0.0 76 | _C.DATASETS.ROTATION = 0.0 77 | _C.DATASETS.USE_MASK = False 78 | _C.DATASETS.NUM_FRAME = 1 79 | _C.DATASETS.FACTOR = 1 80 | _C.DATASETS.FIXED_NEAR = -1.0 81 | _C.DATASETS.FIXED_FAR = -1.0 82 | 83 | _C.DATASETS.CENTER_X = 0.0 84 | _C.DATASETS.CENTER_Y = 0.0 85 | _C.DATASETS.CENTER_Z = 0.0 86 | _C.DATASETS.SCALE = 1.0 87 | _C.DATASETS.FILE_OFFSET = 0 88 | _C.DATASETS.FRAME_OFFSET = 0 89 | _C.DATASETS.FRAME_NUM = 0 90 | _C.DATASETS.LAYER_NUM = 0 91 | _C.DATASETS.CAMERA_NUM = 0 92 | _C.DATASETS.BKGD_SAMPLE_RATE = 0.1 93 | _C.DATASETS.CAMERA_STEPSIZE = 1 94 | 95 | _C.DATASETS.USE_LABEL = False 96 | _C.DATASETS.VIEW_MASK = None 97 | _C.DATASETS.FIXED_LAYER = [] 98 | 99 | # ----------------------------------------------------------------------------- 100 | # DataLoader 101 | # ----------------------------------------------------------------------------- 102 | _C.DATALOADER = CN() 103 | # Number of data loading threads 104 | _C.DATALOADER.NUM_WORKERS = 8 105 | 106 | # ---------------------------------------------------------------------------- # 107 | # Solver 108 | # ---------------------------------------------------------------------------- # 109 | _C.SOLVER = CN() 110 | _C.SOLVER.OPTIMIZER_NAME = "SGD" 111 | 112 | _C.SOLVER.MAX_EPOCHS = 50 113 | 114 | _C.SOLVER.BASE_LR = 0.001 115 | _C.SOLVER.BIAS_LR_FACTOR = 2 116 | 117 | _C.SOLVER.MOMENTUM = 0.9 118 | 119 | _C.SOLVER.WEIGHT_DECAY = 0.0005 120 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0 121 | 122 | _C.SOLVER.GAMMA = 0.1 123 | _C.SOLVER.STEPS = (30000,) 124 | 125 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 3 126 | _C.SOLVER.WARMUP_ITERS = 500 127 | _C.SOLVER.WARMUP_METHOD = "linear" 128 | 129 | _C.SOLVER.CHECKPOINT_PERIOD = 10 130 | _C.SOLVER.LOG_PERIOD = 100 131 | _C.SOLVER.BUNCH = 4096 132 | _C.SOLVER.START_ITERS=50 133 | _C.SOLVER.END_ITERS=200 134 | _C.SOLVER.LR_SCALE=0.1 135 | _C.SOLVER.COARSE_STAGE = 10 136 | 137 | # Number of images per batch 138 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 139 | # see 2 images per batch 140 | _C.SOLVER.IMS_PER_BATCH = 16 141 | 142 | _C.SOLVER.BBOX_ID = 0 143 | 144 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 145 | # see 2 images per batch 146 | _C.TEST = CN() 147 | _C.TEST.IMS_PER_BATCH = 8 148 | _C.TEST.WEIGHT = "" 149 | 150 | # ---------------------------------------------------------------------------- # 151 | # Misc options 152 | # ---------------------------------------------------------------------------- # 153 | _C.OUTPUT_DIR = "" 154 | -------------------------------------------------------------------------------- /configs/config_taekwondo.yml: -------------------------------------------------------------------------------- 1 | 2 | SOLVER: 3 | OPTIMIZER_NAME: "Adam" 4 | BASE_LR: 0.0004 5 | WEIGHT_DECAY: 0.0000000 6 | IMS_PER_BATCH: 2000 7 | START_ITERS: 3000 8 | END_ITERS: 60000 9 | LR_SCALE: 0.09 10 | WARMUP_ITERS: 1000 11 | 12 | MAX_EPOCHS: 100 13 | CHECKPOINT_PERIOD: 3000 14 | LOG_PERIOD: 30 15 | BUNCH: 3000 16 | COARSE_STAGE: 1 17 | 18 | BBOX_ID: 0 19 | 20 | INPUT: 21 | # (4130,2202) / 4 = (1033,551) 22 | SIZE_TRAIN: [1920,1080] 23 | SIZE_LAYER: [1920,1080] 24 | SIZE_TEST: [1920,1080] 25 | 26 | DATASETS: 27 | TRAIN: 28 | TMP_RAYS: "rays_tmp_1920" 29 | NUM_FRAME: 1 30 | SHIFT: 0.0 31 | MAXRATION: 0.0 32 | ROTATION: 0.0 33 | FACTOR: 8 34 | FIXED_NEAR: -1.0 35 | FIXED_FAR: -1.0 36 | SCALE: 0.1 # scale in xyz position of world coordinate 37 | FILE_OFFSET: 0 38 | FRAME_OFFSET: 0 39 | BKGD_SAMPLE_RATE: 0.05 40 | 41 | USE_LABEL: True 42 | 43 | USE_MASK: False 44 | 45 | FRAME_NUM: 101 46 | LAYER_NUM: 2 47 | 48 | 49 | 50 | 51 | MODEL: 52 | COARSE_RAY_SAMPLING: 90 53 | FINE_RAY_SAMPLING: 30 54 | SAMPLE_METHOD: "BBOX" # "NEAR_FAR" "BBOX" 55 | BOARDER_WEIGHT: 1e10 56 | SAME_SPACENET: False 57 | TKERNEL_INC_RAW: True 58 | POSE_REFINEMENT: False # If doing the camera pose refinement 59 | USE_DIR: True 60 | REMOVE_OUTLIERS: True # Use masks to remove the density out of the mask 61 | USE_DEFORM_VIEW: False # Use deformnet to deform view inconsisdency 62 | USE_DEFORM_TIME: True # Use deformnet to deform time inconsisdency 63 | USE_SPACE_TIME: True 64 | BKGD_USE_DEFORM_TIME: False 65 | BKGD_USE_SPACE_TIME: False 66 | DEEP_RGB: False 67 | 68 | 69 | TEST: 70 | IMS_PER_BATCH: 1 71 | 72 | OUTPUT_DIR: "outputs/taekwondo" 73 | -------------------------------------------------------------------------------- /configs/config_walking.yml: -------------------------------------------------------------------------------- 1 | 2 | SOLVER: 3 | OPTIMIZER_NAME: "Adam" 4 | BASE_LR: 0.0004 5 | WEIGHT_DECAY: 0.0000000 6 | IMS_PER_BATCH: 2000 7 | START_ITERS: 3000 8 | END_ITERS: 60000 9 | LR_SCALE: 0.09 10 | WARMUP_ITERS: 1000 11 | 12 | MAX_EPOCHS: 100 13 | CHECKPOINT_PERIOD: 3000 14 | LOG_PERIOD: 30 15 | BUNCH: 3000 16 | COARSE_STAGE: 1 17 | 18 | INPUT: 19 | SIZE_TRAIN: [1920,1080] 20 | SIZE_LAYER: [1920,1080] 21 | SIZE_TEST: [1920,1080] 22 | 23 | DATASETS: 24 | TRAIN: 25 | TMP_RAYS: "rays_tmp_1920_BBOX" 26 | NUM_FRAME: 1 27 | SHIFT: 0.0 28 | MAXRATION: 0.0 29 | ROTATION: 0.0 30 | FACTOR: 8 31 | FIXED_NEAR: -1.0 32 | FIXED_FAR: -1.0 33 | SCALE: 1.0 # scale in xyz position of world coordinate 34 | FILE_OFFSET: 0 35 | FRAME_OFFSET: 25 36 | BKGD_SAMPLE_RATE: 0.0 37 | 38 | USE_LABEL: False 39 | 40 | USE_MASK: False 41 | 42 | FRAME_NUM: 50 43 | LAYER_NUM: 2 44 | 45 | 46 | MODEL: 47 | COARSE_RAY_SAMPLING: 90 48 | FINE_RAY_SAMPLING: 30 49 | SAMPLE_METHOD: "BBOX" # "NEAR_FAR" "BBOX" 50 | BOARDER_WEIGHT: 1e10 51 | SAME_SPACENET: False 52 | TKERNEL_INC_RAW: True 53 | POSE_REFINEMENT: False # If doing the camera pose refinement 54 | USE_DIR: True 55 | REMOVE_OUTLIERS: False # Use masks to remove the density out of the mask 56 | USE_DEFORM_VIEW: False # Use deformnet to deform view inconsisdency 57 | USE_DEFORM_TIME: True # Use deformnet to deform time inconsisdency 58 | USE_SPACE_TIME: False 59 | BKGD_USE_DEFORM_TIME: False 60 | BKGD_USE_SPACE_TIME: False 61 | 62 | TEST: 63 | IMS_PER_BATCH: 1 64 | 65 | OUTPUT_DIR: "outputs/walking" 66 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import make_ray_data_loader, make_ray_data_loader_view, make_ray_data_loader_render 8 | from .datasets.utils import get_iteration_path, get_iteration_path_and_iter 9 | -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: Minye Wu 4 | @GITHUB: wuminye 5 | """ 6 | 7 | from torch.utils import data 8 | import numpy as np 9 | from .datasets.ray_dataset import Ray_Dataset, Ray_Dataset_View, Ray_Dataset_Render, Ray_Frame_Layer_Dataset 10 | from .transforms import build_transforms, build_layered_transforms 11 | 12 | 13 | def make_ray_data_loader(cfg, is_train=True): 14 | 15 | batch_size = cfg.SOLVER.IMS_PER_BATCH 16 | 17 | transforms_bkgd = build_layered_transforms(cfg, is_train=is_train, is_layer=False) 18 | transforms_layer = build_layered_transforms(cfg, is_train=is_train, is_layer=True) 19 | 20 | datasets = Ray_Dataset(cfg, transforms_bkgd, transforms_layer) 21 | 22 | num_workers = cfg.DATALOADER.NUM_WORKERS 23 | data_loader = data.DataLoader( 24 | datasets, batch_size=batch_size, shuffle=True, num_workers=num_workers 25 | ) 26 | 27 | return data_loader, datasets 28 | 29 | def make_ray_data_loader_view(cfg, is_train=False): 30 | 31 | batch_size = cfg.SOLVER.IMS_PER_BATCH 32 | 33 | transforms = build_transforms(cfg, is_train) 34 | 35 | datasets = Ray_Dataset_View(cfg, transforms) 36 | 37 | num_workers = cfg.DATALOADER.NUM_WORKERS 38 | data_loader = data.DataLoader( 39 | datasets, batch_size=batch_size, shuffle=True, num_workers=num_workers 40 | ) 41 | 42 | return data_loader, datasets 43 | 44 | def make_ray_data_loader_render(cfg, is_train=False): 45 | 46 | batch_size = cfg.SOLVER.IMS_PER_BATCH 47 | 48 | 49 | transforms = build_transforms(cfg, is_train) 50 | 51 | datasets = Ray_Dataset_Render(cfg, transforms) 52 | 53 | num_workers = cfg.DATALOADER.NUM_WORKERS 54 | data_loader = data.DataLoader( 55 | datasets, batch_size=batch_size, shuffle=False, num_workers=num_workers 56 | ) 57 | 58 | return data_loader, datasets -------------------------------------------------------------------------------- /data/collate_batch.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | -------------------------------------------------------------------------------- /data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: Minye Wu 4 | @GITHUB: wuminye 5 | """ 6 | 7 | # from .ibr_dynamic import IBRDynamicDataset -------------------------------------------------------------------------------- /data/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/frame_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/frame_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/frame_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/frame_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/ibr_dynamic.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/ibr_dynamic.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/ibr_dynamic.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/ibr_dynamic.cpython-38.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/ray_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/ray_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/ray_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/ray_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/ray_source.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/ray_source.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/ray_source.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/ray_source.cpython-38.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /data/datasets/frame_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from .utils import campose_to_extrinsic, read_intrinsics, read_mask 5 | from PIL import Image 6 | import torchvision 7 | import torch.distributions as tdist 8 | import open3d as o3d 9 | 10 | class FrameDataset(torch.utils.data.Dataset): 11 | 12 | def __init__(self, dataset_path, transform, frame_id, layer_num, file_offset): 13 | 14 | super(FrameDataset, self).__init__() 15 | 16 | # 1. Set the dataset path for the next loading 17 | self.file_offset = file_offset 18 | self.frame_id = frame_id # The frame number 19 | self.layer_num = layer_num # The number of layers 20 | self.image_path = os.path.join(dataset_path,str(self.frame_id),'images') 21 | self.label_path = os.path.join(dataset_path,str(self.frame_id),'labels') 22 | self.pointcloud_path = os.path.join(dataset_path,str(self.frame_id),'pointclouds') 23 | self.pose_path = os.path.join(dataset_path, 'pose') 24 | self.transform = transform 25 | # 2. Loading Intrinsics & Camera poses 26 | camposes = np.loadtxt(os.path.join(self.pose_path,'RT_c2w.txt')) 27 | # Ts are camera poses 28 | self.Ts = torch.Tensor(campose_to_extrinsic(camposes)) 29 | self.Ks = torch.Tensor(read_intrinsics(os.path.join(self.pose_path,'K.txt'))) 30 | # 3. Load pointclouds for different layers 31 | self.pointclouds = [] # Finally (layer_num,) 32 | self.bboxs = [] # Finally (layer_num,) 33 | 34 | 35 | 36 | for i in range(layer_num): 37 | # Start from 1.ply to layer_num.ply 38 | pointcloud_name = os.path.join(self.pointcloud_path, '%d.ply' % (i+1)) 39 | 40 | if not os.path.exists(pointcloud_name): 41 | pointcloud_name = os.path.join(self.pointcloud_path1, '%d.ply' % (i+1)) 42 | 43 | if not os.path.exists(pointcloud_name): 44 | print('Cannot find corresponding pointcloud in path: ', pointcloud_name) 45 | pointcloud = o3d.io.read_point_cloud(pointcloud_name) 46 | xyz = np.asarray(pointcloud.points) 47 | 48 | xyz = torch.Tensor(xyz) 49 | self.pointclouds.append(xyz) 50 | 51 | max_xyz = torch.max(xyz, dim=0)[0] 52 | min_xyz = torch.min(xyz, dim=0)[0] 53 | 54 | # Default scalar is 0.3 55 | tmp = (max_xyz - min_xyz) * 0.0 56 | 57 | max_xyz = max_xyz + tmp 58 | min_xyz = min_xyz - tmp 59 | 60 | minx, miny, minz = min_xyz[0],min_xyz[1],min_xyz[2] 61 | maxx, maxy, maxz = max_xyz[0],max_xyz[1],max_xyz[2] 62 | bbox = torch.Tensor([[minx,miny,minz],[maxx,miny,minz],[maxx,maxy,minz],[minx,maxy,minz], 63 | [minx,miny,maxz],[maxx,miny,maxz],[maxx,maxy,maxz],[minx,maxy,maxz]]) 64 | 65 | bbox = bbox.reshape(1,8,3) 66 | 67 | self.bboxs.append(torch.Tensor(bbox)) 68 | 69 | print('Frame %d dataset loaded, there are totally %d layers' %(frame_id,layer_num)) 70 | 71 | def __len__(self): 72 | return self.cam_num * self.layer_num 73 | 74 | def get_data(self, camera_id, layer_id): 75 | # Find K,T, bbox 76 | T = self.Ts[camera_id] 77 | K = self.Ks[camera_id] 78 | bbox = self.bboxs[layer_id-1] 79 | # Load image 80 | image_path = os.path.join(self.image_path, '%03d.png' % (camera_id + self.file_offset)) 81 | image = Image.open(image_path) 82 | # Load label 83 | label_path = os.path.join(self.label_path, '%03d.npy' % (camera_id + self.file_offset)) 84 | if not os.path.exists(label_path): 85 | label_path = os.path.join(self.label_path, '%03d_label.npy' % (camera_id + self.file_offset)) 86 | label = np.load(label_path) 87 | 88 | # Transform image label K T to right scale 89 | image, label, K, T, ROI = self.transform(image, Ks=K, Ts=T, label=label) 90 | 91 | return image, label, K, T, ROI, bbox 92 | 93 | 94 | class FrameLayerDataset(torch.utils.data.Dataset): 95 | 96 | def __init__(self, cfg, transform, frame_id, layer_id): 97 | 98 | super(FrameLayerDataset, self).__init__() 99 | 100 | # 1. Set the dataset path for the next loading 101 | dataset_path = cfg.DATASETS.TRAIN 102 | fixed_near, fixed_far = cfg.DATASETS.FIXED_NEAR, cfg.DATASETS.FIXED_FAR 103 | scale=cfg.DATASETS.SCALE 104 | camera_stepsize=cfg.DATASETS.CAMERA_STEPSIZE 105 | self.file_offset = cfg.DATASETS.FILE_OFFSET 106 | 107 | self.frame_id = frame_id # The frame id 108 | self.layer_id = layer_id # The layer id 109 | self.image_path = os.path.join(dataset_path,'frame'+str(self.frame_id),'images') 110 | self.label_path = os.path.join(dataset_path,'frame'+str(self.frame_id),'labels') 111 | #TODO: Need to fix when background is deformable 112 | self.pointcloud_path1 = "None" 113 | 114 | if layer_id != 0: 115 | self.pointcloud_path = os.path.join(dataset_path,'frame'+str(self.frame_id),'pointclouds') 116 | self.pointcloud_path1 = os.path.join(dataset_path,'background') 117 | else: 118 | self.pointcloud_path = os.path.join(dataset_path,'background') 119 | self.pose_path = os.path.join(dataset_path, 'pose') 120 | self.transform = transform 121 | # 2. Loading Intrinsics & Camera poses 122 | camposes = np.loadtxt(os.path.join(self.pose_path,'RT_c2w.txt')) 123 | 124 | # Ts are camera poses 125 | self.Ts = torch.Tensor(campose_to_extrinsic(camposes)) 126 | self.Ts[:,0:3,3] = self.Ts[:,0:3,3] * scale 127 | print('scale is ', scale) 128 | 129 | self.Ks = torch.Tensor(read_intrinsics(os.path.join(self.pose_path,'K.txt'))) 130 | 131 | self.cfg = cfg 132 | if cfg.DATASETS.CAMERA_NUM == 0: 133 | self.cam_num = self.Ts.shape[0] 134 | else: 135 | self.cam_num = cfg.DATASETS.CAMERA_NUM 136 | 137 | self.mask = np.ones(self.Ts.shape[0]) 138 | self.mask_path = cfg.DATASETS.VIEW_MASK 139 | if self.mask_path != None: 140 | if os.path.exists(self.mask_path): 141 | self.mask = read_mask(self.mask_path) 142 | 143 | pointcloud_name = os.path.join(self.pointcloud_path, '%d.ply' % (layer_id)) 144 | 145 | self.pointcloud = None 146 | if not os.path.exists(pointcloud_name): 147 | pointcloud_name = os.path.join(self.pointcloud_path1, '%d.ply' % (layer_id)) 148 | 149 | bbox_name = 'bbox_tmp' 150 | if not os.path.exists(pointcloud_name): 151 | print('Warning: Cannot find corresponding pointcloud in path: ', pointcloud_name) 152 | self.bbox = None 153 | self.center = torch.Tensor([0,0,0]) 154 | tmp_bbox_path = os.path.join(dataset_path,bbox_name,'frame'+str(frame_id),'layer'+str(layer_id)) 155 | if os.path.exists(os.path.join(tmp_bbox_path,'center.pt')): 156 | print('There are bbox generated for layer %d, frame %d before, loading bbox...' % (layer_id, frame_id)) 157 | # pointcloud = o3d.io.read_point_cloud(pointcloud_name) 158 | # xyz = np.asarray(pointcloud.points) 159 | 160 | # xyz = torch.Tensor(xyz) 161 | # self.pointcloud = xyz 162 | self.center = torch.load(os.path.join(tmp_bbox_path,'center.pt')) 163 | self.bbox = torch.load(os.path.join(tmp_bbox_path,'bbox.pt')) 164 | else: 165 | tmp_bbox_path = os.path.join(dataset_path,bbox_name,'frame'+str(frame_id),'layer'+str(layer_id)) 166 | if not os.path.exists(os.path.join(tmp_bbox_path,'center.pt')): 167 | print('There is no bbox generated before, generating bbox...') 168 | if not os.path.exists(tmp_bbox_path): 169 | os.makedirs(tmp_bbox_path) 170 | pointcloud = o3d.io.read_point_cloud(pointcloud_name) 171 | xyz = np.asarray(pointcloud.points) 172 | 173 | xyz = torch.Tensor(xyz) 174 | self.pointcloud = xyz * scale 175 | 176 | max_xyz = torch.max(self.pointcloud, dim=0)[0] 177 | min_xyz = torch.min(self.pointcloud, dim=0)[0] 178 | 179 | # Default scalar is 0.3 180 | tmp = (max_xyz - min_xyz) * 0.0 181 | 182 | max_xyz = max_xyz + tmp 183 | min_xyz = min_xyz - tmp 184 | 185 | minx, miny, minz = min_xyz[0],min_xyz[1],min_xyz[2] 186 | maxx, maxy, maxz = max_xyz[0],max_xyz[1],max_xyz[2] 187 | bbox = torch.Tensor([[minx,miny,minz],[maxx,miny,minz],[maxx,maxy,minz],[minx,maxy,minz], 188 | [minx,miny,maxz],[maxx,miny,maxz],[maxx,maxy,maxz],[minx,maxy,maxz]]) 189 | 190 | bbox = bbox.reshape(1,8,3) 191 | 192 | self.center = np.array([(min_xyz[0]+max_xyz[0])/2, (min_xyz[1]+max_xyz[1])/2, (min_xyz[2]+max_xyz[2])/2]) 193 | self.bbox = torch.Tensor(bbox) 194 | if not os.path.exists(os.path.join(tmp_bbox_path,'center.pt')): 195 | torch.save(self.center, os.path.join(tmp_bbox_path,'center.pt')) 196 | if not os.path.exists(os.path.join(tmp_bbox_path,'bbox.pt')): 197 | torch.save(self.bbox, os.path.join(tmp_bbox_path,'bbox.pt')) 198 | else: 199 | print('There are bbox generated for layer %d, frame %d before, loading bbox...' % (layer_id, frame_id)) 200 | # pointcloud = o3d.io.read_point_cloud(pointcloud_name) 201 | # xyz = np.asarray(pointcloud.points) 202 | 203 | # xyz = torch.Tensor(xyz) 204 | # self.pointcloud = xyz 205 | self.center = torch.load(os.path.join(tmp_bbox_path,'center.pt')) 206 | self.bbox = torch.load(os.path.join(tmp_bbox_path,'bbox.pt')) 207 | 208 | 209 | if fixed_near == -1.0 and fixed_far == -1.0: 210 | near_far_name = 'near_far_tmp' 211 | tmp_near_far_path = os.path.join(dataset_path,near_far_name,'frame'+str(frame_id),'layer'+str(layer_id)) 212 | if not os.path.exists(os.path.join(tmp_near_far_path,'near.pt')): 213 | if not os.path.exists(os.path.join(tmp_near_far_path)): 214 | os.makedirs(tmp_near_far_path) 215 | inv_Ts = torch.inverse(self.Ts).unsqueeze(1) #(M,1,4,4) 216 | 217 | if self.pointcloud is None: 218 | pointcloud = o3d.io.read_point_cloud(pointcloud_name) 219 | xyz = np.asarray(pointcloud.points) 220 | 221 | xyz = torch.Tensor(xyz) 222 | self.pointcloud = xyz * scale 223 | vs = self.pointcloud.clone().unsqueeze(-1) #(N,3,1) 224 | vs = torch.cat([vs,torch.ones(vs.size(0),1,vs.size(2)) ],dim=1) #(N,4,1) 225 | 226 | pts = torch.matmul(inv_Ts,vs) #(M,N,4,1) 227 | 228 | pts_max = torch.max(pts, dim=1)[0].squeeze() #(M,4) 229 | pts_min = torch.min(pts, dim=1)[0].squeeze() #(M,4) 230 | 231 | pts_max = pts_max[:,2] #(M) 232 | pts_min = pts_min[:,2] #(M) 233 | 234 | self.near = pts_min 235 | # self.near[self.near<(pts_max*0.1)] = pts_max[self.near<(pts_max*0.1)]*0.1 236 | 237 | self.far = pts_max 238 | torch.save(self.near,os.path.join(tmp_near_far_path,'near.pt')) 239 | torch.save(self.far,os.path.join(tmp_near_far_path,'far.pt')) 240 | else: 241 | self.near = torch.load(os.path.join(tmp_near_far_path,'near.pt')) 242 | self.far = torch.load(os.path.join(tmp_near_far_path,'far.pt')) 243 | else: 244 | self.near = torch.ones(self.Ts.shape[0]) * fixed_near 245 | self.far = torch.ones(self.Ts.shape[0]) * fixed_far 246 | 247 | print('Layer %d, Frame %d dataset loaded' %(layer_id,frame_id)) 248 | 249 | def __len__(self): 250 | return self.cam_num 251 | 252 | def get_data(self, camera_id): 253 | # when camera num is not equal to zero, means we want a complete offset from camera parameters to images, else, only images 254 | if self.cfg.DATASETS.CAMERA_NUM != 0: 255 | camera_id = camera_id + self.file_offset 256 | if self.mask[camera_id] == 0: 257 | return None, None, None, None, None, None, None, 0 258 | # Find K,T, bbox 259 | 260 | T = self.Ts[camera_id] 261 | K = self.Ks[camera_id] 262 | bbox = self.bbox 263 | # Load image 264 | image_path = os.path.join(self.image_path, '%03d.png' % (camera_id)) 265 | if not os.path.exists(image_path): 266 | image_path = os.path.join(self.image_path, '%d.png' % (camera_id)) 267 | if not os.path.exists(image_path): 268 | image = None 269 | else: 270 | image = Image.open(image_path) 271 | # Load label 272 | label = None 273 | label_path = os.path.join(self.label_path, '%03d.npy' % (camera_id)) 274 | if not os.path.exists(label_path): 275 | label_path = os.path.join(self.label_path, '%03d_label.npy' % (camera_id)) 276 | if not os.path.exists(label_path): 277 | label_path = os.path.join(self.label_path, '%d.npy' % (camera_id)) 278 | if not os.path.exists(label_path): 279 | if image == None: 280 | label = None 281 | else: 282 | width, height = image.size 283 | label = np.ones((height, width)) * self.layer_id 284 | print('Warning: There is no label map for this dataset, and we trying to train layer %d, for frame %d, so generate a full label map with it' % (self.layer_id, self.frame_id)) 285 | else: 286 | label = np.load(label_path) 287 | 288 | # Transform image label K T to right scale 289 | image, label, K, T, ROI = self.transform(image, Ks=K, Ts=T, label=label) 290 | 291 | return image, label, K, T, ROI, bbox, torch.tensor([self.near[camera_id],self.far[camera_id]]).unsqueeze(0), self.mask[camera_id] 292 | 293 | def get_original_size(self): 294 | 295 | image_path = os.path.join(self.image_path, '%03d.png' % (0)) 296 | if not os.path.exists(image_path): 297 | image_path = os.path.join(self.image_path, '%d.png' % (0)) 298 | if not os.path.exists(image_path): 299 | image = None 300 | else: 301 | image = Image.open(image_path) 302 | 303 | return image.size 304 | 305 | 306 | -------------------------------------------------------------------------------- /data/datasets/ray_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from math import sin, cos, pi 4 | import os 5 | from .utils import campose_to_extrinsic, read_intrinsics 6 | from PIL import Image 7 | import torchvision 8 | import torch.distributions as tdist 9 | 10 | from .frame_dataset import FrameDataset, FrameLayerDataset 11 | from utils import ray_sampling, ray_sampling_label_bbox, lookat, getSphericalPosition, generate_rays, ray_sampling_label_label 12 | 13 | class Ray_Dataset(torch.utils.data.Dataset): 14 | 15 | def __init__(self, cfg, transforms_bkgd, transforms_layer): 16 | 17 | super(Ray_Dataset, self).__init__() 18 | 19 | frame_num=cfg.DATASETS.FRAME_NUM 20 | layer_num=cfg.DATASETS.LAYER_NUM 21 | 22 | frame_offset=cfg.DATASETS.FRAME_OFFSET 23 | bkgd_sample_rate = cfg.DATASETS.BKGD_SAMPLE_RATE 24 | 25 | # [[bkgd_frame1,bkgd_frame2,...,],[layer1_frame_1,layer1_frame2,...,],[layer2_frame1,layer2_frame2,...,],...,] 26 | self.datasets = [] 27 | self.bboxes = torch.zeros(frame_num+frame_offset, layer_num, 8, 3) 28 | for layer_id in range(layer_num+1): 29 | datasets_layer = [] 30 | for frame_id in range(1+frame_offset,frame_offset+frame_num+1): 31 | if layer_id == 0: 32 | sample_rate=bkgd_sample_rate 33 | use_label_map=True 34 | transform=transforms_bkgd 35 | else: 36 | sample_rate=1 37 | for i in range(len(cfg.DATASETS.FIXED_LAYER)): 38 | if cfg.DATASETS.FIXED_LAYER[i] == layer_id: 39 | sample_rate = 0 40 | use_label_map=cfg.DATASETS.USE_LABEL 41 | transform=transforms_layer 42 | dataset_frame_layer = Ray_Frame_Layer_Dataset(cfg, transform, frame_id, layer_id, use_label_map, sample_rate) 43 | if layer_id != 0: 44 | self.bboxes[frame_id-1, layer_id-1] = dataset_frame_layer.layer_bbox 45 | datasets_layer.append(dataset_frame_layer) 46 | self.datasets.append(datasets_layer) 47 | 48 | self.frame_num = frame_num 49 | self.layer_num = layer_num 50 | 51 | self.bkgd_sample_rate = bkgd_sample_rate 52 | 53 | self.ray_length = np.zeros(layer_num+1) 54 | 55 | for l in range(len(self.datasets)): 56 | layer_datasets = self.datasets[l] 57 | for layer_frame_dataset in layer_datasets: 58 | self.ray_length[l] += len(layer_frame_dataset) 59 | 60 | for l in range(len(self.datasets)): 61 | print('Layer %d has %d rays' % (l, int(self.ray_length[l]))) 62 | self.length = int(sum(self.ray_length)) 63 | 64 | print('The whole ray number is %d' % self.length) 65 | self.camera_num = self.datasets[0][0].camera_num 66 | 67 | 68 | def __len__(self): 69 | 70 | return self.length 71 | 72 | def __getitem__(self, index): 73 | # if index < self.bkgd_length: 74 | # index = int(index / self.bkgd_sample_rate) 75 | # else: 76 | # index = (index-self.bkgd_length) + self.original_bkgd_length 77 | temp = 0 78 | for layer_datasets in self.datasets: 79 | for layer_frame_dataset in layer_datasets: 80 | if temp + len(layer_frame_dataset) > index: 81 | return layer_frame_dataset[index-temp] 82 | else: 83 | temp += len(layer_frame_dataset) 84 | 85 | class Ray_Dataset_View(torch.utils.data.Dataset): 86 | 87 | def __init__(self, cfg, transform): 88 | 89 | super(Ray_Dataset_View, self).__init__() 90 | 91 | # Save input 92 | self.dataset_path = cfg.DATASETS.TRAIN 93 | self.frame_num = cfg.DATASETS.FRAME_NUM 94 | self.layer_num = cfg.DATASETS.LAYER_NUM 95 | self.frame_offset = cfg.DATASETS.FRAME_OFFSET 96 | 97 | self.pose_refinement = cfg.MODEL.POSE_REFINEMENT 98 | self.use_deform_view = cfg.MODEL.USE_DEFORM_VIEW 99 | self.use_deform_time = cfg.MODEL.USE_DEFORM_TIME 100 | self.use_space_time = cfg.MODEL.USE_SPACE_TIME 101 | remove_outliers = cfg.MODEL.REMOVE_OUTLIERS 102 | 103 | self.transform = transform 104 | 105 | self.layer_frame_datasets = [] 106 | for layer_id in range(self.layer_num+1): 107 | datasets_layer = [] 108 | for frame_id in range(1+self.frame_offset,self.frame_offset+self.frame_num+1): 109 | dataset_frame_layer = FrameLayerDataset(cfg, transform, frame_id, layer_id) 110 | datasets_layer.append(dataset_frame_layer) 111 | self.layer_frame_datasets.append(datasets_layer) 112 | self.camera_num = self.layer_frame_datasets[0][0].cam_num 113 | 114 | def __len__(self): 115 | return 1 116 | 117 | def get_fixed_image(self, index_view, index_frame): 118 | 119 | print(index_view) 120 | print(index_frame) 121 | 122 | bboxes = [] 123 | K = None 124 | T = None 125 | label = None 126 | image = None 127 | for i in range(self.layer_num+1): 128 | image_tmp, label_tmp, K_tmp, T_tmp, _, bbox, near_far = self.layer_frame_datasets[i][index_frame].get_data(index_view) 129 | if K is None: 130 | K = K_tmp 131 | if T is None: 132 | T = T_tmp 133 | if label is None: 134 | label = label_tmp 135 | if image is None: 136 | image = image_tmp 137 | bboxes.append(bbox) 138 | 139 | rays, labels, rgbs, ray_mask, layered_bboxes = ray_sampling_label_bbox(image,label,K,T,bboxes=bboxes) 140 | # rays,rgbs = ray_sampling(K.unsqueeze(0), T.unsqueeze(0), (image.size(1),image.size(2)), images = image.unsqueeze(0) ) 141 | if self.pose_refinement: 142 | rays_o, rays_d = rays[:, :3], rays[:, 3:6] 143 | ids=torch.ones((rays_o.size(0),1))*index 144 | rays=torch.cat([rays_o,ids,rays_d,ids],dim = 1) 145 | 146 | if self.use_deform_view: 147 | camera_ids=torch.ones((rays.size(0),1)) * index 148 | rays=torch.cat([rays, camera_ids],dim=-1) 149 | 150 | if self.use_deform_time or self.use_space_time: 151 | frame_ids = torch.Tensor([index_frame+self.frame_offset+1]).reshape(1,1).repeat(rays.shape[0],1) 152 | rays=torch.cat([rays, frame_ids],dim=-1) 153 | 154 | return rays, rgbs, labels, image, label, ray_mask, layered_bboxes, near_far.repeat(rays.size(0),1) 155 | 156 | def __getitem__(self, index): 157 | 158 | index_frame = np.random.randint(0,self.frame_num) 159 | index_view = np.random.randint(0,self.camera_num) 160 | _, _, _, _, _, _, _, mask = self.layer_frame_datasets[0][index_frame].get_data(index_view) 161 | while (mask == 0): 162 | 163 | index_view = np.random.randint(0,self.camera_num) 164 | _, _, _, _, _, _, _, mask = self.layer_frame_datasets[0][index_frame].get_data(index_view) 165 | 166 | print(index_view) 167 | print(index_frame) 168 | 169 | bboxes = [] 170 | K = None 171 | T = None 172 | label = None 173 | image = None 174 | for i in range(self.layer_num+1): 175 | image_tmp, label_tmp, K_tmp, T_tmp, _, bbox, near_far, _ = self.layer_frame_datasets[i][index_frame].get_data(index_view) 176 | if K is None: 177 | K = K_tmp 178 | if T is None: 179 | T = T_tmp 180 | if label is None: 181 | label = label_tmp 182 | if image is None: 183 | image = image_tmp 184 | bboxes.append(bbox) 185 | 186 | rays, labels, rgbs, ray_mask, layered_bboxes = ray_sampling_label_bbox(image,label,K,T,bboxes=bboxes) 187 | # rays,rgbs = ray_sampling(K.unsqueeze(0), T.unsqueeze(0), (image.size(1),image.size(2)), images = image.unsqueeze(0) ) 188 | if self.pose_refinement: 189 | rays_o, rays_d = rays[:, :3], rays[:, 3:6] 190 | ids=torch.ones((rays_o.size(0),1))*index 191 | rays=torch.cat([rays_o,ids,rays_d,ids],dim = 1) 192 | 193 | if self.use_deform_view: 194 | camera_ids=torch.ones((rays.size(0),1)) * index 195 | rays=torch.cat([rays, camera_ids],dim=-1) 196 | 197 | if self.use_deform_time or self.use_space_time: 198 | frame_ids = torch.Tensor([index_frame+self.frame_offset+1]).reshape(1,1).repeat(rays.shape[0],1) 199 | rays=torch.cat([rays, frame_ids],dim=-1) 200 | 201 | return rays, rgbs, labels, image, label, ray_mask, layered_bboxes, near_far.repeat(rays.size(0),1) 202 | 203 | class Ray_Dataset_Render(torch.utils.data.Dataset): 204 | 205 | def __init__(self, cfg, transform): 206 | 207 | super(Ray_Dataset_Render, self).__init__() 208 | 209 | # Save input 210 | self.use_deform_time = cfg.MODEL.USE_DEFORM_TIME 211 | self.use_space_time = cfg.MODEL.USE_SPACE_TIME 212 | 213 | frame_offset = cfg.DATASETS.FRAME_OFFSET 214 | layer_num = cfg.DATASETS.LAYER_NUM 215 | frame_num = cfg.DATASETS.FRAME_NUM 216 | 217 | self.layer_num = layer_num 218 | 219 | self.datasets = [] 220 | 221 | self.bboxes = torch.zeros(frame_num+frame_offset, layer_num, 8, 3) 222 | 223 | 224 | for layer_id in range(layer_num+1): 225 | datasets_layer = [] 226 | for frame_id in range(1+frame_offset,frame_offset+frame_num+1): 227 | dataset_frame_layer = FrameLayerDataset(cfg, transform, frame_id, layer_id) 228 | datasets_layer.append(dataset_frame_layer) 229 | if layer_id != 0: 230 | self.bboxes[frame_id-1, layer_id-1] = dataset_frame_layer.bbox 231 | self.datasets.append(datasets_layer) 232 | 233 | self.camera_num = self.datasets[0][0].cam_num 234 | self.poses = self.datasets[0][0].Ts 235 | 236 | # Default layer size is original size 237 | self.Ks = self.datasets[0][0].Ks 238 | col, row = self.datasets[0][0].get_original_size() 239 | self.Ks[:,0,0] = self.Ks[:,0,0] * cfg.INPUT.SIZE_TEST[0] / col 240 | self.Ks[:,1,1] = self.Ks[:,1,1] * cfg.INPUT.SIZE_TEST[0] / col 241 | self.Ks[:,0,2] = self.Ks[:,0,2] * cfg.INPUT.SIZE_TEST[0] / col 242 | self.Ks[:,1,2] = self.Ks[:,1,2] * cfg.INPUT.SIZE_TEST[0] / col 243 | 244 | # Use original image size, intrinsic and bbox 245 | image, _, self.K, _, _, _, _, _ = self.datasets[0][0].get_data(0) 246 | 247 | # for i in range(len(self.datasets[0][0])): 248 | # _, _, K, _, _, _, _, _ = self.datasets[0][0].get_data(i) 249 | # self.Ks.append(K) 250 | 251 | 252 | _, self.height, self.width = image.shape 253 | 254 | self.near_far = torch.Tensor([cfg.DATASETS.FIXED_NEAR,cfg.DATASETS.FIXED_FAR]).reshape(1,2) 255 | 256 | def get_image_label(self, camera_id, frame_id): 257 | image, label, _, _, _, _, _, _ = self.datasets[frame_id][0].get_data(camera_id) 258 | return image, label 259 | 260 | def get_rays_by_pose_and_K(self, T, K, layer_frame_pair): 261 | 262 | T = torch.Tensor(T) 263 | rays, _ = generate_rays(K, T, None, self.height, self.width) 264 | 265 | #TODO: now bbox and near far is no use 266 | near_fars = self.near_far.repeat(rays.size(0),1) 267 | bboxes = torch.zeros(rays.size(0),8,3) 268 | labels = torch.zeros(rays.size(0)) 269 | 270 | # bboxes = [] 271 | # for layer_id, frame_id in layer_frame_pair: 272 | # bboxes.append(self.bboxes[frame_id-1,layer_id]) 273 | 274 | # rays,rgbs = ray_sampling(K.unsqueeze(0), T.unsqueeze(0), (image.size(1),image.size(2)), images = image.unsqueeze(0) ) 275 | 276 | if self.use_deform_time or self.use_space_time: 277 | frame_ids = torch.zeros(rays.size(0),self.layer_num+1) 278 | for layer_id, frame_id in layer_frame_pair: 279 | frame_ids[:,layer_id] = frame_id 280 | 281 | rays=torch.cat([rays, frame_ids],dim=-1) 282 | 283 | return rays, labels, bboxes, near_fars 284 | #Use the first K of the dataset by default 285 | def get_rays_by_pose(self, T, layer_frame_pair): 286 | 287 | T = torch.Tensor(T) 288 | rays, _ = generate_rays(self.K, T, None, self.height, self.width) 289 | 290 | #TODO: now bbox and near far is no use 291 | near_fars = self.near_far.repeat(rays.size(0),1) 292 | bboxes = torch.zeros(rays.size(0),8,3) 293 | labels = torch.zeros(rays.size(0)) 294 | 295 | # bboxes = [] 296 | # for layer_id, frame_id in layer_frame_pair: 297 | # bboxes.append(self.bboxes[frame_id-1,layer_id]) 298 | 299 | # rays,rgbs = ray_sampling(K.unsqueeze(0), T.unsqueeze(0), (image.size(1),image.size(2)), images = image.unsqueeze(0) ) 300 | 301 | if self.use_deform_time: 302 | frame_ids = torch.zeros(rays.size(0),self.layer_num+1) 303 | for layer_id, frame_id in layer_frame_pair: 304 | frame_ids[:,layer_id] = frame_id 305 | 306 | rays=torch.cat([rays, frame_ids],dim=-1) 307 | 308 | return rays, labels, bboxes, near_fars 309 | 310 | def get_rays_by_lookat(self,eye,center,up, layer_frame_pair): 311 | 312 | T = torch.Tensor(lookat(eye,center,up)) 313 | return self.get_rays_by_pose(T, layer_frame_pair) 314 | 315 | def get_rays_by_spherical(self, theta, phi, radius,offsets, up, layer_frame_pair): 316 | up = np.array(up) 317 | offsets = np.array(offsets) 318 | 319 | pos = getSphericalPosition(radius,theta,phi) 320 | pos += self.center 321 | pos += offsets 322 | T = torch.Tensor(lookat(pos,self.center,up)) 323 | 324 | return self.get_rays_by_pose(T, layer_frame_pair) 325 | 326 | def get_pose_by_lookat(self, eye,center,up): 327 | return torch.Tensor(lookat(eye,center,up)) 328 | 329 | def get_pose_by_spherical(self, theta, phi, radius, offsets, up): 330 | up = np.array(up) 331 | offsets = np.array(offsets) 332 | 333 | pos = getSphericalPosition(radius,theta,phi) 334 | pos += self.center 335 | pos += offsets 336 | T = torch.Tensor(lookat(pos,self.center,up)) 337 | return T 338 | 339 | class Ray_Frame_Layer_Dataset(torch.utils.data.Dataset): 340 | 341 | def __init__(self, cfg, transform, frame_id, layer_id, use_label_map, sample_rate): 342 | 343 | super(Ray_Frame_Layer_Dataset, self).__init__() 344 | 345 | 346 | # Save input 347 | self.dataset_path = cfg.DATASETS.TRAIN 348 | self.tmp_rays = cfg.DATASETS.TMP_RAYS 349 | self.camera_stepsize = cfg.DATASETS.CAMERA_STEPSIZE 350 | 351 | self.pose_refinement = cfg.MODEL.POSE_REFINEMENT 352 | self.use_deform_view = cfg.MODEL.USE_DEFORM_VIEW 353 | self.use_deform_time = cfg.MODEL.USE_DEFORM_TIME 354 | self.use_space_time = cfg.MODEL.USE_SPACE_TIME 355 | 356 | self.transform = transform 357 | self.frame_id = frame_id 358 | self.layer_id = layer_id 359 | 360 | # Generate Frame Dataset 361 | self.frame_dataset = FrameLayerDataset(cfg, transform, frame_id, layer_id) 362 | self.camera_num = self.frame_dataset.cam_num 363 | # Save layered rays, rgbs, labels, bboxs, near_fars 364 | self.layer_rays = [] 365 | self.layer_rgbs = [] 366 | self.layer_labels = [] 367 | if self.frame_dataset.bbox != None: 368 | self.layer_bbox = self.frame_dataset.bbox 369 | else: 370 | self.layer_bbox = torch.zeros(8,3) 371 | self.near_fars = [] 372 | 373 | # Check if we already generate rays 374 | tmp_ray_path = os.path.join(self.dataset_path,self.tmp_rays,'frame'+str(frame_id)) 375 | if not os.path.exists(tmp_ray_path): 376 | print('There is no rays generated before, generating rays...') 377 | os.makedirs(tmp_ray_path) 378 | 379 | # tranverse every camera 380 | tmp_layer_ray_path = os.path.join(tmp_ray_path,'layer'+str(layer_id)) 381 | if sample_rate == 0.0: 382 | print('Skiping layer %d, frame %d rays for zero sample rate...' % (layer_id, frame_id)) 383 | self.layer_rays = torch.tensor([]) 384 | self.layer_rgbs = torch.tensor([]) 385 | self.layer_labels = torch.tensor([]) 386 | self.near_fars = torch.tensor([]) 387 | elif not os.path.exists(tmp_layer_ray_path) or cfg.clean_ray: 388 | rays_tmp = [] 389 | rgbs_tmp = [] 390 | labels_tmp = [] 391 | near_fars_tmp = [] 392 | print('There is no rays generated for layer %d, frame %d before, generating rays...' % (layer_id, frame_id)) 393 | for i in range(0,self.frame_dataset.cam_num,self.camera_stepsize): 394 | print('Generating Layer %d, Camera %d rays...'% (layer_id,i)) 395 | 396 | image, label, K, T, ROI, bbox, near_far, mask = self.frame_dataset.get_data(i) 397 | 398 | if not mask: 399 | print('Skiping Camera %d by mask'% (i)) 400 | continue 401 | 402 | if not use_label_map: 403 | rays, labels, rgbs, _ = ray_sampling_label_bbox(image,label,K,T,bbox) 404 | else: 405 | rays, labels, rgbs, _ = ray_sampling_label_label(image,label,K,T,layer_id) 406 | 407 | if self.pose_refinement: 408 | rays_o, rays_d = rays[:, :3], rays[:, 3:6] 409 | ids=torch.ones((rays_o.size(0),1))*i 410 | rays=torch.cat([rays_o,ids,rays_d,ids],dim = 1) 411 | 412 | if self.use_deform_view: 413 | camera_ids=torch.ones((rays.size(0),1))*i 414 | rays=torch.cat([rays, camera_ids],dim=-1) 415 | 416 | if self.use_deform_time or self.use_space_time: 417 | frame_ids = torch.Tensor([frame_id]).reshape(1,1).repeat(rays.shape[0],1) 418 | rays=torch.cat([rays, frame_ids],dim=-1) 419 | 420 | near_fars_tmp.append(near_far.repeat(rays.size(0),1)) 421 | rays_tmp.append(rays) 422 | rgbs_tmp.append(rgbs) 423 | labels_tmp.append(labels) 424 | 425 | self.layer_rays = torch.cat(rays_tmp,0) 426 | self.layer_rgbs = torch.cat(rgbs_tmp,0) 427 | self.layer_labels = torch.cat(labels_tmp,0) 428 | self.near_fars = torch.cat(near_fars_tmp,0) 429 | if sample_rate != 1: 430 | rand_idx = torch.randperm(self.layer_rays.size(0)) 431 | self.layer_rays = self.layer_rays[rand_idx] 432 | self.layer_rgbs = self.layer_rgbs[rand_idx] 433 | self.layer_labels = self.layer_labels[rand_idx] 434 | self.near_fars = self.near_fars[rand_idx] 435 | end = int(self.layer_rays.size(0) * sample_rate) 436 | self.layer_rays = self.layer_rays[:end,:].clone().detach() 437 | self.layer_rgbs = self.layer_rgbs[:end,:].clone().detach() 438 | self.layer_labels = self.layer_labels[:end,:].clone().detach() 439 | self.near_fars = self.near_fars[:end,:].clone().detach() 440 | if not os.path.exists(tmp_layer_ray_path): 441 | os.mkdir(tmp_layer_ray_path) 442 | torch.save(self.layer_rays, os.path.join(tmp_layer_ray_path,'rays.pt')) 443 | torch.save(self.layer_rgbs, os.path.join(tmp_layer_ray_path,'rgbs.pt')) 444 | torch.save(self.layer_labels, os.path.join(tmp_layer_ray_path,'labels.pt')) 445 | torch.save(self.near_fars, os.path.join(tmp_layer_ray_path,'near_fars.pt')) 446 | else: 447 | print('There are rays generated for layer %d, frame %d before, loading rays...' % (layer_id, frame_id)) 448 | self.layer_rays = torch.load(os.path.join(tmp_layer_ray_path,'rays.pt'),map_location='cpu') 449 | self.layer_rgbs = torch.load(os.path.join(tmp_layer_ray_path,'rgbs.pt'),map_location='cpu') 450 | self.layer_labels = torch.load(os.path.join(tmp_layer_ray_path,'labels.pt'),map_location='cpu') 451 | self.near_fars = torch.load(os.path.join(tmp_layer_ray_path,'near_fars.pt'),map_location='cpu') 452 | 453 | # Fix to the layer id 454 | self.layer_bbox_labels = self.layer_id * torch.ones_like(self.layer_labels) 455 | print('Generating %d rays' % self.layer_rays.shape[0]) 456 | def __len__(self): 457 | return self.layer_rays.shape[0] 458 | 459 | def __getitem__(self, index): 460 | return self.layer_rays[index,:], self.layer_rgbs[index,:], self.layer_labels[index,:], self.layer_bbox_labels[index,:], self.layer_bbox[0], self.near_fars[index,:] -------------------------------------------------------------------------------- /data/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import os 4 | 5 | 6 | def campose_to_extrinsic(camposes): 7 | if camposes.shape[1]!=12: 8 | raise Exception(" wrong campose data structure!") 9 | 10 | res = np.zeros((camposes.shape[0],4,4)) 11 | 12 | res[:,0,:] = camposes[:,0:4] 13 | res[:,1,:] = camposes[:,4:8] 14 | res[:,2,:] = camposes[:,8:12] 15 | res[:,3,3] = 1.0 16 | 17 | return res 18 | 19 | 20 | def read_intrinsics(fn_instrinsic): 21 | fo = open(fn_instrinsic) 22 | data= fo.readlines() 23 | i = 0 24 | Ks = [] 25 | while i max_iter: 57 | max_iter = temp 58 | if not os.path.exists(os.path.join(root_dir,'layered_rfnr_checkpoint_%d.pt' % max_iter)): 59 | return None 60 | return os.path.join(root_dir,'layered_rfnr_checkpoint_%d.pt' % max_iter) 61 | 62 | def get_iteration_path_and_iter(root_dir, fix_iter = -1): 63 | if fix_iter != -1: 64 | return os.path.join(root_dir,'frame','layered_rfnr_checkpoint_%d.pt' % fix_iter) 65 | 66 | if not os.path.exists(root_dir): 67 | return None 68 | file_names = glob.glob(os.path.join(root_dir,'layered_rfnr_checkpoint_*.pt')) 69 | max_iter = -1 70 | for file_name in file_names: 71 | num_name = file_name.split('_')[-1] 72 | temp = int(num_name.split('.')[0]) 73 | if temp > max_iter: 74 | max_iter = temp 75 | if not os.path.exists(os.path.join(root_dir,'layered_rfnr_checkpoint_%d.pt' % max_iter)): 76 | return None 77 | return os.path.join(root_dir,'layered_rfnr_checkpoint_%d.pt' % max_iter), max_iter 78 | 79 | 80 | def read_mask(path): 81 | fo = open(path) 82 | data= fo.readlines() 83 | mask = [] 84 | for i in range(len(data)): 85 | tmp = int(data[i]) 86 | mask.append(tmp) 87 | mask = np.array(mask) 88 | fo.close() 89 | 90 | return mask -------------------------------------------------------------------------------- /data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: Minye Wu 4 | @GITHUB: wuminye 5 | """ 6 | 7 | from .build import build_transforms,build_layered_transforms 8 | -------------------------------------------------------------------------------- /data/transforms/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/transforms/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/transforms/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/build.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/transforms/__pycache__/build.cpython-36.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/build.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/transforms/__pycache__/build.cpython-38.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/random_transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/transforms/__pycache__/random_transforms.cpython-36.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/random_transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/transforms/__pycache__/random_transforms.cpython-38.pyc -------------------------------------------------------------------------------- /data/transforms/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torchvision.transforms as T 8 | 9 | from .random_transforms import Random_Transforms 10 | 11 | 12 | def build_transforms(cfg, is_train=True): 13 | normalize_transform = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 14 | 15 | if is_train: 16 | 17 | transform = Random_Transforms((cfg.INPUT.SIZE_TRAIN[1], cfg.INPUT.SIZE_TRAIN[0]),cfg.DATASETS.SHIFT, cfg.DATASETS.MAXRATION,cfg.DATASETS.ROTATION) 18 | #transform = T.Compose([ 19 | # T.Resize((cfg.INPUT.SIZE_TRAIN[1], cfg.INPUT.SIZE_TRAIN[0])), 20 | # T.ToTensor() 21 | #]) 22 | else: 23 | transform = Random_Transforms((cfg.INPUT.SIZE_TEST[1], cfg.INPUT.SIZE_TEST[0]),0, isTrain = is_train) 24 | #transform = T.Compose([ 25 | # T.Resize((cfg.INPUT.SIZE_TEST[1], cfg.INPUT.SIZE_TEST[0])), 26 | # T.ToTensor() 27 | #]) 28 | 29 | 30 | return transform 31 | 32 | def build_layered_transforms(cfg, is_layer=True, is_train=True): 33 | normalize_transform = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 34 | 35 | if is_train: 36 | if is_layer: 37 | transform = Random_Transforms((cfg.INPUT.SIZE_LAYER[1], cfg.INPUT.SIZE_LAYER[0]),cfg.DATASETS.SHIFT, cfg.DATASETS.MAXRATION,cfg.DATASETS.ROTATION) 38 | else: 39 | transform = Random_Transforms((cfg.INPUT.SIZE_TRAIN[1], cfg.INPUT.SIZE_TRAIN[0]),cfg.DATASETS.SHIFT, cfg.DATASETS.MAXRATION,cfg.DATASETS.ROTATION) 40 | else: 41 | transform = Random_Transforms((cfg.INPUT.SIZE_TEST[1], cfg.INPUT.SIZE_TEST[0]),0) 42 | #transform = T.Compose([ 43 | # T.Resize((cfg.INPUT.SIZE_TEST[1], cfg.INPUT.SIZE_TEST[0])), 44 | # T.ToTensor() 45 | #]) 46 | 47 | 48 | return transform -------------------------------------------------------------------------------- /data/transforms/random_transforms.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torchvision.transforms as T 8 | 9 | from torch.utils import data 10 | import torch 11 | 12 | import numpy as np 13 | 14 | import random 15 | import PIL 16 | from PIL import Image 17 | import collections 18 | import math 19 | ''' 20 | INPUT: mask is a (h,w) numpy array 21 | every pixel larger than 0 will be in count 22 | ''' 23 | def calc_center(mask): 24 | grid = np.mgrid[0:mask.shape[0],0:mask.shape[1]] 25 | grid_mask = mask[grid[0],grid[1]].astype(np.bool) 26 | X = grid[0,grid_mask] 27 | Y = grid[1,grid_mask] 28 | 29 | return np.mean(X),np.mean(Y) 30 | 31 | 32 | def rodrigues_rotation_matrix(axis, theta): 33 | axis = np.asarray(axis) 34 | theta = np.asarray(theta) 35 | axis = axis/math.sqrt(np.dot(axis, axis)) 36 | a = math.cos(theta/2.0) 37 | b, c, d = -axis*math.sin(theta/2.0) 38 | aa, bb, cc, dd = a*a, b*b, c*c, d*d 39 | bc, ad, ac, ab, bd, cd = b*c, a*d, a*c, a*b, b*d, c*d 40 | return np.array([[aa+bb-cc-dd, 2*(bc+ad), 2*(bd-ac)], 41 | [2*(bc-ad), aa+cc-bb-dd, 2*(cd+ab)], 42 | [2*(bd+ac), 2*(cd-ab), aa+dd-bb-cc]]) 43 | 44 | 45 | class Random_Transforms(object): 46 | def __init__(self, size, random_range = 0, random_ration = 0, random_rotation = 0,interpolation=Image.BICUBIC, isTrain = True, is_center = False): 47 | assert isinstance(size, int) or (isinstance(size, collections.abc.Iterable) and len(size) == 2) 48 | self.size = size 49 | self.interpolation = interpolation 50 | self.random_range = random_range 51 | self.random_scale = random_ration 52 | self.isTrain = isTrain 53 | self.random_rotation = random_rotation 54 | self.is_center = is_center 55 | 56 | def __call__(self, img, Ks = None, Ts = None, mask = None, label = None): 57 | 58 | K = Ks.clone() 59 | Tc = Ts.clone() 60 | img_np = np.asarray(img) 61 | 62 | offset = random.randint(-self.random_range,self.random_range) 63 | offset2 = random.randint(-self.random_range,self.random_range) 64 | 65 | rotation = (random.random()-0.5)*np.deg2rad(self.random_rotation) 66 | ration = random.random()*self.random_scale + 1.0 67 | 68 | width, height = img.size 69 | 70 | R = torch.Tensor(rodrigues_rotation_matrix(np.array([0,0,1]),rotation)) 71 | 72 | 73 | Tc[0:3,0:3] = torch.matmul(Tc[0:3,0:3],R) 74 | 75 | m_scale = height/self.size[0] 76 | 77 | cx, cy = 0, 0 78 | 79 | if mask is not None and self.isTrain: 80 | mask_np = np.asarray(mask) 81 | if mask_np.ndim == 3: 82 | mask_np = mask_np[:,:,0] 83 | cy, cx = calc_center(mask_np) 84 | 85 | cx = cx - width /2 86 | cy = cy - height/2 87 | 88 | 89 | 90 | translation = (offset*m_scale-cx,offset2*m_scale-cy ) 91 | 92 | if self.is_center: 93 | translation = [width /2-K[0,2],height/2-K[1,2]] 94 | translation = list(translation) 95 | ration = 1.05 96 | 97 | if (self.size[1]/2)/(self.size[0]*ration / height) - K[0,2] != translation[0] : 98 | ration = 1.2 99 | translation[1] = (self.size[0]/2)/(self.size[0]*ration / height) - K[1,2] 100 | translation[0] = (self.size[1]/2)/(self.size[0]*ration / height) - K[0,2] 101 | translation = tuple(translation) 102 | 103 | #translation = (width /2-K[0,2],height/2-K[1,2]) 104 | 105 | 106 | img = T.functional.rotate(img, angle = np.rad2deg(rotation), resample = Image.BICUBIC, center =(K[0,2],K[1,2])) 107 | img = T.functional.affine(img, angle = 0, translate = translation, scale= 1,shear=0) 108 | img = T.functional.crop(img, 0, 0, int(height/ration),int(height*self.size[1]/ration/self.size[0]) ) 109 | img = T.functional.resize(img, self.size, self.interpolation) 110 | img = T.functional.to_tensor(img) 111 | 112 | 113 | ROI = np.ones_like(img_np)*255.0 114 | 115 | ROI = Image.fromarray(np.uint8(ROI)) 116 | ROI = T.functional.rotate(ROI, angle = np.rad2deg(rotation), resample = Image.BICUBIC, center =(K[0,2],K[1,2])) 117 | ROI = T.functional.affine(ROI, angle = 0, translate = translation, scale= 1,shear=0) 118 | ROI = T.functional.crop(ROI, 0,0, int(height/ration),int(height*self.size[1]/ration/self.size[0]) ) 119 | ROI = T.functional.resize(ROI, self.size, self.interpolation) 120 | ROI = T.functional.to_tensor(ROI) 121 | ROI = ROI[0:1,:,:] 122 | 123 | 124 | 125 | if mask is not None: 126 | mask = T.functional.rotate(mask, angle = np.rad2deg(rotation), resample = Image.BICUBIC, center =(K[0,2],K[1,2])) 127 | mask = T.functional.affine(mask, angle = 0, translate = translation, scale= 1,shear=0) 128 | mask = T.functional.crop(mask, 0, 0, int(height/ration),int(height*self.size[1]/ration/self.size[0]) ) 129 | mask = T.functional.resize(mask, self.size, self.interpolation) 130 | mask = T.functional.to_tensor(mask) 131 | 132 | 133 | if label is not None: 134 | label = Image.fromarray(np.uint8(label)) 135 | label = T.functional.rotate(label, angle = np.rad2deg(rotation), resample = Image.BICUBIC, center =(K[0,2],K[1,2])) 136 | label = T.functional.affine(label, angle = 0, translate = translation, scale= 1,shear=0) 137 | label = T.functional.crop(label, 0,0, int(height/ration),int(height*self.size[1]/ration/self.size[0]) ) 138 | label = T.functional.resize(label, self.size, self.interpolation) 139 | label = T.functional.to_tensor(label) 140 | label = label * 255.0 141 | 142 | 143 | 144 | 145 | #K = K / m_scale 146 | #K[2,2] = 1 147 | 148 | 149 | K[0,2] = K[0,2] + translation[0] 150 | K[1,2] = K[1,2] + translation[1] 151 | 152 | s = self.size[0] * ration / height 153 | 154 | K = K*s 155 | 156 | K[2,2] = 1 157 | #print(img.size(),mask.size(),ROI.size()) 158 | 159 | 160 | if label is None: 161 | return img, K, Tc, mask, ROI 162 | else: 163 | return img, label, K, Tc, ROI 164 | 165 | def __repr__(self): 166 | return self.__class__.__name__ + '()' 167 | 168 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /demo/taekwondo_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 4 | import sys 5 | from os import mkdir 6 | import shutil 7 | import torch 8 | import torch.nn.functional as F 9 | import random 10 | from torchvision import utils as vutils 11 | import numpy as np 12 | import imageio 13 | import matplotlib.pyplot as plt 14 | 15 | sys.path.append('.') 16 | from config import cfg 17 | from engine.layered_trainer import do_train 18 | from solver import make_optimizer, WarmupMultiStepLR,build_scheduler 19 | from layers import make_loss 20 | from utils.logger import setup_logger 21 | from layers.RaySamplePoint import RaySamplePoint 22 | from utils import batchify_ray, vis_density 23 | from render import LayeredNeuralRenderer 24 | 25 | text = 'This is the program to render the nerf by the specific frame id and layer id, try to get help by using ' 26 | parser = argparse.ArgumentParser(description=text) 27 | parser.add_argument('-c', '--config', default='', help='set the config file path to render the network') 28 | parser.add_argument('-g','--gpu', type=int, default=0, help='set gpu id to render the network') 29 | args = parser.parse_args() 30 | 31 | torch.cuda.set_device(args.gpu) 32 | torch.autograd.set_detect_anomaly(True) 33 | torch.set_default_dtype(torch.float32) 34 | 35 | 36 | cfg.merge_from_file(args.config) 37 | cfg.freeze() 38 | 39 | neural_renderer = LayeredNeuralRenderer(cfg) 40 | 41 | key_frames_layer_1 = [21,49,74,87] # performer 1 time line 42 | key_frames_layer_2 = [13,42,80,90] # performer 2 time line 43 | key_frames = [20,50,74,85] # new time line 44 | density_threshold = 0 # Can be set to higher to hide glass 45 | inverse_y_axis = False # For some y-inversed model 46 | neural_renderer = LayeredNeuralRenderer(cfg) 47 | neural_renderer.set_save_dir('origin') 48 | neural_renderer.retime_by_key_frames(1, key_frames_layer_1, key_frames) 49 | neural_renderer.retime_by_key_frames(2, key_frames_layer_2, key_frames) 50 | neural_renderer.set_fps(25) 51 | neural_renderer.set_smooth_path_poses(101, around=False) 52 | neural_renderer.render_path(inverse_y_axis,density_threshold,auto_save=True) 53 | neural_renderer.save_video() 54 | 55 | neural_renderer = LayeredNeuralRenderer(cfg, shift=[[0,0,0],[0,2,0],[0,-2,0]]) 56 | neural_renderer.set_save_dir('shift') 57 | neural_renderer.retime_by_key_frames(1, key_frames_layer_1, key_frames) 58 | neural_renderer.retime_by_key_frames(2, key_frames_layer_2, key_frames) 59 | neural_renderer.set_fps(25) 60 | neural_renderer.set_smooth_path_poses(101, around=False) 61 | neural_renderer.render_path(inverse_y_axis,density_threshold,auto_save=True) 62 | neural_renderer.save_video() 63 | 64 | 65 | neural_renderer = LayeredNeuralRenderer(cfg, scale=[1,0.75,1.5]) 66 | neural_renderer.set_save_dir('scale') 67 | neural_renderer.retime_by_key_frames(1, key_frames_layer_1, key_frames) 68 | neural_renderer.retime_by_key_frames(2, key_frames_layer_2, key_frames) 69 | neural_renderer.set_fps(25) 70 | neural_renderer.set_smooth_path_poses(101, around=False) 71 | neural_renderer.render_path(inverse_y_axis,density_threshold,auto_save=True) 72 | neural_renderer.save_video() 73 | 74 | -------------------------------------------------------------------------------- /demo/walking_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 4 | import sys 5 | from os import mkdir 6 | import shutil 7 | import torch 8 | import torch.nn.functional as F 9 | import random 10 | from torchvision import utils as vutils 11 | import numpy as np 12 | import imageio 13 | import matplotlib.pyplot as plt 14 | 15 | sys.path.append('..') 16 | from config import cfg 17 | from engine.layered_trainer import do_train 18 | from modeling import build_model 19 | from solver import make_optimizer, WarmupMultiStepLR,build_scheduler 20 | from layers import make_loss 21 | from utils.logger import setup_logger 22 | from layers.RaySamplePoint import RaySamplePoint 23 | from utils import batchify_ray, vis_density 24 | from render import LayeredNeuralRenderer 25 | 26 | text = 'This is the program to render the nerf by the specific frame id and layer id, try to get help by using ' 27 | parser = argparse.ArgumentParser(description=text) 28 | parser.add_argument('-c', '--config', default='', help='set the config file path to render the network') 29 | parser.add_argument('-g','--gpu', type=int, default=0, help='set gpu id to render the network') 30 | args = parser.parse_args() 31 | 32 | torch.cuda.set_device(args.gpu) 33 | torch.autograd.set_detect_anomaly(True) 34 | torch.set_default_dtype(torch.float32) 35 | 36 | 37 | cfg.merge_from_file(args.config) 38 | cfg.freeze() 39 | 40 | neural_renderer = LayeredNeuralRenderer(cfg) 41 | 42 | 43 | density_threshold = 20 # Can be set to higher to hide glass 44 | bkgd_density_threshold = 0.8 45 | inverse_y_axis = False # For some y-inversed model 46 | 47 | neural_renderer.set_fps(25) 48 | neural_renderer.set_pose_duration(1,14) # [ min , max ) 49 | neural_renderer.set_smooth_path_poses(100, around=False) 50 | neural_renderer.set_near(4) 51 | neural_renderer.invert_poses() 52 | 53 | 54 | neural_renderer.set_save_dir("origin") 55 | neural_renderer.render_path(inverse_y_axis,density_threshold,bkgd_density_threshold,auto_save=True) 56 | neural_renderer.save_video() 57 | 58 | 59 | neural_renderer.hide_layer(1) 60 | neural_renderer.set_save_dir("hide_man_1") 61 | neural_renderer.render_path(inverse_y_axis,density_threshold,bkgd_density_threshold,auto_save=True) 62 | neural_renderer.save_video() 63 | 64 | 65 | neural_renderer.hide_layer(2) 66 | neural_renderer.set_save_dir("hide_both") 67 | neural_renderer.render_path(inverse_y_axis,density_threshold,bkgd_density_threshold,auto_save=True) 68 | neural_renderer.save_video() 69 | -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .render import render -------------------------------------------------------------------------------- /engine/layered_trainer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | import imageio 9 | import torch 10 | 11 | from utils import layered_batchify_ray, vis_density, metrics 12 | from utils.metrics import * 13 | import numpy as np 14 | import os 15 | import time 16 | 17 | def evaluator(val_dataset, model, loss_fn, swriter, epoch): 18 | 19 | model.eval() 20 | rays, rgbs, labels, image, label, mask, bbox, near_far = val_dataset[0] 21 | 22 | rays = rays.cuda() 23 | rgbs = rgbs.cuda() 24 | bbox = bbox.cuda() 25 | labels = labels.cuda() 26 | color_gt = image.cuda() 27 | mask = mask.cuda() 28 | near_far = near_far.cuda() 29 | 30 | # uv_list = (mask).squeeze().nonzero() 31 | # u_list = uv_list[:,0] 32 | # v_list = uv_list[:,1] 33 | 34 | with torch.no_grad(): 35 | # TODO: Use mask to gain less query of space 36 | stage2, stage1, stage2_layer, stage1_layer, _ = layered_batchify_ray(model, rays, labels, bbox, near_far=near_far) 37 | for i in range(len(stage2_layer)): 38 | color_1 = stage2_layer[i][0] 39 | depth_1 = stage2_layer[i][1] 40 | acc_map_1 = stage2_layer[i][2] 41 | #print(color_1.shape) 42 | #print(depth_1.shape) 43 | #print(acc_map_1.shape) 44 | 45 | color_0 = stage1_layer[i][0] 46 | depth_0 = stage1_layer[i][1] 47 | acc_map_00 = stage1_layer[i][2] 48 | 49 | 50 | color_img = color_1.reshape( (color_gt.size(1), color_gt.size(2), 3) ).permute(2,0,1) 51 | depth_img = depth_1.reshape( (color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1) 52 | depth_img = (depth_img-depth_img.min())/(depth_img.max()-depth_img.min()) 53 | acc_map = acc_map_1.reshape( (color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1) 54 | 55 | 56 | color_img_0 = color_0.reshape( (color_gt.size(1), color_gt.size(2), 3) ).permute(2,0,1) 57 | depth_img_0 = depth_0.reshape( (color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1) 58 | depth_img_0 = (depth_img_0-depth_img_0.min())/(depth_img_0.max()-depth_img_0.min()) 59 | acc_map_0 = acc_map_00.reshape((color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1) 60 | 61 | 62 | depth_img = (depth_img-depth_img.min())/(depth_img.max()-depth_img.min()) 63 | depth_img_0 = (depth_img_0-depth_img_0.min())/(depth_img_0.max()-depth_img_0.min()) 64 | 65 | color_img = color_img*((mask).permute(2,0,1).repeat(3,1,1)) 66 | color_gt = color_gt*((mask).permute(2,0,1).repeat(3,1,1)) 67 | 68 | if i == 0: 69 | swriter.add_image('stage2_bkgd/rendered', color_img, epoch) 70 | swriter.add_image('stage2_bkgd/depth', depth_img, epoch) 71 | swriter.add_image('stage2_bkgd/alpha', acc_map, epoch) 72 | 73 | swriter.add_image('stage1_bkgd/rendered', color_img_0, epoch) 74 | swriter.add_image('stage1_bkgd/depth', depth_img_0, epoch) 75 | swriter.add_image('stage1_bkgd/alpha', acc_map_0, epoch) 76 | 77 | else: 78 | swriter.add_image('stage2_layer' +str(i)+ '/rendered', color_img, epoch) 79 | swriter.add_image('stage2_layer' +str(i)+ '/depth', depth_img, epoch) 80 | swriter.add_image('stage2_layer' +str(i)+ '/alpha', acc_map, epoch) 81 | 82 | swriter.add_image('stage1_layer' +str(i)+ '/rendered', color_img_0, epoch) 83 | swriter.add_image('stage1_layer' +str(i)+ '/depth', depth_img_0, epoch) 84 | swriter.add_image('stage1_layer' +str(i)+ '/alpha', acc_map_0, epoch) 85 | 86 | 87 | color_1 = stage2[0] 88 | depth_1 = stage2[1] 89 | acc_map_1 = stage2[2] 90 | #print(color_1.shape) 91 | #print(depth_1.shape) 92 | #print(acc_map_1.shape) 93 | 94 | color_0 = stage1[0] 95 | depth_0 = stage1[1] 96 | acc_map_00 = stage1[2] 97 | 98 | 99 | color_img = color_1.reshape( (color_gt.size(1), color_gt.size(2), 3) ).permute(2,0,1) 100 | depth_img = depth_1.reshape( (color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1) 101 | depth_img = (depth_img-depth_img.min())/(depth_img.max()-depth_img.min()) 102 | acc_map = acc_map_1.reshape( (color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1) 103 | 104 | 105 | color_img_0 = color_0.reshape( (color_gt.size(1), color_gt.size(2), 3) ).permute(2,0,1) 106 | depth_img_0 = depth_0.reshape( (color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1) 107 | depth_img_0 = (depth_img_0-depth_img_0.min())/(depth_img_0.max()-depth_img_0.min()) 108 | acc_map_0 = acc_map_00.reshape((color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1) 109 | 110 | 111 | depth_img = (depth_img-depth_img.min())/(depth_img.max()-depth_img.min()) 112 | depth_img_0 = (depth_img_0-depth_img_0.min())/(depth_img_0.max()-depth_img_0.min()) 113 | 114 | color_img = color_img*((mask).permute(2,0,1).repeat(3,1,1)) 115 | color_gt = color_gt*((mask).permute(2,0,1).repeat(3,1,1)) 116 | 117 | 118 | swriter.add_image('GT/Label', label * 50, epoch) 119 | swriter.add_image('GT/Image', color_gt, epoch) 120 | 121 | swriter.add_image('stage2/rendered', color_img, epoch) 122 | swriter.add_image('stage2/depth', depth_img, epoch) 123 | swriter.add_image('stage2/alpha', acc_map, epoch) 124 | 125 | swriter.add_image('stage1/rendered', color_img_0, epoch) 126 | swriter.add_image('stage1/depth', depth_img_0, epoch) 127 | swriter.add_image('stage1/alpha', acc_map_0, epoch) 128 | 129 | 130 | return loss_fn(color_img, color_gt).item() 131 | 132 | 133 | def do_train( 134 | cfg, 135 | model, 136 | train_loader, 137 | val_loader, 138 | optimizer, 139 | scheduler, 140 | loss_fn, 141 | swriter, 142 | resume_epoch = 0, 143 | psnr_thres = 100 144 | ): 145 | log_period = cfg.SOLVER.LOG_PERIOD 146 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 147 | output_dir = cfg.OUTPUT_DIR 148 | max_epochs = cfg.SOLVER.MAX_EPOCHS 149 | train_by_pointcloud = cfg.MODEL.TRAIN_BY_POINTCLOUD 150 | use_label = cfg.DATASETS.USE_LABEL 151 | coarse_stage = cfg.SOLVER.COARSE_STAGE 152 | remove_outliers = cfg.MODEL.REMOVE_OUTLIERS 153 | 154 | 155 | logger = logging.getLogger("LayeredRFRender.%s.train" % cfg.OUTPUT_DIR.split('/')[-1]) 156 | logger.info("Start training") 157 | #global step 158 | global_step = 0 159 | 160 | torch.autograd.set_detect_anomaly(True) 161 | 162 | 163 | for epoch in range(1+resume_epoch,max_epochs): 164 | print('Training Epoch %d...' % epoch) 165 | model.cuda() 166 | 167 | #psnr monitor 168 | psnr_monitor = [] 169 | 170 | #epoch time recordingbatchify_ray 171 | epoch_start = time.time() 172 | for batch_idx, batch in enumerate(train_loader): 173 | 174 | #iteration time recording 175 | iters_start = time.time() 176 | global_step = (epoch -1) * len(train_loader) + batch_idx 177 | 178 | model.train() 179 | optimizer.zero_grad() 180 | 181 | rays, rgbs, labels, bbox_labels, bboxes, near_far = batch 182 | bbox_labels = bbox_labels.cuda() 183 | labels = labels.cuda() 184 | rays = rays.cuda() 185 | rgbs = rgbs.cuda() 186 | bboxes = bboxes.cuda() 187 | near_far = near_far.cuda() 188 | 189 | loss = 0 190 | 191 | if epoch label-0.5) 199 | # print(torch.sum(outliers)) 200 | # print(torch.sum(inliers)) 201 | # out_threshold = 0.5 202 | 203 | predict_rgb_0 = stage1[0] 204 | predict_rgb_1 = stage2[0] 205 | 206 | # predict_rgb_0 = stage1_layer[0][labels.repeat(1,3) != 0] 207 | # predict_rgb_1 = stage2_layer[0][labels.repeat(1,3) != 0] 208 | # rgbs = rgbs[labels.repeat(1,3) != 0] 209 | 210 | # print('ray number is %d' % torch.sum(labels != 0)) 211 | 212 | 213 | # print('layer ray number is %d, bbox layer ray number is %d, outlier number is %d, total number is %d' % (torch.sum(labels != 0), torch.sum(bbox_labels != 0), outliers_1.shape[0], rays.size(0))) 214 | 215 | 216 | loss1 = loss_fn(predict_rgb_0, rgbs) 217 | loss2 = loss_fn(predict_rgb_1, rgbs) 218 | if epoch < 3 and remove_outliers: 219 | outliers_1 = [] 220 | outliers_2 = [] 221 | inliers_1 =[] 222 | inliers_2 = [] 223 | for i in range(len(stage1_layer)): 224 | 225 | if i != 0: #i!=3 for spiderman basket 226 | outliers_1.append(stage1_layer[i][2][labels == 0]) 227 | outliers_2.append(stage2_layer[i][2][labels == 0]) 228 | # else: 229 | # outliers_1.append(stage1_layer[i][2][labels == 0]) 230 | # outliers_2.append(stage2_layer[i][2][labels == 0]) 231 | inliers_1.append(stage1_layer[i][2][labels == i]) 232 | inliers_2.append(stage2_layer[i][2][labels == i]) 233 | 234 | if outliers_1 != []: 235 | outliers_1 = torch.cat(outliers_1,0) 236 | outliers_2 = torch.cat(outliers_2,0) 237 | inliers_1 = torch.cat(inliers_1,0) 238 | inliers_2 = torch.cat(inliers_2,0) 239 | # print('total ray number is ', stage2[1].shape, ', the inliers number is ',predict_rgb_1.shape) 240 | # loss1 = loss_fn(predict_rgb_0, rgbs) 241 | # loss2 = loss_fn(predict_rgb_1, rgbs) 242 | 243 | #TODO: 100000 should be adapted 244 | scalar_max = 100000 245 | scalar = scalar_max 246 | #penalty 100 will make mask be smaller, 20 will be better, try 10 247 | penalty = 1 248 | if outliers_1 != []: 249 | loss_mask_0 = torch.sum(torch.abs(outliers_1)) * penalty + torch.sum(torch.abs(1-inliers_1)) 250 | loss_mask_1 = torch.sum(torch.abs(outliers_2)) * penalty + torch.sum(torch.abs(1-inliers_2)) 251 | else: 252 | loss_mask_0 = torch.sum(torch.abs(1-inliers_1)) 253 | loss_mask_1 = torch.sum(torch.abs(1-inliers_2)) 254 | 255 | # while loss_mask_1 / scalar < rays.shape[0]/(scalar_max * 2) and loss_mask_1 > 1: 256 | # scalar /= 2 257 | # if scalar <= 1: 258 | # scalar = 1.0 259 | # break 260 | 261 | # num_ray_mask = torch.sum(ray_mask.view(1,-1)).item() 262 | # print('This batch has %d rays in bbox' % num_ray_mask) 263 | 264 | if loss_mask_0 > rays.shape[0] * 0.0005 and remove_outliers: 265 | loss_mask_0 = loss_mask_0 / scalar 266 | else: 267 | loss_mask_0 = torch.Tensor([0]).cuda() 268 | 269 | if loss_mask_1 > rays.shape[0] * 0.0005 and remove_outliers: 270 | loss_mask_1 = loss_mask_1 / scalar 271 | else: 272 | loss_mask_1 = torch.Tensor([0]).cuda() 273 | else: 274 | loss_mask_0 = torch.Tensor([0]).cuda() 275 | loss_mask_1 = torch.Tensor([0]).cuda() 276 | 277 | 278 | if epoch psnr_thres: 328 | logger.info("The Mean Psnr of Epoch: {:.3f}, greater than threshold: {:.3f}, Training Stopped".format(psnr_monitor, psnr_thres)) 329 | break 330 | else: 331 | logger.info("The Mean Psnr of Epoch: {:.3f}, less than threshold: {:.3f}, Continue to Training".format(psnr_monitor, psnr_thres)) 332 | 333 | def val_vis(val_loader,model ,loss_fn, swriter, logger, epoch): 334 | 335 | 336 | avg_loss = evaluator(val_loader, model, loss_fn, swriter,epoch) 337 | logger.info("Validation Results - Epoch: {} Avg Loss: {:.3f}" 338 | .format(epoch, avg_loss) 339 | ) 340 | swriter.add_scalar('Loss/val_loss',avg_loss, epoch) 341 | 342 | def ModelCheckpoint(model, optimizer, scheduler, output_dir, epoch, global_step = 0): 343 | # model,optimizer,scheduler saving 344 | if not os.path.exists(output_dir): 345 | os.makedirs(output_dir) 346 | if global_step == 0: 347 | torch.save({'model':model.state_dict(),'optimizer':optimizer.state_dict(),'scheduler':scheduler.state_dict()}, 348 | os.path.join(output_dir,'layered_rfnr_checkpoint_%d.pt' % epoch)) 349 | else: 350 | torch.save({'model':model.state_dict(),'optimizer':optimizer.state_dict(),'scheduler':scheduler.state_dict()}, 351 | os.path.join(output_dir,'layered_rfnr_checkpoint_%d_%d.pt' % (epoch,global_step))) 352 | # torch.save(model.state_dict(), os.path.join(output_dir, 'spacenet_epoch_%d.pth'%epoch)) 353 | # torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer_epoch_%d.pth'%epoch)) 354 | # torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler_epoch_%d.pth'%epoch)) 355 | 356 | 357 | def do_evaluate( model,val_dataset): 358 | mae_list = [] 359 | psnr_list = [] 360 | ssim_list = [] 361 | 362 | 363 | model.eval() 364 | with torch.no_grad(): 365 | for i in range(2): 366 | for j in range(50): 367 | rays, rgbs, labels, image, label, mask, bbox, near_far = val_dataset.get_fixed_image(i,j+1) 368 | 369 | rays = rays.cuda() 370 | rgbs = rgbs.cuda() 371 | bbox = bbox.cuda() 372 | labels = labels.cuda() 373 | color_gt = image.cuda() 374 | mask = mask.cuda() 375 | near_far = near_far.cuda() 376 | 377 | # uv_list = (mask).squeeze().nonzero() 378 | # u_list = uv_list[:,0] 379 | # v_list = uv_list[:,1] 380 | 381 | 382 | # TODO: Use mask to gain less query of space 383 | stage2, _, _, _, _ = layered_batchify_ray(model, rays, labels, bbox, near_far=near_far) 384 | 385 | color_1 = stage2[0] 386 | depth_1 = stage2[1] 387 | acc_map_1 = stage2[2] 388 | #print(color_1.shape) 389 | #print(depth_1.shape) 390 | #print(acc_map_1.shape) 391 | 392 | 393 | color_img = color_1.reshape( (color_gt.size(1), color_gt.size(2), 3) ).permute(2,0,1) 394 | 395 | mae = metrics.mae(color_img,color_gt) 396 | psnr = metrics.psnr(color_img,color_gt) 397 | ssim = metrics.ssim(color_img,color_gt) 398 | print(color_img.shape) 399 | print(color_gt.shape) 400 | 401 | #imageio.imwrite("/new_disk/zhangjk/NeuralVolumeRender-dynamic/evaluation/walking/"+str(j+1)+".png", color_img.transpose(0,2).transpose(0,1).cpu()) 402 | 403 | 404 | print("mae:",mae) 405 | print("psnr:",psnr) 406 | print("ssim:",ssim) 407 | mae_list.append(mae) 408 | psnr_list.append(psnr) 409 | ssim_list.append(ssim) 410 | mae_list = np.array(mae_list) 411 | psnr_list = np.array(psnr_list) 412 | ssim_list = np.array(ssim_list) 413 | np.savetxt('/new_disk/zhangjk/NeuralVolumeRender-dynamic/evaluation/complete/mae.out',mae_list) 414 | np.savetxt('/new_disk/zhangjk/NeuralVolumeRender-dynamic/evaluation/complete/psnr.out',psnr_list) 415 | np.savetxt('/new_disk/zhangjk/NeuralVolumeRender-dynamic/evaluation/complete/ssim.out',ssim_list) 416 | avg_mae = np.mean(np.array(mae_list)) 417 | avg_psnr = np.mean(np.array(psnr_list)) 418 | avg_ssim = np.mean(np.array(ssim_list)) 419 | print("avg_mae:",avg_mae) 420 | print("avg_psnr:",avg_psnr) 421 | print("avg_ssim:",avg_ssim) 422 | #print(color_1.shape) 423 | #print(color_gt.shape) 424 | #print(metrics.psnr(color_img, color_gt)) 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | -------------------------------------------------------------------------------- /engine/render.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils import batchify_ray, vis_density, ray_sampling 4 | import numpy as np 5 | import os 6 | import torch 7 | 8 | 9 | ''' 10 | Sample rays from views (and images) with/without masks 11 | 12 | -------------------------- 13 | INPUT Tensors 14 | K: intrinsics of camera (3,3) 15 | T: extrinsic of camera (4,4) 16 | image_size: the size of image [H,W] 17 | 18 | ROI: 2D ROI bboxes (4) left up corner(x,y) followed the height and width (h,w) 19 | 20 | masks:(M,H,W) 21 | ------------------- 22 | OUPUT: 23 | list of rays: (N,6) dirs(3) + pos(3) 24 | RGB: (N,C) 25 | ''' 26 | 27 | 28 | 29 | 30 | def render(model, K,T,img_size,ROI = None, bboxes = None,only_coarse = False,near_far=None): 31 | model.eval() 32 | assert not (bboxes is None and near_far is None), ' either bbox or near_far should not be None.' 33 | mask = torch.ones(img_size[0],img_size[1]) 34 | if ROI is not None: 35 | mask = torch.zeros(img_size[0],img_size[1]) 36 | mask[ROI[0]:ROI[0]+ROI[2], ROI[1]:ROI[1]+ROI[3]] = 1.0 37 | rays,_ = ray_sampling(K.unsqueeze(0), T.unsqueeze(0), img_size, masks=mask.unsqueeze(0)) 38 | 39 | if bboxes is not None: 40 | bboxes = bboxes.unsqueeze(0).repeat(rays.size(0),1,1) 41 | 42 | with torch.no_grad(): 43 | stage2, stage1,_ = batchify_ray(model, rays, bboxes,near_far = near_far) 44 | 45 | 46 | rgb = torch.zeros(img_size[0],img_size[1], 3, device = stage2[0].device) 47 | rgb[mask>0.5,:] = stage2[0] 48 | 49 | depth = torch.zeros(img_size[0],img_size[1],1, device = stage2[1].device) 50 | depth[mask>0.5,:] = stage2[1] 51 | 52 | alpha = torch.zeros(img_size[0],img_size[1],1, device = stage2[2].device) 53 | alpha[mask>0.5,:] = stage2[2] 54 | 55 | stage2_final = [None]*3 56 | stage2_final[0] = rgb.reshape(img_size[0],img_size[1], 3) 57 | stage2_final[1] = depth.reshape(img_size[0],img_size[1]) 58 | stage2_final[2] = alpha.reshape(img_size[0],img_size[1]) 59 | 60 | 61 | rgb = torch.zeros(img_size[0],img_size[1], 3, device = stage1[0].device) 62 | rgb[mask>0.5,:] = stage1[0] 63 | 64 | depth = torch.zeros(img_size[0],img_size[1],1, device = stage1[1].device) 65 | depth[mask>0.5,:] = stage1[1] 66 | 67 | alpha = torch.zeros(img_size[0],img_size[1],1, device = stage1[2].device) 68 | alpha[mask>0.5,:] = stage1[2] 69 | 70 | stage1_final = [None]*3 71 | stage1_final[0] = rgb.reshape(img_size[0],img_size[1], 3) 72 | stage1_final[1] = depth.reshape(img_size[0],img_size[1]) 73 | stage1_final[2] = alpha.reshape(img_size[0],img_size[1]) 74 | 75 | 76 | 77 | return stage2_final, stage1_final 78 | -------------------------------------------------------------------------------- /images/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/images/teaser.jpg -------------------------------------------------------------------------------- /layers/RaySamplePoint-1.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from torch import nn 5 | import torch 6 | from layers.render_layer import gen_weight 7 | import pdb 8 | 9 | def intersection(rays, bbox): 10 | n = rays.shape[0] 11 | left_face = bbox[:, 0, 0] 12 | right_face = bbox[:, 6, 0] 13 | front_face = bbox[:, 0, 1] 14 | back_face = bbox[:, 6, 1] 15 | bottom_face = bbox[:, 0, 2] 16 | up_face = bbox[:, 6, 2] 17 | # parallel t 无穷大 18 | left_t = ((left_face - rays[:, 0]) / (rays[:, 3] + np.finfo(float).eps.item())).reshape((n, 1)) 19 | right_t = ((right_face - rays[:, 0]) / (rays[:, 3] + np.finfo(float).eps.item())).reshape((n, 1)) 20 | front_t = ((front_face - rays[:, 1]) / (rays[:, 4] + np.finfo(float).eps.item())).reshape((n, 1)) 21 | back_t = ((back_face - rays[:, 1]) / (rays[:, 4] + np.finfo(float).eps.item())).reshape((n, 1)) 22 | bottom_t = ((bottom_face - rays[:, 2]) / (rays[:, 5] + np.finfo(float).eps.item())).reshape((n, 1)) 23 | up_t = ((up_face - rays[:, 2]) / (rays[:, 5] + np.finfo(float).eps)).reshape((n, 1)) 24 | 25 | 26 | rays_o = rays[:, :3] 27 | rays_d = rays[:, 3:6] 28 | left_point = left_t * rays_d + rays_o 29 | right_point = right_t * rays_d + rays_o 30 | front_point = front_t * rays_d + rays_o 31 | back_point = back_t * rays_d + rays_o 32 | bottom_point = bottom_t * rays_d + rays_o 33 | up_point = up_t * rays_d + rays_o 34 | 35 | left_mask = (left_point[:, 1] >= bbox[:, 0, 1]) & (left_point[:, 1] <= bbox[:, 7, 1]) \ 36 | & (left_point[:, 2] >= bbox[:, 0, 2]) & (left_point[:, 2] <= bbox[:, 7, 2]) 37 | right_mask = (right_point[:, 1] >= bbox[:, 1, 1]) & (right_point[:, 1] <= bbox[:, 6, 1]) \ 38 | & (right_point[:, 2] >= bbox[:, 1, 2]) & (right_point[:, 2] <= bbox[:, 6, 2]) 39 | 40 | # compare x, z 41 | front_mask = (front_point[:, 0] >= bbox[:, 0, 0]) & (front_point[:, 0] <= bbox[:, 5, 0]) \ 42 | & (front_point[:, 2] >= bbox[:, 0, 2]) & (front_point[:, 2] <= bbox[:, 5, 2]) 43 | 44 | back_mask = (back_point[:, 0] >= bbox[:, 3, 0]) & (back_point[:, 0] <= bbox[:, 6, 0]) \ 45 | & (back_point[:, 2] >= bbox[:, 3, 2]) & (back_point[:, 2] <= bbox[:, 6, 2]) 46 | 47 | # compare x,y 48 | bottom_mask = (bottom_point[:, 0] >= bbox[:, 0, 0]) & (bottom_point[:, 0] <= bbox[:, 2, 0]) \ 49 | & (bottom_point[:, 1] >= bbox[:, 0, 1]) & (bottom_point[:, 1] <= bbox[:, 2, 1]) 50 | 51 | up_mask = (up_point[:, 0] >= bbox[:, 4, 0]) & (up_point[:, 0] <= bbox[:, 6, 0]) \ 52 | & (up_point[:, 1] >= bbox[:, 4, 1]) & (up_point[:, 1] <= bbox[:, 6, 1]) 53 | 54 | tlist = -torch.ones_like(rays, device=rays.device)*1e3 55 | tlist[left_mask, 0] = left_t[left_mask].reshape((-1,)) 56 | tlist[right_mask, 1] = right_t[right_mask].reshape((-1,)) 57 | tlist[front_mask, 2] = front_t[front_mask].reshape((-1,)) 58 | tlist[back_mask, 3] = back_t[back_mask].reshape((-1,)) 59 | tlist[bottom_mask, 4] = bottom_t[bottom_mask].reshape((-1,)) 60 | tlist[up_mask, 5] = up_t[up_mask].reshape((-1,)) 61 | tlist = tlist.topk(k=2, dim=-1) 62 | 63 | return tlist[0] 64 | 65 | class RaySamplePoint(nn.Module): 66 | def __init__(self, coarse_num=64): 67 | super(RaySamplePoint, self).__init__() 68 | self.coarse_num = coarse_num 69 | 70 | 71 | def forward(self, rays, bbox, pdf=None, method='coarse'): 72 | ''' 73 | :param rays: N*6 74 | :param bbox: N*8*3 0,1,2,3 bottom 4,5,6,7 up 75 | pdf: n*coarse_num 表示权重 76 | :param method: 77 | :return: N*C*1 , N*C*3, N 78 | ''' 79 | n = rays.shape[0] 80 | #if method=='coarse': 81 | sample_num = self.coarse_num 82 | bin_range = torch.arange(0, sample_num, device=rays.device).reshape((1, sample_num)).float() 83 | 84 | bin_num = sample_num 85 | n = rays.shape[0] 86 | tlist = intersection(rays, bbox) 87 | start = (tlist[:,1]).reshape((n,1)) 88 | end = (tlist[:, 0]).reshape((n, 1)) 89 | 90 | bin_sample = torch.rand((n, sample_num), device=rays.device) 91 | bin_width = (end - start)/bin_num 92 | sample_t = (bin_range + bin_sample)* bin_width + start 93 | sample_point = sample_t.unsqueeze(-1)*rays[:,3:6].unsqueeze(1) + rays[:,:3].unsqueeze(1) 94 | mask = (torch.abs(bin_width)> 1e-5).squeeze() 95 | return sample_t.unsqueeze(-1), sample_point, mask 96 | 97 | 98 | class RayDistributedSamplePoint(nn.Module): 99 | def __init__(self, fine_num=10): 100 | super(RayDistributedSamplePoint, self).__init__() 101 | self.fine_num = fine_num 102 | 103 | def forward(self, rays, depth, density, noise=0.0): 104 | ''' 105 | :param rays: N*L*6 106 | :param depth: N*L*1 107 | :param density: N*L*1 108 | :param noise:0 109 | :return: 110 | ''' 111 | 112 | sample_num = self.fine_num 113 | n = density.shape[0] 114 | 115 | weights = gen_weight(depth, density, noise=noise) # N*L 116 | weights += 1e-5 117 | bin = depth.squeeze() 118 | 119 | weights = weights[:, 1:].squeeze() #N*(L-1) 120 | pdf = weights/torch.sum(weights, dim=1, keepdim=True) 121 | cdf = torch.cumsum(pdf, dim=1) 122 | cdf_s = torch.cat((torch.zeros((n, 1)).type(cdf.dtype), cdf), dim=1) 123 | fine_bin = torch.linspace(0, 1, sample_num, device=density.device).reshape((1, sample_num)).repeat((n, 1)) 124 | above_index = torch.ones_like(fine_bin, device=density.device).type(torch.LongTensor) 125 | for i in range(cdf.shape[1]): 126 | mask = (fine_bin > (cdf_s[:, i]).reshape((n, 1))) & (fine_bin <= (cdf[:, i]).reshape((n, 1))) 127 | above_index[mask] = i+1 128 | below_index = above_index-1 129 | below_index[below_index==-1]=0 130 | sn_below = torch.gather(bin, dim=1, index=below_index) 131 | sn_above = torch.gather(bin, dim=1, index=above_index) 132 | cdf_below = torch.gather(cdf_s, dim=1, index=below_index) 133 | cdf_above = torch.gather(cdf_s, dim=1, index=above_index) 134 | dnorm = cdf_above - cdf_below 135 | dnorm = torch.where(dnorm<1e-5, torch.ones_like(dnorm, device=density.device), dnorm) 136 | d = (fine_bin - cdf_below)/dnorm 137 | fine_t = (sn_above - sn_below) * d + sn_below 138 | fine_sample_point = fine_t.unsqueeze(-1) * rays[:, 3:6].unsqueeze(1) + rays[:, :3].unsqueeze(1) 139 | return fine_t, fine_sample_point 140 | 141 | 142 | 143 | class RaySamplePoint_Near_Far(nn.Module): 144 | def __init__(self, sample_num=75): 145 | super(RaySamplePoint_Near_Far, self).__init__() 146 | self.sample_num = sample_num 147 | 148 | 149 | def forward(self, rays,near_far): 150 | ''' 151 | :param rays: N*6 152 | :param bbox: N*8*3 0,1,2,3 bottom 4,5,6,7 up 153 | pdf: n*coarse_num 表示权重 154 | :param method: 155 | :return: N*C*3 156 | ''' 157 | n = rays.size(0) 158 | 159 | 160 | ray_o = rays[:,:3] 161 | ray_d = rays[:,3:6] 162 | 163 | # near = 0.1 164 | # far = 5 165 | 166 | 167 | t_vals = torch.linspace(0., 1., steps=self.sample_num,device =rays.device) 168 | #print(near_far[:,0:1].repeat(1, self.sample_num).size(), t_vals.unsqueeze(0).repeat(n,1).size()) 169 | # print(near_far[:,0].unsqueeze(1).expand([n,self.sample_num]).shape) 170 | # print(near_far[:,1].unsqueeze(1).expand([n,self.sample_num]).shape) 171 | # print('------------------------') 172 | #z_vals = near_far[:,0].unsqueeze(1).expand([n,self.sample_num]) * (1.-t_vals).unsqueeze(0).repeat(n,1) + near_far[:,1].unsqueeze(1).expand([n,self.sample_num]) * (t_vals.unsqueeze(0).repeat(n,1)) 173 | z_vals = near_far[:,0:1].repeat(1, self.sample_num) * (1.-t_vals).unsqueeze(0).repeat(n,1) + near_far[:,1:2].repeat(1, self.sample_num) * (t_vals.unsqueeze(0).repeat(n,1)) 174 | # z_vals = near * (1.-t_vals) + far * (t_vals) 175 | # z_vals = z_vals.expand([n, self.sample_num]) 176 | mids = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 177 | upper = torch.cat([mids, z_vals[...,-1:]], -1) 178 | lower = torch.cat([z_vals[...,:1], mids], -1) 179 | 180 | 181 | t_rand = torch.rand(z_vals.size(), device = rays.device) 182 | 183 | z_vals = lower + (upper - lower) * t_rand 184 | 185 | 186 | pts = ray_o[...,None,:] + ray_d[...,None,:] * z_vals[...,:,None] 187 | 188 | return z_vals.unsqueeze(-1), pts 189 | 190 | 191 | 192 | -------------------------------------------------------------------------------- /layers/RaySamplePoint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | from torch import nn 4 | import torch 5 | from layers.render_layer import gen_weight 6 | import pdb 7 | 8 | def intersection(rays, bbox): 9 | n = rays.shape[0] 10 | left_face = bbox[:, 0, 0] 11 | right_face = bbox[:, 6, 0] 12 | front_face = bbox[:, 0, 1] 13 | back_face = bbox[:, 6, 1] 14 | bottom_face = bbox[:, 0, 2] 15 | up_face = bbox[:, 6, 2] 16 | # parallel t 无穷大 17 | left_t = ((left_face - rays[:, 0]) / (rays[:, 3] + np.finfo(float).eps.item())).reshape((n, 1)) 18 | right_t = ((right_face - rays[:, 0]) / (rays[:, 3] + np.finfo(float).eps.item())).reshape((n, 1)) 19 | front_t = ((front_face - rays[:, 1]) / (rays[:, 4] + np.finfo(float).eps.item())).reshape((n, 1)) 20 | back_t = ((back_face - rays[:, 1]) / (rays[:, 4] + np.finfo(float).eps.item())).reshape((n, 1)) 21 | bottom_t = ((bottom_face - rays[:, 2]) / (rays[:, 5] + np.finfo(float).eps.item())).reshape((n, 1)) 22 | up_t = ((up_face - rays[:, 2]) / (rays[:, 5] + np.finfo(float).eps)).reshape((n, 1)) 23 | 24 | 25 | rays_o = rays[:, :3] 26 | rays_d = rays[:, 3:6] 27 | left_point = left_t * rays_d + rays_o 28 | right_point = right_t * rays_d + rays_o 29 | front_point = front_t * rays_d + rays_o 30 | back_point = back_t * rays_d + rays_o 31 | bottom_point = bottom_t * rays_d + rays_o 32 | up_point = up_t * rays_d + rays_o 33 | 34 | left_mask = (left_point[:, 1] >= bbox[:, 0, 1]) & (left_point[:, 1] <= bbox[:, 7, 1]) \ 35 | & (left_point[:, 2] >= bbox[:, 0, 2]) & (left_point[:, 2] <= bbox[:, 7, 2]) 36 | right_mask = (right_point[:, 1] >= bbox[:, 1, 1]) & (right_point[:, 1] <= bbox[:, 6, 1]) \ 37 | & (right_point[:, 2] >= bbox[:, 1, 2]) & (right_point[:, 2] <= bbox[:, 6, 2]) 38 | 39 | # compare x, z 40 | front_mask = (front_point[:, 0] >= bbox[:, 0, 0]) & (front_point[:, 0] <= bbox[:, 5, 0]) \ 41 | & (front_point[:, 2] >= bbox[:, 0, 2]) & (front_point[:, 2] <= bbox[:, 5, 2]) 42 | 43 | back_mask = (back_point[:, 0] >= bbox[:, 3, 0]) & (back_point[:, 0] <= bbox[:, 6, 0]) \ 44 | & (back_point[:, 2] >= bbox[:, 3, 2]) & (back_point[:, 2] <= bbox[:, 6, 2]) 45 | 46 | # compare x,y 47 | bottom_mask = (bottom_point[:, 0] >= bbox[:, 0, 0]) & (bottom_point[:, 0] <= bbox[:, 2, 0]) \ 48 | & (bottom_point[:, 1] >= bbox[:, 0, 1]) & (bottom_point[:, 1] <= bbox[:, 2, 1]) 49 | 50 | up_mask = (up_point[:, 0] >= bbox[:, 4, 0]) & (up_point[:, 0] <= bbox[:, 6, 0]) \ 51 | & (up_point[:, 1] >= bbox[:, 4, 1]) & (up_point[:, 1] <= bbox[:, 6, 1]) 52 | 53 | tlist = -torch.ones_like(rays, device=rays.device)*1e3 54 | tlist[left_mask, 0] = left_t[left_mask].reshape((-1,)) 55 | tlist[right_mask, 1] = right_t[right_mask].reshape((-1,)) 56 | tlist[front_mask, 2] = front_t[front_mask].reshape((-1,)) 57 | tlist[back_mask, 3] = back_t[back_mask].reshape((-1,)) 58 | tlist[bottom_mask, 4] = bottom_t[bottom_mask].reshape((-1,)) 59 | tlist[up_mask, 5] = up_t[up_mask].reshape((-1,)) 60 | tlist = tlist.topk(k=2, dim=-1) 61 | 62 | return tlist[0] 63 | 64 | class RaySamplePoint(nn.Module): 65 | def __init__(self, coarse_num=64): 66 | super(RaySamplePoint, self).__init__() 67 | self.coarse_num = coarse_num 68 | 69 | 70 | def forward(self, rays, bbox, pdf=None, method='coarse'): 71 | ''' 72 | :param rays: N*6 73 | :param bbox: N*L*8*3 0,1,2,3 bottom 4,5,6,7 up 74 | pdf: n*coarse_num 表示权重 75 | :param method: 76 | :return: L*N*C*1 , L*N*C*3, L*N 77 | ''' 78 | n = rays.shape[0] 79 | l = bbox.shape[1] 80 | #if method=='coarse': 81 | sample_num = self.coarse_num 82 | sample_t = [] 83 | sample_point = [] 84 | mask = [] 85 | for i in range(l): 86 | 87 | bin_range = torch.arange(0, sample_num, device=rays.device).reshape((1, sample_num)).float() 88 | 89 | bin_num = sample_num 90 | n = rays.shape[0] 91 | tlist = intersection(rays, bbox[:,i,:,:]) 92 | start = (tlist[:,1]).reshape((n,1)) 93 | if i == 0: 94 | idx = start <= 0 95 | start[idx] = 0 96 | end = (tlist[:, 0]).reshape((n,1)) 97 | 98 | bin_sample = torch.rand((n, sample_num), device=rays.device) 99 | 100 | bin_width = (end - start)/bin_num 101 | 102 | sample_t.append(((bin_range + bin_sample)* bin_width + start ).unsqueeze(-1)) 103 | sample_point.append(sample_t[i]*rays[:,3:6].unsqueeze(1) + rays[:,:3].unsqueeze(1)) 104 | 105 | mask.append((torch.abs(bin_width)> 1e-5).squeeze()) 106 | 107 | return sample_t, sample_point, mask 108 | 109 | 110 | class RayDistributedSamplePoint(nn.Module): 111 | def __init__(self, fine_num=10): 112 | super(RayDistributedSamplePoint, self).__init__() 113 | self.fine_num = fine_num 114 | 115 | def forward(self, rays, depth, density, noise=0.0): 116 | ''' 117 | :param rays: N*L*6 118 | :param depth: N*L*1 119 | :param density: N*L*1 120 | :param noise:0 121 | :return: 122 | ''' 123 | 124 | sample_num = self.fine_num 125 | n = density.shape[0] 126 | 127 | weights = gen_weight(depth, density, noise=noise) # N*L 128 | weights += 1e-5 129 | bin = depth.squeeze() 130 | 131 | weights = weights[:, 1:].squeeze() #N*(L-1) 132 | pdf = weights/torch.sum(weights, dim=1, keepdim=True) 133 | cdf = torch.cumsum(pdf, dim=1) 134 | cdf_s = torch.cat((torch.zeros((n, 1)).type(cdf.dtype), cdf), dim=1) 135 | fine_bin = torch.linspace(0, 1, sample_num, device=density.device).reshape((1, sample_num)).repeat((n, 1)) 136 | above_index = torch.ones_like(fine_bin, device=density.device).type(torch.LongTensor) 137 | for i in range(cdf.shape[1]): 138 | mask = (fine_bin > (cdf_s[:, i]).reshape((n, 1))) & (fine_bin <= (cdf[:, i]).reshape((n, 1))) 139 | above_index[mask] = i+1 140 | below_index = above_index-1 141 | below_index[below_index==-1]=0 142 | sn_below = torch.gather(bin, dim=1, index=below_index) 143 | sn_above = torch.gather(bin, dim=1, index=above_index) 144 | cdf_below = torch.gather(cdf_s, dim=1, index=below_index) 145 | cdf_above = torch.gather(cdf_s, dim=1, index=above_index) 146 | dnorm = cdf_above - cdf_below 147 | dnorm = torch.where(dnorm<1e-5, torch.ones_like(dnorm, device=density.device), dnorm) 148 | d = (fine_bin - cdf_below)/dnorm 149 | fine_t = (sn_above - sn_below) * d + sn_below 150 | fine_sample_point = fine_t.unsqueeze(-1) * rays[:, 3:6].unsqueeze(1) + rays[:, :3].unsqueeze(1) 151 | return fine_t, fine_sample_point 152 | 153 | 154 | 155 | class RaySamplePoint_Near_Far(nn.Module): 156 | def __init__(self, sample_num=75): 157 | super(RaySamplePoint_Near_Far, self).__init__() 158 | self.sample_num = sample_num 159 | 160 | 161 | def forward(self, rays,near_far): 162 | ''' 163 | :param rays: N*6 164 | :param bbox: N*8*3 0,1,2,3 bottom 4,5,6,7 up 165 | pdf: n*coarse_num 表示权重 166 | :param method: 167 | :return: N*C*3 168 | ''' 169 | n = rays.size(0) 170 | 171 | 172 | ray_o = rays[:,:3] 173 | ray_d = rays[:,3:6] 174 | 175 | # near = 0.1 176 | # far = 5 177 | 178 | 179 | t_vals = torch.linspace(0., 1., steps=self.sample_num,device =rays.device) 180 | #print(near_far[:,0:1].repeat(1, self.sample_num).size(), t_vals.unsqueeze(0).repeat(n,1).size()) 181 | # print(near_far[:,0].unsqueeze(1).expand([n,self.sample_num]).shape) 182 | # print(near_far[:,1].unsqueeze(1).expand([n,self.sample_num]).shape) 183 | # print('------------------------') 184 | #z_vals = near_far[:,0].unsqueeze(1).expand([n,self.sample_num]) * (1.-t_vals).unsqueeze(0).repeat(n,1) + near_far[:,1].unsqueeze(1).expand([n,self.sample_num]) * (t_vals.unsqueeze(0).repeat(n,1)) 185 | z_vals = near_far[:,0:1].repeat(1, self.sample_num) * (1.-t_vals).unsqueeze(0).repeat(n,1) + near_far[:,1:2].repeat(1, self.sample_num) * (t_vals.unsqueeze(0).repeat(n,1)) 186 | # z_vals = near * (1.-t_vals) + far * (t_vals) 187 | # z_vals = z_vals.expand([n, self.sample_num]) 188 | mids = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 189 | upper = torch.cat([mids, z_vals[...,-1:]], -1) 190 | lower = torch.cat([z_vals[...,:1], mids], -1) 191 | 192 | 193 | t_rand = torch.rand(z_vals.size(), device = rays.device) 194 | 195 | z_vals = lower + (upper - lower) * t_rand 196 | 197 | 198 | pts = ray_o[...,None,:] + ray_d[...,None,:] * z_vals[...,:,None] 199 | 200 | return z_vals.unsqueeze(-1), pts 201 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .RaySamplePoint import RaySamplePoint,RaySamplePoint_Near_Far 2 | from .render_layer import VolumeRenderer 3 | from .loss import make_loss -------------------------------------------------------------------------------- /layers/__pycache__/RaySamplePoint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/RaySamplePoint.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/RaySamplePoint.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/RaySamplePoint.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/RaySamplePoint1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/RaySamplePoint1.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/camera_transform.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/camera_transform.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/render_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/render_layer.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/render_layer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/render_layer.cpython-38.pyc -------------------------------------------------------------------------------- /layers/camera_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.autograd.set_detect_anomaly(True) 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | def corrupt_cameras(cam_poses, offset=(-0.1, 0.1), rotation=(-5, 5)): 7 | rand_t = np.random.rand(cam_poses.shape[0], 3) 8 | perturb_t = (1 - rand_t) * offset[0] + rand_t * offset[1] 9 | tr = cam_poses[:, :3, 3] + perturb_t 10 | tr = tr[..., None] # [N, 3, 1] 11 | 12 | rand_r = np.random.rand(cam_poses.shape[0], 3) 13 | rand_r = (1 - rand_r) * rotation[0] + rand_r * rotation[1] 14 | rand_r = np.deg2rad(rand_r) 15 | 16 | # Pre-compute rotation matrices 17 | Rx = np.stack(( 18 | np.ones_like(rand_r[:, 0]), np.zeros_like(rand_r[:, 0]), np.zeros_like(rand_r[:, 0]), 19 | np.zeros_like(rand_r[:, 0]), np.cos(rand_r[:, 0]), -np.sin(rand_r[:, 0]), 20 | np.zeros_like(rand_r[:, 0]), np.sin(rand_r[:, 0]), np.cos(rand_r[:, 0]) 21 | ), axis=1).reshape(-1, 3, 3) 22 | 23 | Ry = np.stack(( 24 | np.cos(rand_r[:, 1]), np.zeros_like(rand_r[:, 1]), np.sin(rand_r[:, 1]), 25 | np.zeros_like(rand_r[:, 1]), np.ones_like(rand_r[:, 1]), np.zeros_like(rand_r[:, 1]), 26 | -np.sin(rand_r[:, 1]), np.zeros_like(rand_r[:, 1]), np.cos(rand_r[:, 1]) 27 | ), axis=1).reshape(-1, 3, 3) 28 | 29 | Rz = np.stack(( 30 | np.cos(rand_r[:, 2]), -np.sin(rand_r[:, 2]), np.zeros_like(rand_r[:, 2]), 31 | np.sin(rand_r[:, 2]), np.cos(rand_r[:, 2]), np.zeros_like(rand_r[:, 2]), 32 | np.zeros_like(rand_r[:, 2]), np.zeros_like(rand_r[:, 2]), np.ones_like(rand_r[:, 2]) 33 | ), axis=1).reshape(-1, 3, 3) 34 | 35 | # Apply rotation sequentially 36 | rot = cam_poses[:, :3, :3] # [N, 3, 3] 37 | for perturb_r in [Rz, Ry, Rx]: 38 | rot = np.matmul(perturb_r, rot) 39 | 40 | return np.concatenate([rot, tr], axis=-1) 41 | 42 | # Camera Transformation Layer 43 | class CameraTransformer(nn.Module): 44 | 45 | def __init__(self, num_cams, trainable=False): 46 | """ Init layered sampling 47 | num_cams: number of training cameras 48 | trainable: Whether planes can be trained by optimizer 49 | """ 50 | super(CameraTransformer, self).__init__() 51 | 52 | self.trainable = trainable 53 | 54 | identity_quat = torch.Tensor([0, 0, 0, 1]).repeat((num_cams, 1)) 55 | identity_off = torch.Tensor([0, 0, 0]).repeat((num_cams, 1)) 56 | if self.trainable: 57 | self.rvec = nn.Parameter(torch.Tensor(identity_quat)) # [N_cameras, 4] 58 | self.tvec = nn.Parameter(torch.Tensor(identity_off)) # [N_cameras, 3] 59 | else: 60 | self.register_buffer('rvec', torch.Tensor(identity_quat)) # [N_cameras, 4] 61 | self.register_buffer('tvec', torch.Tensor(identity_off)) # [N_cameras, 3] 62 | 63 | print("Create %d %s camera transformer" % (num_cams, 'trainable' if self.rvec.requires_grad else 'non-trainable')) 64 | 65 | def rot_mats(self): 66 | theta = torch.sqrt(1e-5 + torch.sum(self.rvec ** 2, dim=1)) 67 | rvec = self.rvec / theta[:, None] 68 | return torch.stack(( 69 | 1. - 2. * rvec[:, 1] ** 2 - 2. * rvec[:, 2] ** 2, 70 | 2. * (rvec[:, 0] * rvec[:, 1] - rvec[:, 2] * rvec[:, 3]), 71 | 2. * (rvec[:, 0] * rvec[:, 2] + rvec[:, 1] * rvec[:, 3]), 72 | 73 | 2. * (rvec[:, 0] * rvec[:, 1] + rvec[:, 2] * rvec[:, 3]), 74 | 1. - 2. * rvec[:, 0] ** 2 - 2. * rvec[:, 2] ** 2, 75 | 2. * (rvec[:, 1] * rvec[:, 2] - rvec[:, 0] * rvec[:, 3]), 76 | 77 | 2. * (rvec[:, 0] * rvec[:, 2] - rvec[:, 1] * rvec[:, 3]), 78 | 2. * (rvec[:, 0] * rvec[:, 3] + rvec[:, 1] * rvec[:, 2]), 79 | 1. - 2. * rvec[:, 0] ** 2 - 2. * rvec[:, 1] ** 2 80 | ), dim=1).view(-1, 3, 3) 81 | 82 | def forward(self, rays_o, rays_d, **render_kwargs): 83 | """ Generate sample points 84 | Args: 85 | rays_o: [N_rays, 3+1] origin points of rays with camera id 86 | rays_d: [N_rays, 3+1] directions of rays with camera id 87 | 88 | render_kwargs: other render parameters 89 | 90 | Return: 91 | rays_o: [N_rays, 3] Transformed origin points 92 | rays_d: [N_rays, 3] Transformed directions of rays 93 | """ 94 | assert rays_o.shape[-1] == 4 95 | assert (rays_o[:, 3] == rays_d[:, 3]).all() 96 | indx = rays_o[:, 3].type(torch.LongTensor) 97 | 98 | # Rotate ray directions w.r.t. rvec 99 | c2w = self.rot_mats()[indx] 100 | rays_d = torch.sum(rays_d[..., None, :3] * c2w[:, :3, :3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 101 | 102 | # Translate camera w.r.t. tvec 103 | rays_o = rays_o[..., :3] + self.tvec[indx] 104 | 105 | return rays_o, rays_d -------------------------------------------------------------------------------- /layers/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def make_loss(cfg): 5 | return nn.MSELoss() 6 | -------------------------------------------------------------------------------- /layers/render_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | 8 | def gen_weight(sigma, delta, act_fn=F.relu): 9 | """Generate transmittance from predicted density 10 | """ 11 | alpha = 1.-torch.exp(-act_fn(sigma.squeeze(-1))*delta) 12 | weight = 1.-alpha + 1e-10 13 | #weight = alpha * torch.cumprod(weight, dim=-1) / weight # exclusive cum_prod 14 | 15 | weight = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1),device = alpha.device), weight], -1), -1)[:, :-1] 16 | 17 | return weight 18 | 19 | class VolumeRenderer(nn.Module): 20 | def __init__(self, use_mask= False, boarder_weight = 1e10): 21 | super(VolumeRenderer, self).__init__() 22 | self.boarder_weight = boarder_weight 23 | self.use_mask = use_mask 24 | 25 | def forward(self, depth, rgb, sigma, noise=0): 26 | """ 27 | N - num rays; L - num samples; 28 | :param depth: torch.tensor, depth for each sample along the ray. [N, L, 1] 29 | :param rgb: torch.tensor, raw rgb output from the network. [N, L, 3] 30 | :param sigma: torch.tensor, raw density (without activation). [N, L, 1] 31 | 32 | :return: 33 | color: torch.tensor [N, 3] 34 | depth: torch.tensor [N, 1] 35 | """ 36 | 37 | delta = (depth[:, 1:] - depth[:, :-1]).squeeze() # [N, L-1] 38 | #pad = torch.Tensor([1e10],device=delta.device).expand_as(delta[...,:1]) 39 | pad = self.boarder_weight*torch.ones(delta[...,:1].size(),device = delta.device) 40 | delta = torch.cat([delta, pad], dim=-1) # [N, L] 41 | 42 | if noise > 0.: 43 | sigma += (torch.randn(size=sigma.size(),device = delta.device) * noise) 44 | 45 | weights = gen_weight(sigma, delta).unsqueeze(-1) #[N, L, 1] 46 | 47 | color = torch.sum(torch.sigmoid(rgb) * weights, dim=1) #[N, 3] 48 | depth = torch.sum(weights * depth, dim=1) # [N, 1] 49 | acc_map = torch.sum(weights, dim = 1) # 50 | # #TODO: This scaling will make the program crash. because the summing nan value when acc_map is near to 0. 51 | # if acc_map.max() > 0.0001: 52 | # acc_map = acc_map / acc_map.max() 53 | 54 | if self.use_mask: 55 | #TODO: Here may have a bug about multiply color at the last 56 | color = color + (1.-acc_map[...,None]) * color 57 | 58 | return color, depth, acc_map, weights 59 | 60 | 61 | if __name__ == "__main__": 62 | N_rays = 1024 63 | N_samples = 64 64 | 65 | depth = torch.randn(N_rays, N_samples, 1) 66 | raw = torch.randn(N_rays, N_samples, 3) 67 | sigma = torch.randn(N_rays, N_samples, 1) 68 | 69 | renderer = VolumeRenderer() 70 | 71 | color, dpt, weights = renderer(depth, raw, sigma) 72 | print('Predicted [CPU]: ', color.shape, dpt.shape, weights.shape) 73 | 74 | if torch.cuda.is_available(): 75 | depth = depth.cuda() 76 | raw = raw.cuda() 77 | sigma = sigma.cuda() 78 | renderer = renderer.cuda() 79 | 80 | color, dpt, weights = renderer(depth, raw, sigma) 81 | print('Predicted [GPU]: ', color.shape, dpt.shape, weights.shape) 82 | 83 | print('Test load data') 84 | tf_depth = np.load('layers/test_output/depth_map.npy') 85 | tf_color = np.load('layers/test_output/rgb_map.npy') 86 | tf_weights = np.load('layers/test_output/weights.npy') 87 | print('TF output = ', tf_depth.shape, tf_color.shape, tf_weights.shape) 88 | 89 | raws = torch.from_numpy(np.load('layers/test_output/raws.npy')) 90 | ray_d = torch.from_numpy(np.load('layers/test_output/ray_d.npy')) 91 | z_val = torch.from_numpy(np.load('layers/test_output/z_vals.npy')) 92 | 93 | print('TF input = ', raws.shape, ray_d.shape, z_val.shape) 94 | 95 | in_depth = z_val 96 | print('in_depth = ', in_depth.shape) 97 | in_raw = raws[:, :, :3] 98 | print('in_raw = ', in_raw.shape) 99 | in_sigma = raws[:, :, 3:] 100 | print('in_sigma = ', in_sigma.shape) 101 | 102 | color, dpt, weights = renderer(in_depth.unsqueeze(-1).cuda(), in_raw.cuda(), in_sigma.cuda()) 103 | print('Predicted-TF [GPU]: ', color.shape, dpt.shape, weights.shape) 104 | 105 | print('ERROR [GPU]: ', 106 | np.mean(tf_color - color.detach().cpu().numpy()), 107 | np.mean(tf_depth - dpt.squeeze(-1).detach().cpu().numpy()), 108 | np.mean(tf_weights - weights.squeeze(-1).detach().cpu().numpy())) -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from .layered_rfrender import LayeredRFRender 4 | 5 | def build_layered_model(cfg,camera_num=0,scale=None,shift=None): 6 | model = LayeredRFRender(cfg, camera_num=camera_num, scale=scale,shift=shift) 7 | return model 8 | -------------------------------------------------------------------------------- /modeling/motion_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils import Trigonometric_kernel 5 | class MotionNet(nn.Module): 6 | # (x,y,z,t) 7 | def __init__(self, c_input=5, include_input = True, input_time = False): 8 | """ Init layered sampling 9 | """ 10 | super(MotionNet, self).__init__() 11 | self.c_input = c_input 12 | self.input_time = input_time 13 | #Positional Encoding 14 | self.tri_kernel_pos = Trigonometric_kernel(L=10,input_dim = c_input, include_input = include_input) 15 | 16 | self.pos_dim = self.tri_kernel_pos.calc_dim(c_input) 17 | backbone_dim = 128 18 | head_dim = 128 19 | 20 | self.motion_net = nn.Sequential( 21 | nn.Linear(self.pos_dim, head_dim), 22 | nn.ReLU(inplace=False), 23 | nn.Linear(head_dim,backbone_dim), 24 | nn.ReLU(inplace=True), 25 | nn.Linear(backbone_dim,backbone_dim), 26 | nn.ReLU(inplace=True), 27 | nn.Linear(backbone_dim,backbone_dim), 28 | nn.ReLU(inplace=True), 29 | nn.Linear(backbone_dim ,head_dim), 30 | nn.ReLU(inplace=True), 31 | nn.Linear(head_dim,3) 32 | ) 33 | 34 | def forward(self, input_0): 35 | """ Generate sample points 36 | Input: 37 | pos: [N,3] points in real world coordinates 38 | 39 | Output: 40 | flow: [N,3] Scene Flow in real world coordinates 41 | """ 42 | 43 | bins_mode = False 44 | if len(input_0.size()) > 2: 45 | bins_mode = True 46 | L = input_0.size(1) 47 | input_0 = input_0.reshape((-1, self.c_input)) # (N,input) 48 | 49 | if self.input_time: 50 | xyz = input_0[:,:-1] 51 | time = input_0[:,-1:] 52 | lower = torch.floor(time) 53 | if not torch.all(torch.eq(lower, time)): 54 | upper = lower + 1 55 | weight = time - lower 56 | i_lower = torch.cat([xyz,lower],-1) 57 | i_upper = torch.cat([xyz,upper],-1) 58 | i_lower = self.tri_kernel_pos(i_lower) 59 | i_upper = self.tri_kernel_pos(i_upper) 60 | input_0 = (1-weight) * i_lower + weight * i_upper 61 | else: 62 | input_0 = self.tri_kernel_pos(input_0) 63 | else: 64 | input_0 = self.tri_kernel_pos(input_0) 65 | 66 | flow = self.motion_net(input_0) 67 | 68 | if bins_mode: 69 | flow = flow.reshape(-1, L, 3) 70 | 71 | return flow 72 | -------------------------------------------------------------------------------- /modeling/spacenet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from torch import nn 6 | import time 7 | 8 | from utils import Trigonometric_kernel 9 | 10 | 11 | 12 | 13 | class SpaceNet(nn.Module): 14 | 15 | 16 | def __init__(self, c_pos=3, include_input = True, use_dir = True, use_time = False, deep_rgb = False): 17 | super(SpaceNet, self).__init__() 18 | 19 | 20 | self.tri_kernel_pos = Trigonometric_kernel(L=10,include_input = include_input) 21 | if use_dir: 22 | self.tri_kernel_dir = Trigonometric_kernel(L=4, include_input = include_input) 23 | if use_time: 24 | self.tri_kernel_time = Trigonometric_kernel(L=10, input_dim=1, include_input = include_input) 25 | 26 | self.c_pos = c_pos 27 | 28 | self.pos_dim = self.tri_kernel_pos.calc_dim(c_pos) 29 | if use_dir: 30 | self.dir_dim = self.tri_kernel_dir.calc_dim(3) 31 | else: 32 | self.dir_dim = 0 33 | 34 | if use_time: 35 | self.time_dim = self.tri_kernel_time.calc_dim(1) 36 | else: 37 | self.time_dim = 0 38 | 39 | self.use_dir = use_dir 40 | self.use_time = use_time 41 | backbone_dim = 256 42 | head_dim = 128 43 | 44 | 45 | self.stage1 = nn.Sequential( 46 | nn.Linear(self.pos_dim, backbone_dim), 47 | nn.ReLU(inplace=True), 48 | nn.Linear(backbone_dim,backbone_dim), 49 | nn.ReLU(inplace=True), 50 | nn.Linear(backbone_dim,backbone_dim), 51 | nn.ReLU(inplace=True), 52 | nn.Linear(backbone_dim,backbone_dim), 53 | nn.ReLU(inplace=True), 54 | ) 55 | 56 | self.stage2 = nn.Sequential( 57 | nn.Linear(backbone_dim+self.pos_dim, backbone_dim), 58 | nn.ReLU(inplace=True), 59 | nn.Linear(backbone_dim,backbone_dim), 60 | nn.ReLU(inplace=True), 61 | nn.Linear(backbone_dim,backbone_dim), 62 | nn.ReLU(inplace=True), 63 | ) 64 | 65 | self.density_net = nn.Sequential( 66 | nn.Linear(backbone_dim, 1) 67 | ) 68 | if deep_rgb: 69 | print("deep") 70 | self.rgb_net = nn.Sequential( 71 | nn.ReLU(inplace=True), 72 | nn.Linear(backbone_dim+self.dir_dim+self.time_dim, head_dim), 73 | nn.ReLU(inplace=True), 74 | nn.Linear(head_dim, head_dim), 75 | nn.ReLU(inplace=True), 76 | nn.Linear(head_dim, head_dim), 77 | nn.ReLU(inplace=True), 78 | nn.Linear(head_dim,3) 79 | ) 80 | else: 81 | self.rgb_net = nn.Sequential( 82 | nn.ReLU(inplace=True), 83 | nn.Linear(backbone_dim+self.dir_dim+self.time_dim, head_dim), 84 | nn.ReLU(inplace=True), 85 | nn.Linear(head_dim,3) 86 | ) 87 | 88 | 89 | ''' 90 | INPUT 91 | pos: 3D positions (N,L,c_pos) or (N,c_pos) 92 | rays: corresponding rays (N,6) 93 | times: corresponding time (N,1) 94 | 95 | OUTPUT 96 | 97 | rgb: color (N,L,3) or (N,3) 98 | density: (N,L,1) or (N,1) 99 | 100 | ''' 101 | def forward(self, pos, rays, times=None, maxs=None, mins=None): 102 | 103 | #beg = time.time() 104 | rgbs = None 105 | if rays is not None and self.use_dir: 106 | 107 | dirs = rays[...,3:6] 108 | 109 | bins_mode = False 110 | if len(pos.size())>2: 111 | bins_mode = True 112 | L = pos.size(1) 113 | pos = pos.reshape((-1,self.c_pos)) #(N,c_pos) 114 | if rays is not None and self.use_dir: 115 | dirs = dirs.unsqueeze(1).repeat(1,L,1) 116 | dirs = dirs.reshape((-1,self.c_pos)) #(N,3) 117 | if rays is not None and self.use_time: 118 | times = times.unsqueeze(1).repeat(1,L,1) 119 | times = times.reshape((-1,1)) #(N,1) 120 | 121 | 122 | 123 | 124 | if maxs is not None: 125 | pos = ((pos - mins)/(maxs-mins) - 0.5) * 2 126 | 127 | pos = self.tri_kernel_pos(pos) 128 | if rays is not None and self.use_dir: 129 | dirs = self.tri_kernel_dir(dirs) 130 | if self.use_time: 131 | times = self.tri_kernel_time(times) 132 | #torch.cuda.synchronize() 133 | #print('transform :',time.time()-beg) 134 | 135 | #beg = time.time() 136 | x = self.stage1(pos) 137 | x = self.stage2(torch.cat([x,pos],dim =1)) 138 | 139 | density = self.density_net(x) 140 | 141 | x1 = 0 142 | if rays is not None and self.use_dir: 143 | x1 = torch.cat([x,dirs],dim =1) 144 | else: 145 | x1 = x.clone() 146 | 147 | rgbs = None 148 | if self.use_time: 149 | x2 = torch.cat([x1,times],dim =1) 150 | rgbs = self.rgb_net(x2) 151 | else: 152 | rgbs = self.rgb_net(x1) 153 | #torch.cuda.synchronize() 154 | #print('fc:',time.time()-beg) 155 | 156 | if bins_mode: 157 | density = density.reshape((-1,L,1)) 158 | rgbs = rgbs.reshape((-1,L,3)) 159 | 160 | return rgbs, density 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /outputs/taekwondo/layered_rfnr_checkpoint_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/outputs/taekwondo/layered_rfnr_checkpoint_1.pt -------------------------------------------------------------------------------- /outputs/walking/layered_rfnr_checkpoint_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/outputs/walking/layered_rfnr_checkpoint_1.pt -------------------------------------------------------------------------------- /render/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from .render_functions import * 4 | from .neural_renderer import NeuralRenderer 5 | from .layered_neural_renderer import LayeredNeuralRenderer -------------------------------------------------------------------------------- /render/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/render/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /render/__pycache__/bkgd_renderer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/render/__pycache__/bkgd_renderer.cpython-38.pyc -------------------------------------------------------------------------------- /render/__pycache__/layered_neural_renderer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/render/__pycache__/layered_neural_renderer.cpython-38.pyc -------------------------------------------------------------------------------- /render/__pycache__/neural_renderer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/render/__pycache__/neural_renderer.cpython-38.pyc -------------------------------------------------------------------------------- /render/__pycache__/render_functions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/render/__pycache__/render_functions.cpython-38.pyc -------------------------------------------------------------------------------- /render/bkgd_renderer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple, Optional 3 | from collections import namedtuple 4 | 5 | import numpy as np 6 | from PIL import Image 7 | import pyrender as pr 8 | 9 | # skew is generally not supported 10 | Pinhole = namedtuple('Pinhole', ['fx', 'fy', 'cx', 'cy']) 11 | 12 | 13 | class MeshRender(ABC): 14 | @abstractmethod 15 | def load_mesh(self, fn: str) -> None: 16 | pass 17 | 18 | @abstractmethod 19 | def render(self, pinhole: Optional[Pinhole] = None, 20 | pose: Optional[np.ndarray] = None) -> Image.Image: 21 | pass 22 | 23 | 24 | class PrRender(MeshRender): 25 | _gl_cv = np.array([ 26 | [1, 0, 0, 0], 27 | [0, -1, 0, 0], 28 | [0, 0, -1, 0], 29 | [0, 0, 0, 1], 30 | ]) 31 | 32 | def __init__(self, resolution: Tuple[int, int]): 33 | self._scene = pr.Scene(ambient_light=np.ones(3)) 34 | self._mesh = None 35 | self._cam = None 36 | 37 | self._width, self._height = resolution 38 | self._render = pr.OffscreenRenderer(self._width, self._height) 39 | 40 | def load_mesh(self, fn: str) -> None: 41 | tm = pr.mesh.trimesh.load_mesh(fn) 42 | mesh = pr.Mesh.from_trimesh(tm) 43 | if mesh: 44 | if self._mesh is not None: 45 | self._scene.remove_node(self._mesh) 46 | self._scene.add(mesh) 47 | self._mesh = mesh 48 | 49 | def render(self, pinhole: Optional[Pinhole] = None, 50 | pose: Optional[np.ndarray] = None, 51 | znear=pr.constants.DEFAULT_Z_NEAR, 52 | zfar=pr.constants.DEFAULT_Z_FAR) -> Image.Image: 53 | if pinhole is not None: 54 | # update intrinsics 55 | if self._cam is None: 56 | cam = pr.IntrinsicsCamera(*pinhole, znear, zfar) 57 | self._cam = self._scene.add(cam) 58 | else: 59 | cam = self._cam.camera 60 | cam.fx, cam.fy, cam.cx, cam.cy = pinhole 61 | cam.znear, cam.zfar = znear, zfar 62 | if self._cam is None: 63 | raise ValueError('Empty intrinsics while previous camera not set') 64 | 65 | # camera is not None 66 | if pose is not None: 67 | # update camera pose 68 | # gl_to_world = cv_to_world @ gl_to_cv 69 | self._scene.set_pose(self._cam, pose @ self._gl_cv) 70 | 71 | color, _ = self._render.render(self._scene) 72 | return color -------------------------------------------------------------------------------- /render/layered_neural_renderer.py: -------------------------------------------------------------------------------- 1 | from config import cfg 2 | import imageio 3 | import os 4 | import numpy as np 5 | import torch 6 | from data import make_ray_data_loader_render, get_iteration_path 7 | from modeling import build_layered_model 8 | from utils import layered_batchify_ray, add_two_dim_dict 9 | from .render_functions import * 10 | from robopy import * 11 | 12 | from scipy.spatial.transform import Rotation as R 13 | from scipy.spatial.transform import Slerp 14 | from scipy.interpolate import splprep, splev 15 | import pdb 16 | 17 | class LayeredNeuralRenderer: 18 | 19 | def __init__(self,cfg,scale=None,shift=None,rotation = None,s_shift = None,s_scale=None,s_alpha=None): 20 | self.alpha = None 21 | self.cfg = cfg 22 | self.scale = scale 23 | self.shift = shift 24 | self.rotation = rotation 25 | self.s_shift = s_shift 26 | self.s_scale = s_scale 27 | self.s_alpha = s_alpha 28 | 29 | 30 | if s_shift != None: 31 | self.shift = self.s_shift[0] 32 | 33 | if s_scale != None: 34 | self.scale = self.s_scale[0] 35 | 36 | if s_alpha != None: 37 | self.alpha = self.s_alpha[0] 38 | 39 | 40 | # The dictionary save all rendered images and videos 41 | self.dataset_dir = self.cfg.OUTPUT_DIR 42 | self.output_dir = os.path.join(self.cfg.OUTPUT_DIR,'rendered') 43 | 44 | self.dataset, self.model = self.load_dataset_model() 45 | 46 | # {0,1} dictionary, 1 means display, 0 means hide, key is [layer_id] 47 | self.display_layers = {} 48 | 49 | # (0,1,2,...,LAYER_NUM) 50 | for layer_id in range(cfg.DATASETS.LAYER_NUM+1): 51 | self.display_layers[layer_id] = 1 52 | 53 | # Intrinsic for all rendered image, update when firstly load dataset 54 | self.gt_poses = self.dataset.poses 55 | self.gt_Ks = self.dataset.Ks 56 | 57 | # self.near = 0.1 58 | self.far = 20.0 59 | 60 | # Each layer will have a min-max frame range 61 | self.min_frame = [1+cfg.DATASETS.FRAME_OFFSET for i in range(cfg.DATASETS.LAYER_NUM+1)] 62 | self.max_frame = [cfg.DATASETS.FRAME_NUM+cfg.DATASETS.FRAME_OFFSET for i in range(cfg.DATASETS.LAYER_NUM+1)] 63 | 64 | self.images = [] 65 | self.depths = [] 66 | 67 | # Total image number rendered and saved in renderer 68 | self.image_num = 0 69 | # Total frame number and layer number, use it carefully, because it may not be all loaded into model 70 | self.frame_num = cfg.DATASETS.FRAME_NUM 71 | self.layer_num = cfg.DATASETS.LAYER_NUM 72 | self.camera_num = self.dataset.camera_num 73 | self.min_camera_id = 0 74 | self.max_camera_id = self.camera_num-1 75 | 76 | self.fps = 25 77 | self.height = cfg.INPUT.SIZE_TEST[1] 78 | self.width = cfg.INPUT.SIZE_TEST[0] 79 | 80 | #Count for save multiple videos 81 | self.save_count = 0 82 | 83 | # All rendered poses and intrinsics aligned with images 84 | self.poses = [] 85 | self.Ks = [] 86 | # Corresponding to each pose, we will have mutiple (layer_id, frame_id) pairs to identify the visible layers and frames. 87 | # Example [[(0,1),(1,1)],[(0,1),(1,2)],...] represent [(layer_0, frame_1), (layer_1,frame_1)] for poses[0] and so on 88 | self.layer_frame_pairs = [] 89 | 90 | # Trace one layer (lookat to the center of layer), -1 means no trace layer 91 | self.trace_layer = -1 92 | 93 | # auto saving dir 94 | self.dir_name = '' 95 | 96 | def load_dataset_model(self): 97 | para_file = get_iteration_path(self.dataset_dir) 98 | print(para_file) 99 | 100 | if para_file is None: 101 | assert 'training model does not exist' 102 | 103 | _, dataset = make_ray_data_loader_render(cfg) 104 | 105 | model = build_layered_model(cfg, dataset.camera_num, scale = self.scale, shift=self.shift) 106 | 107 | model.set_bkgd_bbox(dataset.datasets[0][0].bbox) 108 | model.set_bboxes(dataset.bboxes) 109 | model_dict = model.state_dict() 110 | dict_0 = torch.load(os.path.join(para_file),map_location='cuda') 111 | 112 | model_dict = dict_0['model'] 113 | model_new_dict = model.state_dict() 114 | offset = {k: v for k, v in model_new_dict.items() if k not in model_dict} 115 | for k,v in offset.items(): 116 | model_dict[k] = v 117 | model.load_state_dict(model_dict) 118 | 119 | model.cuda() 120 | 121 | return dataset, model 122 | 123 | 124 | def check_label(self): 125 | output = os.path.join(self.output_dir,'masked_images') 126 | if not os.path.exists(output): 127 | os.makedirs(output) 128 | for i in range(self.frame_num): 129 | output_f = os.path.join(output, 'frame%d' % i) 130 | if not os.path.exists(output_f): 131 | os.makedirs(output_f) 132 | for j in range(self.camera_num): 133 | image, label = self.dataset.get_image_label(j, i) 134 | image = image.permute(1,2,0) 135 | image[label[0,...]==0] = 0 136 | imageio.imwrite(os.path.join(output_f,'%d.jpg'% j), image) 137 | 138 | return 139 | 140 | 141 | 142 | 143 | # The function set the pose, before using it, set the right frame duration for each layer 144 | def set_path_lookat(self, start,end,step_num,center,up): 145 | 146 | # Generate poses 147 | if self.trace_layer == -1: 148 | poses = generate_poses_by_path(start,end,step_num,center,up) 149 | else: 150 | centers = [] 151 | temp = center 152 | for idx in range(step_num): 153 | frame_id = int((self.max_frame-self.min_frame)/step_num*(idx+1)) + self.min_frame 154 | frame_dic = self.datasets[frame_id] 155 | for layer_id in frame_dic: 156 | if layer_id == self.trace_layer: 157 | temp = frame_dic[layer_id].center 158 | centers.append(temp) 159 | poses = generate_poses_by_path_center(start,end,step_num,centers,up) 160 | 161 | self.poses = self.poses + poses 162 | 163 | for idx in range(len(poses)+1): 164 | layer_frame_pair = [] 165 | for layer_id in range(self.layer_num+1): 166 | if self.is_shown_layer(layer_id): 167 | frame_id = int((self.max_frame[layer_id]-self.min_frame[layer_id])/len(poses)*(idx)) + self.min_frame[layer_id] 168 | layer_frame_pair.append((layer_id,frame_id)) 169 | self.layer_frame_pairs.append(layer_frame_pair) 170 | 171 | def set_path_gt_poses(self): 172 | poses = [] 173 | for i in range(self.dataset.poses.shape[0]): 174 | poses.append(self.dataset.poses[i]) 175 | 176 | self.poses = self.poses + poses 177 | self.Ks = self.Ks + self.gt_Ks 178 | 179 | for idx in range(len(poses)+1): 180 | layer_frame_pair = [] 181 | for layer_id in range(self.layer_num+1): 182 | if self.is_shown_layer(layer_id): 183 | frame_id = int((self.max_frame[layer_id]-self.min_frame[layer_id])/len(poses)*(idx)) + self.min_frame[layer_id] 184 | layer_frame_pair.append((layer_id,frame_id)) 185 | self.layer_frame_pairs.append(layer_frame_pair) 186 | 187 | 188 | def set_path_fixed_gt_poses(self,id,num=None): 189 | poses = [] 190 | Ks = [] 191 | if self.s_shift != None: 192 | s_shift_start = np.array(self.s_shift[0]) 193 | s_shift_end = np.array(self.s_shift[1]) 194 | s_shift_step = (s_shift_end-s_shift_start)/(num-1) 195 | self.s_shift_frame = [] 196 | 197 | if self.s_scale != None: 198 | s_scale_start = np.array(self.s_scale[0]) 199 | s_scale_end = np.array(self.s_scale[1]) 200 | s_scale_step = (s_scale_end-s_scale_start)/(num-1) 201 | self.s_scale_frame = [] 202 | 203 | 204 | for i in range(num): 205 | poses.append(self.dataset.poses[id]) 206 | K = self.dataset.Ks[id] 207 | 208 | # EXPEDIENCY 209 | if K == None: 210 | K = self.dataset.Ks[id+1] 211 | Ks.append(K) 212 | if self.s_shift != None: 213 | self.s_shift_frame.append((s_shift_start+i*s_shift_step).tolist()) 214 | 215 | if self.s_scale != None: 216 | self.s_scale_frame.append((s_scale_start+i*s_scale_step).tolist()) 217 | 218 | self.poses = self.poses + poses 219 | self.Ks = self.Ks + Ks 220 | 221 | for idx in range(len(poses)+1): 222 | layer_frame_pair = [] 223 | for layer_id in range(self.layer_num+1): 224 | if self.is_shown_layer(layer_id): 225 | frame_id = int((self.max_frame[layer_id]-self.min_frame[layer_id])/len(poses)*(idx)) + self.min_frame[layer_id] 226 | layer_frame_pair.append((layer_id,frame_id)) 227 | self.layer_frame_pairs.append(layer_frame_pair) 228 | 229 | 230 | def set_smooth_path_poses(self,step_num, around=False, smooth_time = False): 231 | 232 | if self.s_shift != None: 233 | s_shift_start = np.array(self.s_shift[0]) 234 | s_shift_end = np.array(self.s_shift[1]) 235 | s_shift_step = (s_shift_end-s_shift_start)/(step_num-1) 236 | self.s_shift_frame = [] 237 | 238 | if self.s_alpha != None: 239 | s_alpha_start = self.s_alpha[0] 240 | s_alpha_end = self.s_alpha[1] 241 | s_alpha_step = (s_alpha_end-s_alpha_start)/(step_num-1) 242 | self.s_alpha_frame = [] 243 | 244 | poses = [] 245 | Rs = self.gt_poses[self.min_camera_id:self.max_camera_id+1,:3,:3].cpu().numpy() 246 | Ts = self.gt_poses[self.min_camera_id:self.max_camera_id+1,:3,3].cpu().numpy() 247 | #print(Ts) 248 | 249 | key_frames = [i for i in range(self.min_camera_id,self.max_camera_id+1)] 250 | # Only use the first and the last 251 | if not around: 252 | temp = [Rs[0],Rs[-1]] 253 | Rs = np.array(temp) 254 | key_frames = [self.min_camera_id,self.max_camera_id] 255 | 256 | # key_frames = [i for i in range(self.min_camera_id,self.max_camera_id)] 257 | # key_frames = [self.min_camera_id,self.max_camera_id-1] 258 | 259 | interp_frames = [(i * (self.max_camera_id-self.min_camera_id) / (step_num-1) + self.min_camera_id) for i in range(step_num)] 260 | #print(interp_frames) 261 | # print(interp_frames) 262 | Rs = R.from_matrix(Rs) 263 | slerp = Slerp(key_frames, Rs) 264 | interp_Rs = slerp(interp_frames).as_matrix() 265 | #print(interp_Rs) 266 | 267 | x = Ts[:,0] 268 | y = Ts[:,1] 269 | z = Ts[:,2] 270 | 271 | tck, u0 = splprep([x,y,z]) 272 | u_new = [i / (step_num-1) for i in range(step_num)] 273 | new_points = splev(u_new,tck) 274 | 275 | new_points = np.stack(new_points, axis=1) 276 | 277 | K0 = self.gt_Ks[self.min_camera_id] 278 | K1 = self.gt_Ks[self.max_camera_id] 279 | 280 | if self.s_scale != None: 281 | s_scale_start = np.array(self.s_scale[0]) 282 | s_scale_end = np.array(self.s_scale[1]) 283 | s_scale_step = (s_scale_end-s_scale_start)/(step_num-1) 284 | self.s_scale_frame = [] 285 | for i in range(step_num): 286 | pose = np.zeros((4,4)) 287 | pose[:3,:3] = interp_Rs[i] 288 | pose[:3,3] = new_points[i] 289 | pose[3,3] = 1 290 | poses.append(pose) 291 | 292 | K = (K1 - K0) * i / (step_num - 1) + K0 293 | 294 | # print(K) 295 | 296 | self.Ks.append(K) 297 | if self.s_scale != None: 298 | self.s_scale_frame.append((s_scale_start+i*s_scale_step).tolist()) 299 | 300 | if self.s_shift != None: 301 | self.s_shift_frame.append((s_shift_start+i*s_shift_step).tolist()) 302 | 303 | if self.s_alpha != None: 304 | self.s_alpha_frame.append((s_alpha_start+i*s_alpha_step)) 305 | 306 | self.poses = self.poses + poses 307 | 308 | # Generate corresponding layer id and frame id for poses 309 | for idx in range(len(poses)+1): 310 | layer_frame_pair = [] 311 | for layer_id in range(self.layer_num+1): 312 | if self.is_shown_layer(layer_id): 313 | 314 | if not smooth_time: 315 | frame_id = int((self.max_frame[layer_id]-self.min_frame[layer_id])/len(poses)*(idx)) + self.min_frame[layer_id] 316 | else: 317 | frame_id = (self.max_frame[layer_id]-self.min_frame[layer_id])/len(poses)*(idx) + self.min_frame[layer_id] 318 | layer_frame_pair.append((layer_id,frame_id)) 319 | self.layer_frame_pairs.append(layer_frame_pair) 320 | 321 | def load_path_poses(self,poses): 322 | self.poses = poses 323 | step_num = len(poses) 324 | K0 = self.gt_Ks[self.min_camera_id] 325 | K1 = self.gt_Ks[self.max_camera_id-1] 326 | for i in range(step_num): 327 | K = (K1 - K0) * i / (step_num - 1) + K0 328 | self.Ks.append(K) 329 | 330 | for idx in range(len(poses)+1): 331 | layer_frame_pair = [] 332 | for layer_id in range(self.layer_num+1): 333 | if self.is_shown_layer(layer_id): 334 | frame_id = int((self.max_frame[layer_id]-self.min_frame[layer_id])/len(poses)*(idx)) + self.min_frame[layer_id] 335 | layer_frame_pair.append((layer_id,frame_id)) 336 | self.layer_frame_pairs.append(layer_frame_pair) 337 | 338 | 339 | def load_cams_from_path(self, path): 340 | 341 | campose = np.load(os.path.join(path, 'RT_c2w.npy')) 342 | Ts = np.zeros((campose.shape[0],4,4)) 343 | Ts[:,:3,:] = campose.reshape(-1, 3, 4) 344 | Ts[:,3,3] = 1. 345 | 346 | #scale 347 | Ts[:,:3,3] = self.cfg.DATASETS.SCALE * Ts[:,:3,3] 348 | 349 | Ks = np.load(os.path.join(path, 'K.npy')) 350 | Ks = Ks.reshape(-1, 3, 3) 351 | 352 | self.poses = Ts 353 | self.Ks = torch.from_numpy(Ks.astype(np.float32)) 354 | 355 | for idx in range(len(self.poses)+1): 356 | layer_frame_pair = [] 357 | for layer_id in range(self.layer_num+1): 358 | if self.is_shown_layer(layer_id): 359 | frame_id = int((self.max_frame[layer_id]-self.min_frame[layer_id])/len(self.poses)*(idx)) + self.min_frame[layer_id] 360 | layer_frame_pair.append((layer_id,frame_id)) 361 | self.layer_frame_pairs.append(layer_frame_pair) 362 | 363 | 364 | def render_pose(self, pose, K, layer_frame_pair, density_threshold=0,bkgd_density_threshold=0): 365 | print(K) 366 | print(pose) 367 | #print(K) 368 | H = self.dataset.height 369 | W = self.dataset.width 370 | rays, labels, bbox, near_far = self.dataset.get_rays_by_pose_and_K(pose, K, layer_frame_pair) 371 | 372 | rays = rays.cuda() 373 | bbox = bbox.cuda() 374 | labels = labels.cuda() 375 | near_far = near_far.cuda() 376 | 377 | with torch.no_grad(): 378 | stage2, stage1, stage2_layer, stage1_layer, _ = layered_batchify_ray(self.model, rays, labels, bbox, near_far=near_far, density_threshold=density_threshold,bkgd_density_threshold=bkgd_density_threshold) 379 | 380 | color = stage2[0].reshape(H,W,3) 381 | depth = stage2[1].reshape(H,W,1) 382 | depth[depth < 0] = 0 383 | depth = depth / self.far 384 | color_layer = [i[0].reshape(H,W,3) for i in stage2_layer] 385 | depth_layer = [] 386 | for temp in stage2_layer: 387 | depth_1 = temp[1].reshape(H,W,1) 388 | depth_1[depth < 0] = 0 389 | depth_1 = depth_1 / self.far 390 | depth_layer.append(depth_1) 391 | 392 | return color,depth,color_layer,depth_layer 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | def render_path(self,inverse_y_axis=False,density_threshold=0,bkgd_density_threshold=0, auto_save=True): 402 | 403 | if self.dir_name == '': 404 | save_dir = os.path.join(self.output_dir,'video_%d' % self.save_count,'mixed') 405 | else: 406 | save_dir = os.path.join(self.output_dir,self.dir_name,'video_%d' % self.save_count,'mixed') 407 | if not os.path.exists(save_dir): 408 | os.makedirs(save_dir) 409 | os.mkdir(os.path.join(save_dir,'color')) 410 | os.mkdir(os.path.join(save_dir,'depth')) 411 | 412 | file_poses = open(os.path.join(save_dir,'poses'),mode='w') 413 | for pose in self.poses: 414 | file_poses.write(str(pose)+"\n") 415 | file_poses.close() 416 | 417 | file_Ks = open(os.path.join(save_dir,'Ks'),mode='w') 418 | for K in self.Ks: 419 | file_Ks.write(str(K)+"\n") 420 | file_Ks.close() 421 | 422 | 423 | self.images = [] 424 | self.depths = [] 425 | self.images_layer = [[] for i in range(self.layer_num+1)] 426 | self.depths_layer = [[] for i in range(self.layer_num+1)] 427 | 428 | self.image_num = 0 429 | 430 | for idx in range(len(self.poses)): 431 | print('Rendering image %d' % idx) 432 | K = self.Ks[idx] 433 | pose = self.poses[idx] 434 | layer_frame_pair = self.layer_frame_pairs[idx] 435 | if self.s_shift != None: 436 | self.model.shift = self.s_shift_frame[idx] 437 | if self.s_scale != None: 438 | self.model.scale = self.s_scale_frame[idx] 439 | if self.s_alpha != None: 440 | self.model.alpha = self.s_alpha_frame[idx] 441 | 442 | color,depth,color_layer,depth_layer = self.render_pose(pose, K, layer_frame_pair, density_threshold,bkgd_density_threshold) 443 | 444 | 445 | if inverse_y_axis: 446 | color = torch.flip(color,[0]) 447 | depth = torch.flip(depth,[0]) 448 | color_layer = [torch.flip(i,[0]) for i in color_layer] 449 | depth_layer = [torch.flip(i,[0]) for i in depth_layer] 450 | 451 | color = color.cpu() 452 | depth = depth.cpu() 453 | color_layer = [i.cpu() for i in color_layer] 454 | depth_layer = [i.cpu() for i in depth_layer] 455 | 456 | if auto_save: 457 | if self.dir_name == '': 458 | save_dir = os.path.join(self.output_dir,'video_%d' % self.save_count,'mixed') 459 | else: 460 | save_dir = os.path.join(self.output_dir,self.dir_name,'video_%d' % self.save_count,'mixed') 461 | if not os.path.exists(save_dir): 462 | os.makedirs(save_dir) 463 | os.mkdir(os.path.join(save_dir,'color')) 464 | os.mkdir(os.path.join(save_dir,'depth')) 465 | 466 | #print(rgb.shape) 467 | imageio.imwrite(os.path.join(save_dir,'color','%d.jpg'% self.image_num), color) 468 | imageio.imwrite(os.path.join(save_dir,'depth','%d.png'% self.image_num), depth) 469 | self.images.append(color) 470 | self.depths.append(depth) 471 | for layer_id in range(self.layer_num+1): 472 | if self.is_shown_layer(layer_id): 473 | if self.dir_name == '': 474 | save_dir = os.path.join(self.output_dir,'video_%d' % self.save_count,str(layer_id)) 475 | else: 476 | save_dir = os.path.join(self.output_dir,self.dir_name,'video_%d' % self.save_count,str(layer_id)) 477 | if not os.path.exists(save_dir): 478 | os.makedirs(save_dir) 479 | os.mkdir(os.path.join(save_dir,'color')) 480 | os.mkdir(os.path.join(save_dir,'depth')) 481 | 482 | imageio.imwrite(os.path.join(save_dir,'color','%d.jpg'% self.image_num), color_layer[layer_id]) 483 | imageio.imwrite(os.path.join(save_dir,'depth','%d.png'% self.image_num), depth_layer[layer_id]) 484 | self.images_layer[layer_id].append(color) 485 | self.depths_layer[layer_id].append(depth) 486 | 487 | 488 | self.image_num += 1 489 | 490 | 491 | 492 | 493 | 494 | 495 | def retime_by_key_frames(self, layer_id, key_frames_layer, key_frames): 496 | 497 | assert (len(key_frames_layer) == len(key_frames)) 498 | 499 | for i in range(len(self.layer_frame_pairs)): 500 | for j in range(len(self.layer_frame_pairs[i])): 501 | layer, frame = self.layer_frame_pairs[i][j] 502 | #Retiming the corresponding layer 503 | if layer == layer_id: 504 | idx_start = -1 505 | idx_end = -1 506 | weight = 0 507 | for idx in range(len(key_frames)): 508 | if frame <= key_frames[idx]: 509 | idx_end = idx 510 | idx_start = idx_end-1 511 | end = key_frames[idx] 512 | start = 0 513 | if idx == 0: 514 | start = self.min_frame[layer] 515 | else: 516 | start = key_frames[idx-1] 517 | weight = (frame-start) / (end-start) 518 | # print('frame %d, start %d, end %d' % (frame,start,end)) 519 | # print('idx_end %d, idx_start %d' % (idx_end, idx_start)) 520 | break 521 | 522 | new_end = 0 523 | new_start = 0 524 | # print('123') 525 | # print('idx_end %d, idx_start %d' % (idx_end, idx_start)) 526 | if idx_start == -1 and idx_end == 0: 527 | weight = (frame-self.min_frame[layer]) / (key_frames[0] - self.min_frame[layer]) 528 | new_start = self.min_frame[layer] 529 | new_end = key_frames_layer[0] 530 | elif idx_start >= -1 and idx_end != -1: 531 | new_start = key_frames_layer[idx_start] 532 | new_end = key_frames_layer[idx_start+1] 533 | elif idx_start == -1 and idx_end == -1: 534 | weight = (frame-key_frames[-1]) / (self.max_frame[layer] - key_frames[-1]) 535 | new_start = key_frames_layer[-1] 536 | new_end = self.max_frame[layer] 537 | else: 538 | print('Undefined branch', 'start idx is %d, end idx is %d' % (idx_start,idx_end)) 539 | exit(-1) 540 | 541 | new_frame = round(weight * (new_end - new_start) + new_start) 542 | # print('new end is %d, new start is %d' % (new_end,new_start)) 543 | # print('layer %d: old frame is %d, new is %d, weight %f' % (layer,frame,new_frame,weight)) 544 | self.layer_frame_pairs[i][j] = (layer, new_frame) 545 | 546 | # exit(0) 547 | 548 | 549 | 550 | def render_path_walking(self,inverse_y_axis=False,density_threshold=0,bkgd_density_threshold=0, auto_save=True): 551 | 552 | self.images = [] 553 | self.depths = [] 554 | self.images_layer = [[] for i in range(self.layer_num+1)] 555 | self.depths_layer = [[] for i in range(self.layer_num+1)] 556 | 557 | self.image_num = 0 558 | 559 | for idx in range(len(self.poses)): 560 | print('Rendering image %d' % idx) 561 | K = self.Ks[idx] 562 | pose = self.poses[idx] 563 | layer_frame_pair = self.layer_frame_pairs[idx] 564 | 565 | color,depth,color_layer,depth_layer = self.render_pose(pose, K, layer_frame_pair, density_threshold,bkgd_density_threshold) 566 | 567 | if inverse_y_axis: 568 | color = torch.flip(color,[0]) 569 | depth = torch.flip(depth,[0]) 570 | color_layer = [torch.flip(i,[0]) for i in color_layer] 571 | depth_layer = [torch.flip(i,[0]) for i in depth_layer] 572 | 573 | color = color.cpu() 574 | depth = depth.cpu() 575 | color_layer = [i.cpu() for i in color_layer] 576 | depth_layer = [i.cpu() for i in depth_layer] 577 | 578 | if auto_save: 579 | save_dir = os.path.join(self.output_dir,'mixed') 580 | if not os.path.exists(save_dir): 581 | os.makedirs(save_dir) 582 | os.mkdir(os.path.join(save_dir,'color')) 583 | os.mkdir(os.path.join(save_dir,'depth')) 584 | 585 | #print(rgb.shape) 586 | imageio.imwrite(os.path.join(save_dir,'color','%d.jpg'% self.image_num), color) 587 | imageio.imwrite(os.path.join(save_dir,'depth','%d.png'% self.image_num), depth) 588 | self.images.append(color) 589 | self.depths.append(depth) 590 | for layer_id in range(self.layer_num+1): 591 | save_dir = os.path.join(self.output_dir,str(layer_id)) 592 | if not os.path.exists(save_dir): 593 | os.makedirs(save_dir) 594 | os.mkdir(os.path.join(save_dir,'color')) 595 | os.mkdir(os.path.join(save_dir,'depth')) 596 | 597 | imageio.imwrite(os.path.join(save_dir,'color','%d.jpg'% self.image_num), color_layer[layer_id]) 598 | imageio.imwrite(os.path.join(save_dir,'depth','%d.png'% self.image_num), depth_layer[layer_id]) 599 | self.images_layer[layer_id].append(color) 600 | self.depths_layer[layer_id].append(depth) 601 | 602 | color_hide = color_layer[0].clone() 603 | index = depth_layer[2]=start_epoches: 66 | return (1.0-scale)*math.exp(-(epoch0-start_epoches)/(end_epoches-start_epoches)) + scale 67 | 68 | 69 | return 1.0 70 | return torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=[scheduler]*len(optimizer.param_groups)) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .dimension_kernel import Trigonometric_kernel 8 | from .ray_sampling import ray_sampling,ray_sampling_label_bbox,ray_sampling_label_label 9 | from .batchify_rays import batchify_ray, layered_batchify_ray,layered_batchify_ray_big 10 | from .vis_density import vis_density 11 | from .sample_pdf import sample_pdf 12 | from .high_dim_dics import add_two_dim_dict, add_three_dim_dict 13 | from .render_helpers import * 14 | -------------------------------------------------------------------------------- /utils/batchify_rays.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def batchify_ray(model, rays, bboxes, chuncks = 1024*7, near_far=None, near_far_points = [], density_threshold=0,bkgd_density_threshold=0): 5 | N = rays.size(0) 6 | if N 0: 46 | ray_masks = torch.cat(ray_masks, dim=0) 47 | 48 | return (colors[1], depths[1], acc_maps[1]), (colors[0], depths[0], acc_maps[0]), ray_masks 49 | 50 | 51 | def layered_batchify_ray(model, rays, labels, bboxes, chuncks = 512*7, near_far=None, near_far_points = [], density_threshold=0,bkgd_density_threshold=0): 52 | N = rays.size(0) 53 | if N 0: 130 | for i in range(len(stage2_layer)): 131 | ray_masks[i] = torch.cat(ray_masks[i], dim=0) 132 | 133 | stage1_layer_final = [] 134 | stage2_layer_final = [] 135 | 136 | for i in range(len(stage2_layer)): 137 | stage1_layer_final.append((colors[2+i*2], depths[2+i*2], acc_maps[2+i*2])) 138 | stage2_layer_final.append((colors[3+i*2], depths[3+i*2], acc_maps[3+i*2])) 139 | return (colors[1], depths[1], acc_maps[1]), (colors[0], depths[0], acc_maps[0]),\ 140 | stage2_layer_final, stage1_layer_final, ray_masks 141 | 142 | 143 | 144 | def layered_batchify_ray_big(layer_big,scale,model, rays, labels, bboxes, chuncks = 512*7, near_far=None, near_far_points = [], density_threshold=0): 145 | N = rays.size(0) 146 | if N 0: 223 | for i in range(len(stage2_layer)): 224 | ray_masks[i] = torch.cat(ray_masks[i], dim=0) 225 | 226 | stage1_layer_final = [] 227 | stage2_layer_final = [] 228 | 229 | for i in range(len(stage2_layer)): 230 | stage1_layer_final.append((colors[2+i*2], depths[2+i*2], acc_maps[2+i*2])) 231 | stage2_layer_final.append((colors[3+i*2], depths[3+i*2], acc_maps[3+i*2])) 232 | return (colors[1], depths[1], acc_maps[1]), (colors[0], depths[0], acc_maps[0]),\ 233 | stage2_layer_final, stage1_layer_final, ray_masks -------------------------------------------------------------------------------- /utils/dimension_kernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Embedder: 4 | def __init__(self, **kwargs): 5 | self.kwargs = kwargs 6 | self.create_embedding_fn() 7 | 8 | def create_embedding_fn(self): 9 | embed_fns = [] 10 | d = self.kwargs['input_dims'] 11 | out_dim = 0 12 | if self.kwargs['include_input']: 13 | embed_fns.append(lambda x : x) 14 | out_dim += d 15 | 16 | max_freq = self.kwargs['max_freq_log2'] 17 | N_freqs = self.kwargs['num_freqs'] 18 | 19 | if self.kwargs['log_sampling']: 20 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) 21 | else: 22 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) 23 | 24 | for freq in freq_bands: 25 | for p_fn in self.kwargs['periodic_fns']: 26 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) 27 | out_dim += d 28 | 29 | self.embed_fns = embed_fns 30 | self.out_dim = out_dim 31 | 32 | def embed(self, inputs): 33 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 34 | 35 | 36 | def get_embedder(multires, i=0,include_input = True,input_dim=3): 37 | if i == -1: 38 | return nn.Identity(), 3 39 | 40 | embed_kwargs = { 41 | 'include_input' : include_input, 42 | 'input_dims' : input_dim, 43 | 'max_freq_log2' : multires-1, 44 | 'num_freqs' : multires, 45 | 'log_sampling' : True, 46 | 'periodic_fns' : [torch.sin, torch.cos], 47 | } 48 | 49 | embedder_obj = Embedder(**embed_kwargs) 50 | embed = lambda x, eo=embedder_obj : eo.embed(x) 51 | return embed, embedder_obj.out_dim 52 | 53 | # Positional encoding 54 | class Trigonometric_kernel: 55 | def __init__(self, L = 10, input_dim = 3, include_input=True): 56 | 57 | self.L = L 58 | 59 | self.embed_fn, self.out_ch= get_embedder(L,include_input = include_input, input_dim=input_dim) 60 | 61 | ''' 62 | INPUT 63 | x: input vectors (N,C) 64 | 65 | OUTPUT 66 | 67 | pos_kernel: (N, calc_dim(C) ) 68 | ''' 69 | def __call__(self, x): 70 | return self.embed_fn(x) 71 | 72 | def calc_dim(self, dims=0): 73 | return self.out_ch -------------------------------------------------------------------------------- /utils/high_dim_dics.py: -------------------------------------------------------------------------------- 1 | 2 | def add_two_dim_dict(adic, key_a, key_b, val): 3 | if key_a in adic: 4 | adic[key_a].update({key_b: val}) 5 | else: 6 | adic.update({key_a:{key_b: val}}) 7 | 8 | def add_three_dim_dict(adic, key_a, key_b, key_c, val): 9 | if key_a in adic: 10 | if key_b in adic[key_a]: 11 | adic[key_a][key_b].update({key_c: val}) 12 | else: 13 | adic[key_a].update({key_b:{key_c: val}}) 14 | else: 15 | adic.update({key_a: {key_b: {key_c: val}}}) -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | 12 | def setup_logger(name, save_dir, distributed_rank): 13 | logger = logging.getLogger(name) 14 | logger.setLevel(logging.DEBUG) 15 | # don't log results for the non-master process 16 | if distributed_rank > 0: 17 | return logger 18 | ch = logging.StreamHandler(stream=sys.stdout) 19 | ch.setLevel(logging.DEBUG) 20 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 21 | ch.setFormatter(formatter) 22 | logger.addHandler(ch) 23 | 24 | if save_dir: 25 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w') 26 | fh.setLevel(logging.DEBUG) 27 | fh.setFormatter(formatter) 28 | logger.addHandler(fh) 29 | 30 | return logger 31 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia.losses import ssim as dssim 3 | 4 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'): 5 | value=(image_pred-image_gt)**2 6 | if valid_mask is not None: 7 | value = value[valid_mask] 8 | if reduction == 'mean': 9 | return torch.mean(value) 10 | return value 11 | 12 | def mae(image_pred, image_gt): 13 | value=torch.abs(image_pred-image_gt) 14 | return torch.mean(value) 15 | 16 | def psnr(image_pred, image_gt, valid_mask=None, reduction='mean'): 17 | return -10*torch.log10(mse(image_pred, image_gt, valid_mask, reduction)) 18 | 19 | def ssim(image_pred, image_gt, reduction='mean'): 20 | """ 21 | image_pred and image_gt: (3, H, W) 22 | """ 23 | dssim_ = dssim(image_pred.unsqueeze(0), image_gt.unsqueeze(0), 3, reduction) # dissimilarity in [0, 1] 24 | return 1-2*dssim_ # in [-1, 1] -------------------------------------------------------------------------------- /utils/ray_sampling.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | 5 | ''' 6 | Sample rays from views (and images) with/without masks 7 | 8 | -------------------------- 9 | INPUT Tensors 10 | Ks: intrinsics of cameras (M,3,3) 11 | Ts: extrinsic of cameras (M,4,4) 12 | image_size: the size of image [H,W] 13 | images: (M,C,H,W) 14 | mask_threshold: a float threshold to mask rays 15 | masks:(M,H,W) 16 | ------------------- 17 | OUPUT: 18 | list of rays: (N,6) dirs(3) + pos(3) 19 | RGB: (N,C) 20 | ''' 21 | 22 | def ray_sampling(Ks, Ts, image_size, masks=None, mask_threshold = 0.5, images=None, outlier_map=None): 23 | h = image_size[0] 24 | w = image_size[1] 25 | M = Ks.size(0) 26 | 27 | 28 | x = torch.linspace(0,h-1,steps=h,device = Ks.device ) 29 | y = torch.linspace(0,w-1,steps=w,device = Ks.device ) 30 | 31 | grid_x, grid_y = torch.meshgrid(x,y) 32 | coordinates = torch.stack([grid_y, grid_x]).unsqueeze(0).repeat(M,1,1,1) #(M,2,H,W) 33 | coordinates = torch.cat([coordinates,torch.ones(coordinates.size(0),1,coordinates.size(2), 34 | coordinates.size(3),device = Ks.device) ],dim=1).permute(0,2,3,1).unsqueeze(-1) 35 | 36 | 37 | inv_Ks = torch.inverse(Ks) 38 | 39 | dirs = torch.matmul(inv_Ks,coordinates) #(M,H,W,3,1) 40 | dirs = dirs/torch.norm(dirs,dim=3,keepdim = True) 41 | dirs = torch.cat([dirs,torch.zeros(dirs.size(0),coordinates.size(1), 42 | coordinates.size(2),1,1,device = Ks.device) ],dim=3) #(M,H,W,4,1) 43 | 44 | 45 | dirs = torch.matmul(Ts,dirs) #(M,H,W,4,1) 46 | dirs = dirs[:,:,:,0:3,0] #(M,H,W,3) 47 | 48 | pos = Ts[:,0:3,3] #(M,3) 49 | pos = pos.unsqueeze(1).unsqueeze(1).repeat(1,h,w,1) 50 | 51 | if outlier_map is not None: 52 | ids = outlier_map.reshape([M,h,w,1]) 53 | rays = torch.cat([pos,dirs,ids],dim = 3) #(M,H,W,7) 54 | else: 55 | rays = torch.cat([pos,dirs],dim = 3) #(M,H,W,6) 56 | 57 | if images is not None: 58 | rgbs = images.permute(0,2,3,1) #(M,H,W,C) 59 | else: 60 | rgbs = None 61 | 62 | if masks is not None: 63 | rays = rays[masks>mask_threshold,:] 64 | if rgbs is not None: 65 | rgbs = rgbs[masks>mask_threshold,:] 66 | 67 | else: 68 | rays = rays.reshape((-1,rays.size(3))) 69 | if rgbs is not None: 70 | rgbs = rgbs.reshape((-1, rgbs.size(3))) 71 | 72 | return rays,rgbs 73 | 74 | # Sample rays and labels with K,T and bbox 75 | def ray_sampling_label_bbox(image,label,K,T,bbox=None, bboxes=None): 76 | 77 | _,H,W = image.shape 78 | 79 | if bbox != None: 80 | bbox = bbox.reshape(8,3) 81 | bbox = torch.transpose(bbox,0,1) #(3,8) 82 | bbox = torch.cat([bbox,torch.ones(1,bbox.shape[1])],0) 83 | inv_T = torch.inverse(T) 84 | 85 | pts = torch.mm(inv_T,bbox) 86 | 87 | pts = pts[:3,:] 88 | pixels = torch.mm(K,pts) 89 | pixels = pixels / pixels[2,:] 90 | pixels = pixels[:2,:] 91 | temp = torch.zeros_like(pixels) 92 | temp[1,:] = pixels[0,:] 93 | temp[0,:] = pixels[1,:] 94 | pixels = temp 95 | 96 | 97 | min_pixel = torch.min(pixels, dim=1)[0] 98 | max_pixel = torch.max(pixels, dim=1)[0] 99 | 100 | # print(pixels) 101 | # print(min_pixel) 102 | # print(max_pixel) 103 | 104 | min_pixel[min_pixel < 0.0] = 0 105 | if min_pixel[0] >= H-1: 106 | min_pixel[0] = H-1 107 | if min_pixel[1] >= W-1: 108 | min_pixel[1] = W-1 109 | 110 | max_pixel[max_pixel < 0.0] = 0 111 | if max_pixel[0] >= H-1: 112 | max_pixel[0] = H-1 113 | if max_pixel[1] >= W-1: 114 | max_pixel[1] = W-1 115 | 116 | minh = int(min_pixel[0]) 117 | minw = int(min_pixel[1]) 118 | maxh = int(max_pixel[0])+1 119 | maxw = int(max_pixel[1])+1 120 | else: 121 | minh = 0 122 | minw = 0 123 | maxh = H 124 | maxw = W 125 | 126 | # print(max_pixel,min_pixel) 127 | # print(minh,maxh,minw,maxw) 128 | 129 | if minh == maxh or minw == maxw: 130 | print('Warning: there is a pointcloud cannot find right bbox') 131 | 132 | # minh = 0 133 | # minw = 0 134 | # maxh = H 135 | # maxw = W 136 | # image_cutted = image[:,minh:maxh,minw:maxw] 137 | # label_cutted = label[:,minh:maxh,minw:maxw] 138 | 139 | K = K.unsqueeze(0) 140 | T = T.unsqueeze(0) 141 | M = 1 142 | 143 | 144 | x = torch.linspace(0,H-1,steps=H,device = K.device ) 145 | y = torch.linspace(0,W-1,steps=W,device = K.device ) 146 | 147 | grid_x, grid_y = torch.meshgrid(x,y) 148 | coordinates = torch.stack([grid_y, grid_x]).unsqueeze(0).repeat(M,1,1,1) #(M,2,H,W) 149 | coordinates = torch.cat([coordinates,torch.ones(coordinates.size(0),1,coordinates.size(2), 150 | coordinates.size(3),device = K.device) ],dim=1).permute(0,2,3,1).unsqueeze(-1) 151 | 152 | 153 | inv_Ks = torch.inverse(K) 154 | 155 | dirs = torch.matmul(inv_Ks,coordinates) #(M,H,W,3,1) 156 | dirs = dirs/torch.norm(dirs,dim=3,keepdim = True) 157 | dirs = torch.cat([dirs,torch.zeros(dirs.size(0),coordinates.size(1), 158 | coordinates.size(2),1,1,device = K.device) ],dim=3) #(M,H,W,4,1) 159 | 160 | 161 | dirs = torch.matmul(T,dirs) #(M,H,W,4,1) 162 | dirs = dirs[:,:,:,0:3,0] #(M,H,W,3) 163 | 164 | pos = T[:,0:3,3] #(M,3) 165 | pos = pos.unsqueeze(1).unsqueeze(1).repeat(1,H,W,1) 166 | rays = torch.cat([pos,dirs],dim = 3) 167 | 168 | rays = rays[:,minh:maxh,minw:maxw,:] #(H',W',6) 169 | rays = rays.reshape((-1,rays.size(3))) 170 | 171 | ray_mask = torch.zeros_like(label) 172 | ray_mask[:,minh:maxh,minw:maxw] = 1.0 173 | ray_mask = ray_mask.permute(1,2,0) 174 | 175 | label = label[:,minh:maxh,minw:maxw].permute(1,2,0) #(H',W',1) 176 | image = image[:,minh:maxh,minw:maxw].permute(1,2,0) #(H',W',3) 177 | 178 | 179 | rays = rays.reshape(-1,6) 180 | label = label.reshape(-1,1) #(N,1) 181 | image = image.reshape(-1,3) 182 | 183 | if bboxes is not None: 184 | layered_bboxes = torch.zeros(rays.size(0),8,3) 185 | for i in range(len(bboxes)): 186 | idx = (label == i).squeeze() #(N,) 187 | layered_bboxes[idx] = bboxes[i] 188 | 189 | if bboxes is None: 190 | return rays, label, image, ray_mask 191 | else: 192 | return rays, label, image, ray_mask,layered_bboxes 193 | 194 | def ray_sampling_label_label(image,label,K,T,label0): 195 | 196 | _,H,W = image.shape 197 | 198 | K = K.unsqueeze(0) 199 | T = T.unsqueeze(0) 200 | M = 1 201 | 202 | 203 | x = torch.linspace(0,H-1,steps=H,device = K.device ) 204 | y = torch.linspace(0,W-1,steps=W,device = K.device ) 205 | 206 | grid_x, grid_y = torch.meshgrid(x,y) 207 | coordinates = torch.stack([grid_y, grid_x]).unsqueeze(0).repeat(M,1,1,1) #(M,2,H,W) 208 | coordinates = torch.cat([coordinates,torch.ones(coordinates.size(0),1,coordinates.size(2), 209 | coordinates.size(3),device = K.device) ],dim=1).permute(0,2,3,1).unsqueeze(-1) 210 | 211 | 212 | inv_Ks = torch.inverse(K) 213 | 214 | dirs = torch.matmul(inv_Ks,coordinates) #(M,H,W,3,1) 215 | dirs = dirs/torch.norm(dirs,dim=3,keepdim = True) 216 | dirs = torch.cat([dirs,torch.zeros(dirs.size(0),coordinates.size(1), 217 | coordinates.size(2),1,1,device = K.device) ],dim=3) #(M,H,W,4,1) 218 | 219 | 220 | dirs = torch.matmul(T,dirs) #(M,H,W,4,1) 221 | dirs = dirs[:,:,:,0:3,0] #(M,H,W,3) 222 | 223 | pos = T[:,0:3,3] #(M,3) 224 | pos = pos.unsqueeze(1).unsqueeze(1).repeat(1,H,W,1) 225 | rays = torch.cat([pos,dirs],dim = 3) 226 | 227 | 228 | ray_mask = torch.zeros_like(label) 229 | idx = (label == label0) 230 | ray_mask[idx] = 1.0 231 | ray_mask = ray_mask.permute(1,2,0) 232 | 233 | rays = rays[idx,:] #(N,6) 234 | 235 | label = label[idx] #(N) 236 | label = label.reshape(-1,1) 237 | image = image[:,idx.squeeze()].permute(1,0) #(N,3) 238 | 239 | 240 | return rays, label, image, ray_mask 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | -------------------------------------------------------------------------------- /utils/render_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def lookat(eye,center,up): 6 | z = eye - center 7 | z /= np.sqrt(z.dot(z)) 8 | 9 | y = up 10 | x = np.cross(y,z) 11 | y = np.cross(z,x) 12 | 13 | x /= np.sqrt(x.dot(x)) 14 | y /= np.sqrt(y.dot(y)) 15 | 16 | T = np.identity(4) 17 | T[0,:3] = x 18 | T[1,:3] = y 19 | T[2,:3] = z 20 | T[0,3] = -x.dot(eye) 21 | T[1,3] = -y.dot(eye) 22 | T[2,3] = -z.dot(eye) 23 | T[3,:] = np.array([0,0,0,1]) 24 | 25 | # What we need is camera pose 26 | T = np.linalg.inv(T) 27 | T[:3,1] = -T[:3,1] 28 | T[:3,2] = -T[:3,2] 29 | 30 | return T 31 | 32 | # degree is True means using degree measure, else means using radian system 33 | def getSphericalPosition(r,theta,phi,degree=True): 34 | if degree: 35 | theta = theta / 180 * pi 36 | phi = phi / 180 * pi 37 | x = r * cos(theta) * sin(phi) 38 | z = r * cos(theta) * cos(phi) 39 | y = r * sin(theta) 40 | return np.array([x,y,z]) 41 | 42 | def generate_rays(K, T, bbox, h, w): 43 | 44 | if bbox is not None: 45 | bbox = bbox.reshape(8,3) 46 | bbox = torch.transpose(bbox,0,1) #(3,8) 47 | bbox = torch.cat([bbox,torch.ones(1,bbox.shape[1])],0) 48 | inv_T = torch.inverse(T) 49 | 50 | pts = torch.mm(inv_T,bbox) 51 | 52 | pts = pts[:3,:] 53 | pixels = torch.mm(K,pts) 54 | pixels = pixels / pixels[2,:] 55 | pixels = pixels[:2,:] 56 | temp = torch.zeros_like(pixels) 57 | temp[1,:] = pixels[0,:] 58 | temp[0,:] = pixels[1,:] 59 | pixels = temp 60 | 61 | min_pixel = torch.min(pixels, dim=1)[0] 62 | max_pixel = torch.max(pixels, dim=1)[0] 63 | 64 | min_pixel[min_pixel < 0.0] = 0 65 | if min_pixel[0] >= h-1: 66 | min_pixel[0] = h-1 67 | if min_pixel[1] >= w-1: 68 | min_pixel[1] = w-1 69 | 70 | max_pixel[max_pixel < 0.0] = 0 71 | if max_pixel[0] >= h-1: 72 | max_pixel[0] = h-1 73 | if max_pixel[1] >= w-1: 74 | max_pixel[1] = w-1 75 | 76 | minh = int(min_pixel[0]) 77 | minw = int(min_pixel[1]) 78 | maxh = int(max_pixel[0])+1 79 | maxw = int(max_pixel[1])+1 80 | else: 81 | minh = 0 82 | minw = 0 83 | maxh = h 84 | maxw = w 85 | 86 | # print(max_pixel,min_pixel) 87 | # print(minh,maxh,minw,maxw) 88 | 89 | if minh == maxh or minw == maxw: 90 | print('Warning: there is a pointcloud cannot find right bbox') 91 | 92 | K = K.unsqueeze(0) 93 | T = T.unsqueeze(0) 94 | M = 1 95 | 96 | x = torch.linspace(0,h-1,steps=h,device = K.device ) 97 | y = torch.linspace(0,w-1,steps=w,device = K.device ) 98 | 99 | grid_x, grid_y = torch.meshgrid(x,y) 100 | coordinates = torch.stack([grid_y, grid_x]).unsqueeze(0).repeat(M,1,1,1) #(M,2,H,W) 101 | coordinates = torch.cat([coordinates,torch.ones(coordinates.size(0),1,coordinates.size(2), 102 | coordinates.size(3),device = K.device) ],dim=1).permute(0,2,3,1).unsqueeze(-1) 103 | 104 | 105 | inv_K = torch.inverse(K) 106 | 107 | dirs = torch.matmul(inv_K,coordinates) #(M,H,W,3,1) 108 | dirs = dirs/torch.norm(dirs,dim=3,keepdim = True) 109 | dirs = torch.cat([dirs,torch.zeros(dirs.size(0),coordinates.size(1), 110 | coordinates.size(2),1,1,device = K.device) ],dim=3) #(M,H,W,4,1) 111 | 112 | 113 | dirs = torch.matmul(T,dirs) #(M,H,W,4,1) 114 | dirs = dirs[:,:,:,0:3,0] #(M,H,W,3) 115 | 116 | pos = T[:,0:3,3] #(M,3) 117 | pos = pos.unsqueeze(1).unsqueeze(1).repeat(1,h,w,1) 118 | 119 | rays = torch.cat([pos,dirs],dim = 3) #(M,H,W,6) 120 | 121 | rays = rays[:,minh:maxh,minw:maxw,:] #(M,H',W',6) 122 | 123 | rays = rays.reshape((-1,rays.size(3))) 124 | 125 | ray_mask = torch.zeros(h,w,1) 126 | ray_mask[minh:maxh,minw:maxw,:] = 1.0 127 | 128 | return rays, ray_mask -------------------------------------------------------------------------------- /utils/sample_pdf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | ''' 4 | INPUT: 5 | 6 | z_vals: (N,L) 7 | weights: (N,L) 8 | 9 | OUPUT: 10 | 11 | samples_z: (N,L) 12 | 13 | 14 | ''' 15 | torch.autograd.set_detect_anomaly(True) 16 | 17 | 18 | def sample_pdf(z_vals, weights, N_samples, det=False, pytest=False): 19 | # Get pdf 20 | bins = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 21 | weights = weights + 1e-5 # prevent nans 22 | pdf = weights / torch.sum(weights, -1, keepdim=True) 23 | cdf = torch.cumsum(pdf, -1) 24 | cdf = torch.cat([torch.zeros_like(cdf[...,:1], device = z_vals.device), cdf], -1) # (batch, len(bins)) 25 | 26 | # Take uniform samples 27 | if det: 28 | u = torch.linspace(0., 1., steps=N_samples, device = z_vals.device) 29 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 30 | else: 31 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device = z_vals.device) 32 | 33 | # Pytest, overwrite u with numpy's fixed random numbers 34 | if pytest: 35 | np.random.seed(0) 36 | new_shape = list(cdf.shape[:-1]) + [N_samples] 37 | if det: 38 | u = np.linspace(0., 1., N_samples) 39 | u = np.broadcast_to(u, new_shape) 40 | else: 41 | u = np.random.rand(*new_shape) 42 | u = torch.Tensor(u) 43 | 44 | # Invert CDF 45 | u = u.contiguous() 46 | 47 | inds = torch.searchsorted(cdf, u, right = True) 48 | below = torch.max(torch.zeros_like(inds-1, device = inds.device), inds-1) 49 | above = torch.min(cdf.shape[-1]-1 * torch.ones_like(inds, device = inds.device), inds) 50 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 51 | 52 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 53 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 54 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 55 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 56 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 57 | 58 | denom = (cdf_g[...,1]-cdf_g[...,0]) 59 | denom = torch.where(denom<1e-5, torch.ones_like(denom, device=denom.device), denom) 60 | t = (u-cdf_g[...,0])/denom 61 | samples_z = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0]) 62 | 63 | return samples_z 64 | -------------------------------------------------------------------------------- /utils/vis_density.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def vis_density(model,bbox, L= 32): 4 | 5 | maxs = torch.max(bbox, dim=0).values 6 | mins = torch.min(bbox, dim=0).values 7 | 8 | 9 | x = torch.linspace(mins[0],maxs[0],steps=L).cuda() 10 | y = torch.linspace(mins[1],maxs[1],steps=L).cuda() 11 | z = torch.linspace(mins[2],maxs[2],steps=L).cuda() 12 | grid_x ,grid_y,grid_z = torch.meshgrid(x, y,z) 13 | xyz = torch.stack([grid_x ,grid_y,grid_z], dim = -1) #(L,L,L,3) 14 | 15 | xyz = xyz.reshape((-1,3)) #(L*L*L,3) 16 | 17 | 18 | xyzs = xyz.split(5000, dim=0) 19 | 20 | sigmas = [] 21 | for i in xyzs: 22 | with torch.no_grad(): 23 | _,density = model.spacenet_fine(i, None, model.maxs, model.mins) #(L*L*L,1) 24 | density = torch.nn.functional.relu(density) 25 | sigmas.append(density.detach().cpu()) 26 | 27 | sigmas = torch.cat(sigmas, dim=0) 28 | 29 | return sigmas --------------------------------------------------------------------------------