├── README.md
├── config
├── __init__.py
└── defaults.py
├── configs
├── config_taekwondo.yml
└── config_walking.yml
├── data
├── __init__.py
├── build.py
├── collate_batch.py
├── datasets
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── frame_dataset.cpython-36.pyc
│ │ ├── frame_dataset.cpython-38.pyc
│ │ ├── ibr_dynamic.cpython-36.pyc
│ │ ├── ibr_dynamic.cpython-38.pyc
│ │ ├── ray_dataset.cpython-36.pyc
│ │ ├── ray_dataset.cpython-38.pyc
│ │ ├── ray_source.cpython-36.pyc
│ │ ├── ray_source.cpython-38.pyc
│ │ ├── utils.cpython-36.pyc
│ │ └── utils.cpython-38.pyc
│ ├── frame_dataset.py
│ ├── ray_dataset.py
│ └── utils.py
└── transforms
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-38.pyc
│ ├── build.cpython-36.pyc
│ ├── build.cpython-38.pyc
│ ├── random_transforms.cpython-36.pyc
│ └── random_transforms.cpython-38.pyc
│ ├── build.py
│ └── random_transforms.py
├── demo
├── taekwondo_demo.py
└── walking_demo.py
├── engine
├── __init__.py
├── layered_trainer.py
└── render.py
├── images
└── teaser.jpg
├── layers
├── RaySamplePoint-1.py
├── RaySamplePoint.py
├── __init__.py
├── __pycache__
│ ├── RaySamplePoint.cpython-36.pyc
│ ├── RaySamplePoint.cpython-38.pyc
│ ├── RaySamplePoint1.cpython-38.pyc
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-38.pyc
│ ├── camera_transform.cpython-38.pyc
│ ├── loss.cpython-36.pyc
│ ├── loss.cpython-38.pyc
│ ├── render_layer.cpython-36.pyc
│ └── render_layer.cpython-38.pyc
├── camera_transform.py
├── loss.py
└── render_layer.py
├── modeling
├── __init__.py
├── layered_rfrender.py
├── motion_net.py
└── spacenet.py
├── outputs
├── taekwondo
│ └── layered_rfnr_checkpoint_1.pt
└── walking
│ └── layered_rfnr_checkpoint_1.pt
├── render
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── bkgd_renderer.cpython-38.pyc
│ ├── layered_neural_renderer.cpython-38.pyc
│ ├── neural_renderer.cpython-38.pyc
│ └── render_functions.cpython-38.pyc
├── bkgd_renderer.py
├── layered_neural_renderer.py
├── neural_renderer.py
└── render_functions.py
├── solver
├── __init__.py
├── build.py
└── lr_scheduler.py
└── utils
├── __init__.py
├── batchify_rays.py
├── dimension_kernel.py
├── high_dim_dics.py
├── logger.py
├── metrics.py
├── ray_sampling.py
├── render_helpers.py
├── sample_pdf.py
└── vis_density.py
/README.md:
--------------------------------------------------------------------------------
1 | # st-nerf
2 |
3 | We provide PyTorch implementations for our paper:
4 | [Editable Free-viewpoint Video Using a Layered Neural Representation](https://arxiv.org/abs/2104.14786)
5 |
6 | SIGGRAPH 2021
7 |
8 | Jiakai Zhang, Xinhang Liu, Xinyi Ye, Fuqiang Zhao, Yanshun Zhang, Minye Wu, Yingliang Zhang, Lan Xu and Jingyi Yu
9 |
10 |
11 |
12 |
13 | **st-nerf: [Project](https://jiakai-zhang.github.io/st-nerf/) | [Paper](https://arxiv.org/abs/2104.14786)**
14 |
15 | ## Getting Started
16 | ### Installation
17 |
18 | - Clone this repo:
19 | ```bash
20 | git clone https://github.com/DarlingHang/st-nerf
21 | cd st-nerf
22 | ```
23 |
24 | - Install [PyTorch](http://pytorch.org) and other dependencies using:
25 | ```
26 | conda create -n st-nerf python=3.8.5
27 | conda activate st-nerf
28 | conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch
29 | conda install imageio matplotlib
30 | pip install yacs kornia robpy
31 | ```
32 |
33 |
34 | ### Datasets
35 | The walking and taekwondo datasets can be downloaded from [here](https://hkustconnect-my.sharepoint.com/:f:/g/personal/xliufe_connect_ust_hk/EjqArjZxmmtDplj_IrwlUq0BMUyG69zr5YqXFBxgku4rRQ?e=n2fSBs).
36 |
37 | ### Apply a pre-trained model to render demo videos
38 | - We provide our pretrained models which can be found under the `outputs` folder.
39 | - We provide some example scripts under the `demo` folder.
40 | - To run our demo scripts, you need to first downloaded the corresponding dataset, and put them under the folder specified by `DATASETS` -> `TRAIN` in `configs/config_taekwondo.yml` and `configs/config_walking.yml`
41 | - For the walking sequence, you can render videos where some performers are hided by typing the command:
42 | ```
43 | python demo/walking_demo.py -c configs/config_walking.yml
44 | ```
45 | - For the taekwondo sequence, you can render videos where performers are translated and scaled by typing the command:
46 | ```
47 | python demo/taekwondo_demo.py -c configs/config_taekwondo.yml
48 | ```
49 | - The rendered images and videos will be under `outputs/taekwondo/rendered` and `outputs/walking/rendered`
50 |
51 | ## Acknowlegements
52 | We borrowed some codes from [Multi-view Neural Human Rendering (NHR)](https://github.com/wuminye/NHR).
53 |
54 | ## Citation
55 | If you use this code for your research, please cite our papers.
56 | ```
57 | @article{zhang2021editable,
58 | title={Editable free-viewpoint video using a layered neural representation},
59 | author={Zhang, Jiakai and Liu, Xinhang and Ye, Xinyi and Zhao, Fuqiang and Zhang, Yanshun and Wu, Minye and Zhang, Yingliang and Xu, Lan and Yu, Jingyi},
60 | journal={ACM Transactions on Graphics (TOG)},
61 | volume={40},
62 | number={4},
63 | pages={1--18},
64 | year={2021},
65 | publisher={ACM New York, NY, USA}
66 | }
67 | ```
68 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .defaults import _C as cfg
8 |
--------------------------------------------------------------------------------
/config/defaults.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | # -----------------------------------------------------------------------------
4 | # Convention about Training / Test specific parameters
5 | # -----------------------------------------------------------------------------
6 | # Whenever an argument can be either used for training or for testing, the
7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter,
8 | # or _TEST for a test-specific parameter.
9 | # For example, the number of images during training will be
10 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
11 | # IMAGES_PER_BATCH_TEST
12 |
13 | # -----------------------------------------------------------------------------
14 | # Config definition
15 | # -----------------------------------------------------------------------------
16 |
17 | _C = CN()
18 |
19 | _C.deep_rgb = True
20 |
21 | _C.MODEL = CN()
22 | _C.MODEL.DEVICE = "cuda"
23 | _C.MODEL.COARSE_RAY_SAMPLING = 64
24 | _C.MODEL.FINE_RAY_SAMPLING = 80
25 | _C.MODEL.SAMPLE_METHOD = "NEAR_FAR"
26 | _C.MODEL.BOARDER_WEIGHT = 1e10
27 | _C.MODEL.SAME_SPACENET = False
28 |
29 | _C.MODEL.TKERNEL_INC_RAW = True
30 | _C.MODEL.POSE_REFINEMENT = True
31 | _C.MODEL.USE_DIR = True
32 | _C.MODEL.REMOVE_OUTLIERS = False
33 | _C.MODEL.TRAIN_BY_POINTCLOUD = False
34 | _C.MODEL.USE_DEFORM_VIEW = False # Use deformnet to deform view inconsisdency
35 | _C.MODEL.USE_DEFORM_TIME = False # Use deformnet to deform time inconsisdency
36 | _C.MODEL.BKGD_USE_DEFORM_TIME = False
37 | _C.MODEL.BKGD_USE_SPACE_TIME = False
38 | _C.MODEL.USE_SPACE_TIME = False
39 | _C.MODEL.DEEP_RGB = True
40 |
41 |
42 |
43 |
44 | # -----------------------------------------------------------------------------
45 | # INPUT
46 | # -----------------------------------------------------------------------------
47 | _C.INPUT = CN()
48 | # Size of the image during training
49 | _C.INPUT.SIZE_TRAIN = [400,250]
50 | # Size of the image during test
51 | _C.INPUT.SIZE_TEST = [400,250]
52 | # Size of the image during sample layer
53 | _C.INPUT.SIZE_LAYER = [400,250]
54 | # Minimum scale for the image during training
55 | _C.INPUT.MIN_SCALE_TRAIN = 0.5
56 | # Maximum scale for the image during test
57 | _C.INPUT.MAX_SCALE_TRAIN = 1.2
58 | # Random probability for image horizontal flip
59 | _C.INPUT.PROB = 0.5
60 | # Values to be used for image normalization
61 | _C.INPUT.PIXEL_MEAN = [0.1307, ]
62 | # Values to be used for image normalization
63 | _C.INPUT.PIXEL_STD = [0.3081, ]
64 |
65 | # -----------------------------------------------------------------------------
66 | # Dataset
67 | # -----------------------------------------------------------------------------
68 | _C.DATASETS = CN()
69 | # List of the dataset names for training, as present in paths_catalog.py
70 | _C.DATASETS.TRAIN = ""
71 | _C.DATASETS.TMP_RAYS = "rays_tmp"
72 | # List of the dataset names for testing, as present in paths_catalog.py
73 | _C.DATASETS.TEST = ()
74 | _C.DATASETS.SHIFT = 0.0
75 | _C.DATASETS.MAXRATION = 0.0
76 | _C.DATASETS.ROTATION = 0.0
77 | _C.DATASETS.USE_MASK = False
78 | _C.DATASETS.NUM_FRAME = 1
79 | _C.DATASETS.FACTOR = 1
80 | _C.DATASETS.FIXED_NEAR = -1.0
81 | _C.DATASETS.FIXED_FAR = -1.0
82 |
83 | _C.DATASETS.CENTER_X = 0.0
84 | _C.DATASETS.CENTER_Y = 0.0
85 | _C.DATASETS.CENTER_Z = 0.0
86 | _C.DATASETS.SCALE = 1.0
87 | _C.DATASETS.FILE_OFFSET = 0
88 | _C.DATASETS.FRAME_OFFSET = 0
89 | _C.DATASETS.FRAME_NUM = 0
90 | _C.DATASETS.LAYER_NUM = 0
91 | _C.DATASETS.CAMERA_NUM = 0
92 | _C.DATASETS.BKGD_SAMPLE_RATE = 0.1
93 | _C.DATASETS.CAMERA_STEPSIZE = 1
94 |
95 | _C.DATASETS.USE_LABEL = False
96 | _C.DATASETS.VIEW_MASK = None
97 | _C.DATASETS.FIXED_LAYER = []
98 |
99 | # -----------------------------------------------------------------------------
100 | # DataLoader
101 | # -----------------------------------------------------------------------------
102 | _C.DATALOADER = CN()
103 | # Number of data loading threads
104 | _C.DATALOADER.NUM_WORKERS = 8
105 |
106 | # ---------------------------------------------------------------------------- #
107 | # Solver
108 | # ---------------------------------------------------------------------------- #
109 | _C.SOLVER = CN()
110 | _C.SOLVER.OPTIMIZER_NAME = "SGD"
111 |
112 | _C.SOLVER.MAX_EPOCHS = 50
113 |
114 | _C.SOLVER.BASE_LR = 0.001
115 | _C.SOLVER.BIAS_LR_FACTOR = 2
116 |
117 | _C.SOLVER.MOMENTUM = 0.9
118 |
119 | _C.SOLVER.WEIGHT_DECAY = 0.0005
120 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0
121 |
122 | _C.SOLVER.GAMMA = 0.1
123 | _C.SOLVER.STEPS = (30000,)
124 |
125 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 3
126 | _C.SOLVER.WARMUP_ITERS = 500
127 | _C.SOLVER.WARMUP_METHOD = "linear"
128 |
129 | _C.SOLVER.CHECKPOINT_PERIOD = 10
130 | _C.SOLVER.LOG_PERIOD = 100
131 | _C.SOLVER.BUNCH = 4096
132 | _C.SOLVER.START_ITERS=50
133 | _C.SOLVER.END_ITERS=200
134 | _C.SOLVER.LR_SCALE=0.1
135 | _C.SOLVER.COARSE_STAGE = 10
136 |
137 | # Number of images per batch
138 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
139 | # see 2 images per batch
140 | _C.SOLVER.IMS_PER_BATCH = 16
141 |
142 | _C.SOLVER.BBOX_ID = 0
143 |
144 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
145 | # see 2 images per batch
146 | _C.TEST = CN()
147 | _C.TEST.IMS_PER_BATCH = 8
148 | _C.TEST.WEIGHT = ""
149 |
150 | # ---------------------------------------------------------------------------- #
151 | # Misc options
152 | # ---------------------------------------------------------------------------- #
153 | _C.OUTPUT_DIR = ""
154 |
--------------------------------------------------------------------------------
/configs/config_taekwondo.yml:
--------------------------------------------------------------------------------
1 |
2 | SOLVER:
3 | OPTIMIZER_NAME: "Adam"
4 | BASE_LR: 0.0004
5 | WEIGHT_DECAY: 0.0000000
6 | IMS_PER_BATCH: 2000
7 | START_ITERS: 3000
8 | END_ITERS: 60000
9 | LR_SCALE: 0.09
10 | WARMUP_ITERS: 1000
11 |
12 | MAX_EPOCHS: 100
13 | CHECKPOINT_PERIOD: 3000
14 | LOG_PERIOD: 30
15 | BUNCH: 3000
16 | COARSE_STAGE: 1
17 |
18 | BBOX_ID: 0
19 |
20 | INPUT:
21 | # (4130,2202) / 4 = (1033,551)
22 | SIZE_TRAIN: [1920,1080]
23 | SIZE_LAYER: [1920,1080]
24 | SIZE_TEST: [1920,1080]
25 |
26 | DATASETS:
27 | TRAIN:
28 | TMP_RAYS: "rays_tmp_1920"
29 | NUM_FRAME: 1
30 | SHIFT: 0.0
31 | MAXRATION: 0.0
32 | ROTATION: 0.0
33 | FACTOR: 8
34 | FIXED_NEAR: -1.0
35 | FIXED_FAR: -1.0
36 | SCALE: 0.1 # scale in xyz position of world coordinate
37 | FILE_OFFSET: 0
38 | FRAME_OFFSET: 0
39 | BKGD_SAMPLE_RATE: 0.05
40 |
41 | USE_LABEL: True
42 |
43 | USE_MASK: False
44 |
45 | FRAME_NUM: 101
46 | LAYER_NUM: 2
47 |
48 |
49 |
50 |
51 | MODEL:
52 | COARSE_RAY_SAMPLING: 90
53 | FINE_RAY_SAMPLING: 30
54 | SAMPLE_METHOD: "BBOX" # "NEAR_FAR" "BBOX"
55 | BOARDER_WEIGHT: 1e10
56 | SAME_SPACENET: False
57 | TKERNEL_INC_RAW: True
58 | POSE_REFINEMENT: False # If doing the camera pose refinement
59 | USE_DIR: True
60 | REMOVE_OUTLIERS: True # Use masks to remove the density out of the mask
61 | USE_DEFORM_VIEW: False # Use deformnet to deform view inconsisdency
62 | USE_DEFORM_TIME: True # Use deformnet to deform time inconsisdency
63 | USE_SPACE_TIME: True
64 | BKGD_USE_DEFORM_TIME: False
65 | BKGD_USE_SPACE_TIME: False
66 | DEEP_RGB: False
67 |
68 |
69 | TEST:
70 | IMS_PER_BATCH: 1
71 |
72 | OUTPUT_DIR: "outputs/taekwondo"
73 |
--------------------------------------------------------------------------------
/configs/config_walking.yml:
--------------------------------------------------------------------------------
1 |
2 | SOLVER:
3 | OPTIMIZER_NAME: "Adam"
4 | BASE_LR: 0.0004
5 | WEIGHT_DECAY: 0.0000000
6 | IMS_PER_BATCH: 2000
7 | START_ITERS: 3000
8 | END_ITERS: 60000
9 | LR_SCALE: 0.09
10 | WARMUP_ITERS: 1000
11 |
12 | MAX_EPOCHS: 100
13 | CHECKPOINT_PERIOD: 3000
14 | LOG_PERIOD: 30
15 | BUNCH: 3000
16 | COARSE_STAGE: 1
17 |
18 | INPUT:
19 | SIZE_TRAIN: [1920,1080]
20 | SIZE_LAYER: [1920,1080]
21 | SIZE_TEST: [1920,1080]
22 |
23 | DATASETS:
24 | TRAIN:
25 | TMP_RAYS: "rays_tmp_1920_BBOX"
26 | NUM_FRAME: 1
27 | SHIFT: 0.0
28 | MAXRATION: 0.0
29 | ROTATION: 0.0
30 | FACTOR: 8
31 | FIXED_NEAR: -1.0
32 | FIXED_FAR: -1.0
33 | SCALE: 1.0 # scale in xyz position of world coordinate
34 | FILE_OFFSET: 0
35 | FRAME_OFFSET: 25
36 | BKGD_SAMPLE_RATE: 0.0
37 |
38 | USE_LABEL: False
39 |
40 | USE_MASK: False
41 |
42 | FRAME_NUM: 50
43 | LAYER_NUM: 2
44 |
45 |
46 | MODEL:
47 | COARSE_RAY_SAMPLING: 90
48 | FINE_RAY_SAMPLING: 30
49 | SAMPLE_METHOD: "BBOX" # "NEAR_FAR" "BBOX"
50 | BOARDER_WEIGHT: 1e10
51 | SAME_SPACENET: False
52 | TKERNEL_INC_RAW: True
53 | POSE_REFINEMENT: False # If doing the camera pose refinement
54 | USE_DIR: True
55 | REMOVE_OUTLIERS: False # Use masks to remove the density out of the mask
56 | USE_DEFORM_VIEW: False # Use deformnet to deform view inconsisdency
57 | USE_DEFORM_TIME: True # Use deformnet to deform time inconsisdency
58 | USE_SPACE_TIME: False
59 | BKGD_USE_DEFORM_TIME: False
60 | BKGD_USE_SPACE_TIME: False
61 |
62 | TEST:
63 | IMS_PER_BATCH: 1
64 |
65 | OUTPUT_DIR: "outputs/walking"
66 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .build import make_ray_data_loader, make_ray_data_loader_view, make_ray_data_loader_render
8 | from .datasets.utils import get_iteration_path, get_iteration_path_and_iter
9 |
--------------------------------------------------------------------------------
/data/build.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: Minye Wu
4 | @GITHUB: wuminye
5 | """
6 |
7 | from torch.utils import data
8 | import numpy as np
9 | from .datasets.ray_dataset import Ray_Dataset, Ray_Dataset_View, Ray_Dataset_Render, Ray_Frame_Layer_Dataset
10 | from .transforms import build_transforms, build_layered_transforms
11 |
12 |
13 | def make_ray_data_loader(cfg, is_train=True):
14 |
15 | batch_size = cfg.SOLVER.IMS_PER_BATCH
16 |
17 | transforms_bkgd = build_layered_transforms(cfg, is_train=is_train, is_layer=False)
18 | transforms_layer = build_layered_transforms(cfg, is_train=is_train, is_layer=True)
19 |
20 | datasets = Ray_Dataset(cfg, transforms_bkgd, transforms_layer)
21 |
22 | num_workers = cfg.DATALOADER.NUM_WORKERS
23 | data_loader = data.DataLoader(
24 | datasets, batch_size=batch_size, shuffle=True, num_workers=num_workers
25 | )
26 |
27 | return data_loader, datasets
28 |
29 | def make_ray_data_loader_view(cfg, is_train=False):
30 |
31 | batch_size = cfg.SOLVER.IMS_PER_BATCH
32 |
33 | transforms = build_transforms(cfg, is_train)
34 |
35 | datasets = Ray_Dataset_View(cfg, transforms)
36 |
37 | num_workers = cfg.DATALOADER.NUM_WORKERS
38 | data_loader = data.DataLoader(
39 | datasets, batch_size=batch_size, shuffle=True, num_workers=num_workers
40 | )
41 |
42 | return data_loader, datasets
43 |
44 | def make_ray_data_loader_render(cfg, is_train=False):
45 |
46 | batch_size = cfg.SOLVER.IMS_PER_BATCH
47 |
48 |
49 | transforms = build_transforms(cfg, is_train)
50 |
51 | datasets = Ray_Dataset_Render(cfg, transforms)
52 |
53 | num_workers = cfg.DATALOADER.NUM_WORKERS
54 | data_loader = data.DataLoader(
55 | datasets, batch_size=batch_size, shuffle=False, num_workers=num_workers
56 | )
57 |
58 | return data_loader, datasets
--------------------------------------------------------------------------------
/data/collate_batch.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
--------------------------------------------------------------------------------
/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: Minye Wu
4 | @GITHUB: wuminye
5 | """
6 |
7 | # from .ibr_dynamic import IBRDynamicDataset
--------------------------------------------------------------------------------
/data/datasets/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/frame_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/frame_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/frame_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/frame_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/ibr_dynamic.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/ibr_dynamic.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/ibr_dynamic.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/ibr_dynamic.cpython-38.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/ray_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/ray_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/ray_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/ray_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/ray_source.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/ray_source.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/ray_source.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/ray_source.cpython-38.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/datasets/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/data/datasets/frame_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | from .utils import campose_to_extrinsic, read_intrinsics, read_mask
5 | from PIL import Image
6 | import torchvision
7 | import torch.distributions as tdist
8 | import open3d as o3d
9 |
10 | class FrameDataset(torch.utils.data.Dataset):
11 |
12 | def __init__(self, dataset_path, transform, frame_id, layer_num, file_offset):
13 |
14 | super(FrameDataset, self).__init__()
15 |
16 | # 1. Set the dataset path for the next loading
17 | self.file_offset = file_offset
18 | self.frame_id = frame_id # The frame number
19 | self.layer_num = layer_num # The number of layers
20 | self.image_path = os.path.join(dataset_path,str(self.frame_id),'images')
21 | self.label_path = os.path.join(dataset_path,str(self.frame_id),'labels')
22 | self.pointcloud_path = os.path.join(dataset_path,str(self.frame_id),'pointclouds')
23 | self.pose_path = os.path.join(dataset_path, 'pose')
24 | self.transform = transform
25 | # 2. Loading Intrinsics & Camera poses
26 | camposes = np.loadtxt(os.path.join(self.pose_path,'RT_c2w.txt'))
27 | # Ts are camera poses
28 | self.Ts = torch.Tensor(campose_to_extrinsic(camposes))
29 | self.Ks = torch.Tensor(read_intrinsics(os.path.join(self.pose_path,'K.txt')))
30 | # 3. Load pointclouds for different layers
31 | self.pointclouds = [] # Finally (layer_num,)
32 | self.bboxs = [] # Finally (layer_num,)
33 |
34 |
35 |
36 | for i in range(layer_num):
37 | # Start from 1.ply to layer_num.ply
38 | pointcloud_name = os.path.join(self.pointcloud_path, '%d.ply' % (i+1))
39 |
40 | if not os.path.exists(pointcloud_name):
41 | pointcloud_name = os.path.join(self.pointcloud_path1, '%d.ply' % (i+1))
42 |
43 | if not os.path.exists(pointcloud_name):
44 | print('Cannot find corresponding pointcloud in path: ', pointcloud_name)
45 | pointcloud = o3d.io.read_point_cloud(pointcloud_name)
46 | xyz = np.asarray(pointcloud.points)
47 |
48 | xyz = torch.Tensor(xyz)
49 | self.pointclouds.append(xyz)
50 |
51 | max_xyz = torch.max(xyz, dim=0)[0]
52 | min_xyz = torch.min(xyz, dim=0)[0]
53 |
54 | # Default scalar is 0.3
55 | tmp = (max_xyz - min_xyz) * 0.0
56 |
57 | max_xyz = max_xyz + tmp
58 | min_xyz = min_xyz - tmp
59 |
60 | minx, miny, minz = min_xyz[0],min_xyz[1],min_xyz[2]
61 | maxx, maxy, maxz = max_xyz[0],max_xyz[1],max_xyz[2]
62 | bbox = torch.Tensor([[minx,miny,minz],[maxx,miny,minz],[maxx,maxy,minz],[minx,maxy,minz],
63 | [minx,miny,maxz],[maxx,miny,maxz],[maxx,maxy,maxz],[minx,maxy,maxz]])
64 |
65 | bbox = bbox.reshape(1,8,3)
66 |
67 | self.bboxs.append(torch.Tensor(bbox))
68 |
69 | print('Frame %d dataset loaded, there are totally %d layers' %(frame_id,layer_num))
70 |
71 | def __len__(self):
72 | return self.cam_num * self.layer_num
73 |
74 | def get_data(self, camera_id, layer_id):
75 | # Find K,T, bbox
76 | T = self.Ts[camera_id]
77 | K = self.Ks[camera_id]
78 | bbox = self.bboxs[layer_id-1]
79 | # Load image
80 | image_path = os.path.join(self.image_path, '%03d.png' % (camera_id + self.file_offset))
81 | image = Image.open(image_path)
82 | # Load label
83 | label_path = os.path.join(self.label_path, '%03d.npy' % (camera_id + self.file_offset))
84 | if not os.path.exists(label_path):
85 | label_path = os.path.join(self.label_path, '%03d_label.npy' % (camera_id + self.file_offset))
86 | label = np.load(label_path)
87 |
88 | # Transform image label K T to right scale
89 | image, label, K, T, ROI = self.transform(image, Ks=K, Ts=T, label=label)
90 |
91 | return image, label, K, T, ROI, bbox
92 |
93 |
94 | class FrameLayerDataset(torch.utils.data.Dataset):
95 |
96 | def __init__(self, cfg, transform, frame_id, layer_id):
97 |
98 | super(FrameLayerDataset, self).__init__()
99 |
100 | # 1. Set the dataset path for the next loading
101 | dataset_path = cfg.DATASETS.TRAIN
102 | fixed_near, fixed_far = cfg.DATASETS.FIXED_NEAR, cfg.DATASETS.FIXED_FAR
103 | scale=cfg.DATASETS.SCALE
104 | camera_stepsize=cfg.DATASETS.CAMERA_STEPSIZE
105 | self.file_offset = cfg.DATASETS.FILE_OFFSET
106 |
107 | self.frame_id = frame_id # The frame id
108 | self.layer_id = layer_id # The layer id
109 | self.image_path = os.path.join(dataset_path,'frame'+str(self.frame_id),'images')
110 | self.label_path = os.path.join(dataset_path,'frame'+str(self.frame_id),'labels')
111 | #TODO: Need to fix when background is deformable
112 | self.pointcloud_path1 = "None"
113 |
114 | if layer_id != 0:
115 | self.pointcloud_path = os.path.join(dataset_path,'frame'+str(self.frame_id),'pointclouds')
116 | self.pointcloud_path1 = os.path.join(dataset_path,'background')
117 | else:
118 | self.pointcloud_path = os.path.join(dataset_path,'background')
119 | self.pose_path = os.path.join(dataset_path, 'pose')
120 | self.transform = transform
121 | # 2. Loading Intrinsics & Camera poses
122 | camposes = np.loadtxt(os.path.join(self.pose_path,'RT_c2w.txt'))
123 |
124 | # Ts are camera poses
125 | self.Ts = torch.Tensor(campose_to_extrinsic(camposes))
126 | self.Ts[:,0:3,3] = self.Ts[:,0:3,3] * scale
127 | print('scale is ', scale)
128 |
129 | self.Ks = torch.Tensor(read_intrinsics(os.path.join(self.pose_path,'K.txt')))
130 |
131 | self.cfg = cfg
132 | if cfg.DATASETS.CAMERA_NUM == 0:
133 | self.cam_num = self.Ts.shape[0]
134 | else:
135 | self.cam_num = cfg.DATASETS.CAMERA_NUM
136 |
137 | self.mask = np.ones(self.Ts.shape[0])
138 | self.mask_path = cfg.DATASETS.VIEW_MASK
139 | if self.mask_path != None:
140 | if os.path.exists(self.mask_path):
141 | self.mask = read_mask(self.mask_path)
142 |
143 | pointcloud_name = os.path.join(self.pointcloud_path, '%d.ply' % (layer_id))
144 |
145 | self.pointcloud = None
146 | if not os.path.exists(pointcloud_name):
147 | pointcloud_name = os.path.join(self.pointcloud_path1, '%d.ply' % (layer_id))
148 |
149 | bbox_name = 'bbox_tmp'
150 | if not os.path.exists(pointcloud_name):
151 | print('Warning: Cannot find corresponding pointcloud in path: ', pointcloud_name)
152 | self.bbox = None
153 | self.center = torch.Tensor([0,0,0])
154 | tmp_bbox_path = os.path.join(dataset_path,bbox_name,'frame'+str(frame_id),'layer'+str(layer_id))
155 | if os.path.exists(os.path.join(tmp_bbox_path,'center.pt')):
156 | print('There are bbox generated for layer %d, frame %d before, loading bbox...' % (layer_id, frame_id))
157 | # pointcloud = o3d.io.read_point_cloud(pointcloud_name)
158 | # xyz = np.asarray(pointcloud.points)
159 |
160 | # xyz = torch.Tensor(xyz)
161 | # self.pointcloud = xyz
162 | self.center = torch.load(os.path.join(tmp_bbox_path,'center.pt'))
163 | self.bbox = torch.load(os.path.join(tmp_bbox_path,'bbox.pt'))
164 | else:
165 | tmp_bbox_path = os.path.join(dataset_path,bbox_name,'frame'+str(frame_id),'layer'+str(layer_id))
166 | if not os.path.exists(os.path.join(tmp_bbox_path,'center.pt')):
167 | print('There is no bbox generated before, generating bbox...')
168 | if not os.path.exists(tmp_bbox_path):
169 | os.makedirs(tmp_bbox_path)
170 | pointcloud = o3d.io.read_point_cloud(pointcloud_name)
171 | xyz = np.asarray(pointcloud.points)
172 |
173 | xyz = torch.Tensor(xyz)
174 | self.pointcloud = xyz * scale
175 |
176 | max_xyz = torch.max(self.pointcloud, dim=0)[0]
177 | min_xyz = torch.min(self.pointcloud, dim=0)[0]
178 |
179 | # Default scalar is 0.3
180 | tmp = (max_xyz - min_xyz) * 0.0
181 |
182 | max_xyz = max_xyz + tmp
183 | min_xyz = min_xyz - tmp
184 |
185 | minx, miny, minz = min_xyz[0],min_xyz[1],min_xyz[2]
186 | maxx, maxy, maxz = max_xyz[0],max_xyz[1],max_xyz[2]
187 | bbox = torch.Tensor([[minx,miny,minz],[maxx,miny,minz],[maxx,maxy,minz],[minx,maxy,minz],
188 | [minx,miny,maxz],[maxx,miny,maxz],[maxx,maxy,maxz],[minx,maxy,maxz]])
189 |
190 | bbox = bbox.reshape(1,8,3)
191 |
192 | self.center = np.array([(min_xyz[0]+max_xyz[0])/2, (min_xyz[1]+max_xyz[1])/2, (min_xyz[2]+max_xyz[2])/2])
193 | self.bbox = torch.Tensor(bbox)
194 | if not os.path.exists(os.path.join(tmp_bbox_path,'center.pt')):
195 | torch.save(self.center, os.path.join(tmp_bbox_path,'center.pt'))
196 | if not os.path.exists(os.path.join(tmp_bbox_path,'bbox.pt')):
197 | torch.save(self.bbox, os.path.join(tmp_bbox_path,'bbox.pt'))
198 | else:
199 | print('There are bbox generated for layer %d, frame %d before, loading bbox...' % (layer_id, frame_id))
200 | # pointcloud = o3d.io.read_point_cloud(pointcloud_name)
201 | # xyz = np.asarray(pointcloud.points)
202 |
203 | # xyz = torch.Tensor(xyz)
204 | # self.pointcloud = xyz
205 | self.center = torch.load(os.path.join(tmp_bbox_path,'center.pt'))
206 | self.bbox = torch.load(os.path.join(tmp_bbox_path,'bbox.pt'))
207 |
208 |
209 | if fixed_near == -1.0 and fixed_far == -1.0:
210 | near_far_name = 'near_far_tmp'
211 | tmp_near_far_path = os.path.join(dataset_path,near_far_name,'frame'+str(frame_id),'layer'+str(layer_id))
212 | if not os.path.exists(os.path.join(tmp_near_far_path,'near.pt')):
213 | if not os.path.exists(os.path.join(tmp_near_far_path)):
214 | os.makedirs(tmp_near_far_path)
215 | inv_Ts = torch.inverse(self.Ts).unsqueeze(1) #(M,1,4,4)
216 |
217 | if self.pointcloud is None:
218 | pointcloud = o3d.io.read_point_cloud(pointcloud_name)
219 | xyz = np.asarray(pointcloud.points)
220 |
221 | xyz = torch.Tensor(xyz)
222 | self.pointcloud = xyz * scale
223 | vs = self.pointcloud.clone().unsqueeze(-1) #(N,3,1)
224 | vs = torch.cat([vs,torch.ones(vs.size(0),1,vs.size(2)) ],dim=1) #(N,4,1)
225 |
226 | pts = torch.matmul(inv_Ts,vs) #(M,N,4,1)
227 |
228 | pts_max = torch.max(pts, dim=1)[0].squeeze() #(M,4)
229 | pts_min = torch.min(pts, dim=1)[0].squeeze() #(M,4)
230 |
231 | pts_max = pts_max[:,2] #(M)
232 | pts_min = pts_min[:,2] #(M)
233 |
234 | self.near = pts_min
235 | # self.near[self.near<(pts_max*0.1)] = pts_max[self.near<(pts_max*0.1)]*0.1
236 |
237 | self.far = pts_max
238 | torch.save(self.near,os.path.join(tmp_near_far_path,'near.pt'))
239 | torch.save(self.far,os.path.join(tmp_near_far_path,'far.pt'))
240 | else:
241 | self.near = torch.load(os.path.join(tmp_near_far_path,'near.pt'))
242 | self.far = torch.load(os.path.join(tmp_near_far_path,'far.pt'))
243 | else:
244 | self.near = torch.ones(self.Ts.shape[0]) * fixed_near
245 | self.far = torch.ones(self.Ts.shape[0]) * fixed_far
246 |
247 | print('Layer %d, Frame %d dataset loaded' %(layer_id,frame_id))
248 |
249 | def __len__(self):
250 | return self.cam_num
251 |
252 | def get_data(self, camera_id):
253 | # when camera num is not equal to zero, means we want a complete offset from camera parameters to images, else, only images
254 | if self.cfg.DATASETS.CAMERA_NUM != 0:
255 | camera_id = camera_id + self.file_offset
256 | if self.mask[camera_id] == 0:
257 | return None, None, None, None, None, None, None, 0
258 | # Find K,T, bbox
259 |
260 | T = self.Ts[camera_id]
261 | K = self.Ks[camera_id]
262 | bbox = self.bbox
263 | # Load image
264 | image_path = os.path.join(self.image_path, '%03d.png' % (camera_id))
265 | if not os.path.exists(image_path):
266 | image_path = os.path.join(self.image_path, '%d.png' % (camera_id))
267 | if not os.path.exists(image_path):
268 | image = None
269 | else:
270 | image = Image.open(image_path)
271 | # Load label
272 | label = None
273 | label_path = os.path.join(self.label_path, '%03d.npy' % (camera_id))
274 | if not os.path.exists(label_path):
275 | label_path = os.path.join(self.label_path, '%03d_label.npy' % (camera_id))
276 | if not os.path.exists(label_path):
277 | label_path = os.path.join(self.label_path, '%d.npy' % (camera_id))
278 | if not os.path.exists(label_path):
279 | if image == None:
280 | label = None
281 | else:
282 | width, height = image.size
283 | label = np.ones((height, width)) * self.layer_id
284 | print('Warning: There is no label map for this dataset, and we trying to train layer %d, for frame %d, so generate a full label map with it' % (self.layer_id, self.frame_id))
285 | else:
286 | label = np.load(label_path)
287 |
288 | # Transform image label K T to right scale
289 | image, label, K, T, ROI = self.transform(image, Ks=K, Ts=T, label=label)
290 |
291 | return image, label, K, T, ROI, bbox, torch.tensor([self.near[camera_id],self.far[camera_id]]).unsqueeze(0), self.mask[camera_id]
292 |
293 | def get_original_size(self):
294 |
295 | image_path = os.path.join(self.image_path, '%03d.png' % (0))
296 | if not os.path.exists(image_path):
297 | image_path = os.path.join(self.image_path, '%d.png' % (0))
298 | if not os.path.exists(image_path):
299 | image = None
300 | else:
301 | image = Image.open(image_path)
302 |
303 | return image.size
304 |
305 |
306 |
--------------------------------------------------------------------------------
/data/datasets/ray_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from math import sin, cos, pi
4 | import os
5 | from .utils import campose_to_extrinsic, read_intrinsics
6 | from PIL import Image
7 | import torchvision
8 | import torch.distributions as tdist
9 |
10 | from .frame_dataset import FrameDataset, FrameLayerDataset
11 | from utils import ray_sampling, ray_sampling_label_bbox, lookat, getSphericalPosition, generate_rays, ray_sampling_label_label
12 |
13 | class Ray_Dataset(torch.utils.data.Dataset):
14 |
15 | def __init__(self, cfg, transforms_bkgd, transforms_layer):
16 |
17 | super(Ray_Dataset, self).__init__()
18 |
19 | frame_num=cfg.DATASETS.FRAME_NUM
20 | layer_num=cfg.DATASETS.LAYER_NUM
21 |
22 | frame_offset=cfg.DATASETS.FRAME_OFFSET
23 | bkgd_sample_rate = cfg.DATASETS.BKGD_SAMPLE_RATE
24 |
25 | # [[bkgd_frame1,bkgd_frame2,...,],[layer1_frame_1,layer1_frame2,...,],[layer2_frame1,layer2_frame2,...,],...,]
26 | self.datasets = []
27 | self.bboxes = torch.zeros(frame_num+frame_offset, layer_num, 8, 3)
28 | for layer_id in range(layer_num+1):
29 | datasets_layer = []
30 | for frame_id in range(1+frame_offset,frame_offset+frame_num+1):
31 | if layer_id == 0:
32 | sample_rate=bkgd_sample_rate
33 | use_label_map=True
34 | transform=transforms_bkgd
35 | else:
36 | sample_rate=1
37 | for i in range(len(cfg.DATASETS.FIXED_LAYER)):
38 | if cfg.DATASETS.FIXED_LAYER[i] == layer_id:
39 | sample_rate = 0
40 | use_label_map=cfg.DATASETS.USE_LABEL
41 | transform=transforms_layer
42 | dataset_frame_layer = Ray_Frame_Layer_Dataset(cfg, transform, frame_id, layer_id, use_label_map, sample_rate)
43 | if layer_id != 0:
44 | self.bboxes[frame_id-1, layer_id-1] = dataset_frame_layer.layer_bbox
45 | datasets_layer.append(dataset_frame_layer)
46 | self.datasets.append(datasets_layer)
47 |
48 | self.frame_num = frame_num
49 | self.layer_num = layer_num
50 |
51 | self.bkgd_sample_rate = bkgd_sample_rate
52 |
53 | self.ray_length = np.zeros(layer_num+1)
54 |
55 | for l in range(len(self.datasets)):
56 | layer_datasets = self.datasets[l]
57 | for layer_frame_dataset in layer_datasets:
58 | self.ray_length[l] += len(layer_frame_dataset)
59 |
60 | for l in range(len(self.datasets)):
61 | print('Layer %d has %d rays' % (l, int(self.ray_length[l])))
62 | self.length = int(sum(self.ray_length))
63 |
64 | print('The whole ray number is %d' % self.length)
65 | self.camera_num = self.datasets[0][0].camera_num
66 |
67 |
68 | def __len__(self):
69 |
70 | return self.length
71 |
72 | def __getitem__(self, index):
73 | # if index < self.bkgd_length:
74 | # index = int(index / self.bkgd_sample_rate)
75 | # else:
76 | # index = (index-self.bkgd_length) + self.original_bkgd_length
77 | temp = 0
78 | for layer_datasets in self.datasets:
79 | for layer_frame_dataset in layer_datasets:
80 | if temp + len(layer_frame_dataset) > index:
81 | return layer_frame_dataset[index-temp]
82 | else:
83 | temp += len(layer_frame_dataset)
84 |
85 | class Ray_Dataset_View(torch.utils.data.Dataset):
86 |
87 | def __init__(self, cfg, transform):
88 |
89 | super(Ray_Dataset_View, self).__init__()
90 |
91 | # Save input
92 | self.dataset_path = cfg.DATASETS.TRAIN
93 | self.frame_num = cfg.DATASETS.FRAME_NUM
94 | self.layer_num = cfg.DATASETS.LAYER_NUM
95 | self.frame_offset = cfg.DATASETS.FRAME_OFFSET
96 |
97 | self.pose_refinement = cfg.MODEL.POSE_REFINEMENT
98 | self.use_deform_view = cfg.MODEL.USE_DEFORM_VIEW
99 | self.use_deform_time = cfg.MODEL.USE_DEFORM_TIME
100 | self.use_space_time = cfg.MODEL.USE_SPACE_TIME
101 | remove_outliers = cfg.MODEL.REMOVE_OUTLIERS
102 |
103 | self.transform = transform
104 |
105 | self.layer_frame_datasets = []
106 | for layer_id in range(self.layer_num+1):
107 | datasets_layer = []
108 | for frame_id in range(1+self.frame_offset,self.frame_offset+self.frame_num+1):
109 | dataset_frame_layer = FrameLayerDataset(cfg, transform, frame_id, layer_id)
110 | datasets_layer.append(dataset_frame_layer)
111 | self.layer_frame_datasets.append(datasets_layer)
112 | self.camera_num = self.layer_frame_datasets[0][0].cam_num
113 |
114 | def __len__(self):
115 | return 1
116 |
117 | def get_fixed_image(self, index_view, index_frame):
118 |
119 | print(index_view)
120 | print(index_frame)
121 |
122 | bboxes = []
123 | K = None
124 | T = None
125 | label = None
126 | image = None
127 | for i in range(self.layer_num+1):
128 | image_tmp, label_tmp, K_tmp, T_tmp, _, bbox, near_far = self.layer_frame_datasets[i][index_frame].get_data(index_view)
129 | if K is None:
130 | K = K_tmp
131 | if T is None:
132 | T = T_tmp
133 | if label is None:
134 | label = label_tmp
135 | if image is None:
136 | image = image_tmp
137 | bboxes.append(bbox)
138 |
139 | rays, labels, rgbs, ray_mask, layered_bboxes = ray_sampling_label_bbox(image,label,K,T,bboxes=bboxes)
140 | # rays,rgbs = ray_sampling(K.unsqueeze(0), T.unsqueeze(0), (image.size(1),image.size(2)), images = image.unsqueeze(0) )
141 | if self.pose_refinement:
142 | rays_o, rays_d = rays[:, :3], rays[:, 3:6]
143 | ids=torch.ones((rays_o.size(0),1))*index
144 | rays=torch.cat([rays_o,ids,rays_d,ids],dim = 1)
145 |
146 | if self.use_deform_view:
147 | camera_ids=torch.ones((rays.size(0),1)) * index
148 | rays=torch.cat([rays, camera_ids],dim=-1)
149 |
150 | if self.use_deform_time or self.use_space_time:
151 | frame_ids = torch.Tensor([index_frame+self.frame_offset+1]).reshape(1,1).repeat(rays.shape[0],1)
152 | rays=torch.cat([rays, frame_ids],dim=-1)
153 |
154 | return rays, rgbs, labels, image, label, ray_mask, layered_bboxes, near_far.repeat(rays.size(0),1)
155 |
156 | def __getitem__(self, index):
157 |
158 | index_frame = np.random.randint(0,self.frame_num)
159 | index_view = np.random.randint(0,self.camera_num)
160 | _, _, _, _, _, _, _, mask = self.layer_frame_datasets[0][index_frame].get_data(index_view)
161 | while (mask == 0):
162 |
163 | index_view = np.random.randint(0,self.camera_num)
164 | _, _, _, _, _, _, _, mask = self.layer_frame_datasets[0][index_frame].get_data(index_view)
165 |
166 | print(index_view)
167 | print(index_frame)
168 |
169 | bboxes = []
170 | K = None
171 | T = None
172 | label = None
173 | image = None
174 | for i in range(self.layer_num+1):
175 | image_tmp, label_tmp, K_tmp, T_tmp, _, bbox, near_far, _ = self.layer_frame_datasets[i][index_frame].get_data(index_view)
176 | if K is None:
177 | K = K_tmp
178 | if T is None:
179 | T = T_tmp
180 | if label is None:
181 | label = label_tmp
182 | if image is None:
183 | image = image_tmp
184 | bboxes.append(bbox)
185 |
186 | rays, labels, rgbs, ray_mask, layered_bboxes = ray_sampling_label_bbox(image,label,K,T,bboxes=bboxes)
187 | # rays,rgbs = ray_sampling(K.unsqueeze(0), T.unsqueeze(0), (image.size(1),image.size(2)), images = image.unsqueeze(0) )
188 | if self.pose_refinement:
189 | rays_o, rays_d = rays[:, :3], rays[:, 3:6]
190 | ids=torch.ones((rays_o.size(0),1))*index
191 | rays=torch.cat([rays_o,ids,rays_d,ids],dim = 1)
192 |
193 | if self.use_deform_view:
194 | camera_ids=torch.ones((rays.size(0),1)) * index
195 | rays=torch.cat([rays, camera_ids],dim=-1)
196 |
197 | if self.use_deform_time or self.use_space_time:
198 | frame_ids = torch.Tensor([index_frame+self.frame_offset+1]).reshape(1,1).repeat(rays.shape[0],1)
199 | rays=torch.cat([rays, frame_ids],dim=-1)
200 |
201 | return rays, rgbs, labels, image, label, ray_mask, layered_bboxes, near_far.repeat(rays.size(0),1)
202 |
203 | class Ray_Dataset_Render(torch.utils.data.Dataset):
204 |
205 | def __init__(self, cfg, transform):
206 |
207 | super(Ray_Dataset_Render, self).__init__()
208 |
209 | # Save input
210 | self.use_deform_time = cfg.MODEL.USE_DEFORM_TIME
211 | self.use_space_time = cfg.MODEL.USE_SPACE_TIME
212 |
213 | frame_offset = cfg.DATASETS.FRAME_OFFSET
214 | layer_num = cfg.DATASETS.LAYER_NUM
215 | frame_num = cfg.DATASETS.FRAME_NUM
216 |
217 | self.layer_num = layer_num
218 |
219 | self.datasets = []
220 |
221 | self.bboxes = torch.zeros(frame_num+frame_offset, layer_num, 8, 3)
222 |
223 |
224 | for layer_id in range(layer_num+1):
225 | datasets_layer = []
226 | for frame_id in range(1+frame_offset,frame_offset+frame_num+1):
227 | dataset_frame_layer = FrameLayerDataset(cfg, transform, frame_id, layer_id)
228 | datasets_layer.append(dataset_frame_layer)
229 | if layer_id != 0:
230 | self.bboxes[frame_id-1, layer_id-1] = dataset_frame_layer.bbox
231 | self.datasets.append(datasets_layer)
232 |
233 | self.camera_num = self.datasets[0][0].cam_num
234 | self.poses = self.datasets[0][0].Ts
235 |
236 | # Default layer size is original size
237 | self.Ks = self.datasets[0][0].Ks
238 | col, row = self.datasets[0][0].get_original_size()
239 | self.Ks[:,0,0] = self.Ks[:,0,0] * cfg.INPUT.SIZE_TEST[0] / col
240 | self.Ks[:,1,1] = self.Ks[:,1,1] * cfg.INPUT.SIZE_TEST[0] / col
241 | self.Ks[:,0,2] = self.Ks[:,0,2] * cfg.INPUT.SIZE_TEST[0] / col
242 | self.Ks[:,1,2] = self.Ks[:,1,2] * cfg.INPUT.SIZE_TEST[0] / col
243 |
244 | # Use original image size, intrinsic and bbox
245 | image, _, self.K, _, _, _, _, _ = self.datasets[0][0].get_data(0)
246 |
247 | # for i in range(len(self.datasets[0][0])):
248 | # _, _, K, _, _, _, _, _ = self.datasets[0][0].get_data(i)
249 | # self.Ks.append(K)
250 |
251 |
252 | _, self.height, self.width = image.shape
253 |
254 | self.near_far = torch.Tensor([cfg.DATASETS.FIXED_NEAR,cfg.DATASETS.FIXED_FAR]).reshape(1,2)
255 |
256 | def get_image_label(self, camera_id, frame_id):
257 | image, label, _, _, _, _, _, _ = self.datasets[frame_id][0].get_data(camera_id)
258 | return image, label
259 |
260 | def get_rays_by_pose_and_K(self, T, K, layer_frame_pair):
261 |
262 | T = torch.Tensor(T)
263 | rays, _ = generate_rays(K, T, None, self.height, self.width)
264 |
265 | #TODO: now bbox and near far is no use
266 | near_fars = self.near_far.repeat(rays.size(0),1)
267 | bboxes = torch.zeros(rays.size(0),8,3)
268 | labels = torch.zeros(rays.size(0))
269 |
270 | # bboxes = []
271 | # for layer_id, frame_id in layer_frame_pair:
272 | # bboxes.append(self.bboxes[frame_id-1,layer_id])
273 |
274 | # rays,rgbs = ray_sampling(K.unsqueeze(0), T.unsqueeze(0), (image.size(1),image.size(2)), images = image.unsqueeze(0) )
275 |
276 | if self.use_deform_time or self.use_space_time:
277 | frame_ids = torch.zeros(rays.size(0),self.layer_num+1)
278 | for layer_id, frame_id in layer_frame_pair:
279 | frame_ids[:,layer_id] = frame_id
280 |
281 | rays=torch.cat([rays, frame_ids],dim=-1)
282 |
283 | return rays, labels, bboxes, near_fars
284 | #Use the first K of the dataset by default
285 | def get_rays_by_pose(self, T, layer_frame_pair):
286 |
287 | T = torch.Tensor(T)
288 | rays, _ = generate_rays(self.K, T, None, self.height, self.width)
289 |
290 | #TODO: now bbox and near far is no use
291 | near_fars = self.near_far.repeat(rays.size(0),1)
292 | bboxes = torch.zeros(rays.size(0),8,3)
293 | labels = torch.zeros(rays.size(0))
294 |
295 | # bboxes = []
296 | # for layer_id, frame_id in layer_frame_pair:
297 | # bboxes.append(self.bboxes[frame_id-1,layer_id])
298 |
299 | # rays,rgbs = ray_sampling(K.unsqueeze(0), T.unsqueeze(0), (image.size(1),image.size(2)), images = image.unsqueeze(0) )
300 |
301 | if self.use_deform_time:
302 | frame_ids = torch.zeros(rays.size(0),self.layer_num+1)
303 | for layer_id, frame_id in layer_frame_pair:
304 | frame_ids[:,layer_id] = frame_id
305 |
306 | rays=torch.cat([rays, frame_ids],dim=-1)
307 |
308 | return rays, labels, bboxes, near_fars
309 |
310 | def get_rays_by_lookat(self,eye,center,up, layer_frame_pair):
311 |
312 | T = torch.Tensor(lookat(eye,center,up))
313 | return self.get_rays_by_pose(T, layer_frame_pair)
314 |
315 | def get_rays_by_spherical(self, theta, phi, radius,offsets, up, layer_frame_pair):
316 | up = np.array(up)
317 | offsets = np.array(offsets)
318 |
319 | pos = getSphericalPosition(radius,theta,phi)
320 | pos += self.center
321 | pos += offsets
322 | T = torch.Tensor(lookat(pos,self.center,up))
323 |
324 | return self.get_rays_by_pose(T, layer_frame_pair)
325 |
326 | def get_pose_by_lookat(self, eye,center,up):
327 | return torch.Tensor(lookat(eye,center,up))
328 |
329 | def get_pose_by_spherical(self, theta, phi, radius, offsets, up):
330 | up = np.array(up)
331 | offsets = np.array(offsets)
332 |
333 | pos = getSphericalPosition(radius,theta,phi)
334 | pos += self.center
335 | pos += offsets
336 | T = torch.Tensor(lookat(pos,self.center,up))
337 | return T
338 |
339 | class Ray_Frame_Layer_Dataset(torch.utils.data.Dataset):
340 |
341 | def __init__(self, cfg, transform, frame_id, layer_id, use_label_map, sample_rate):
342 |
343 | super(Ray_Frame_Layer_Dataset, self).__init__()
344 |
345 |
346 | # Save input
347 | self.dataset_path = cfg.DATASETS.TRAIN
348 | self.tmp_rays = cfg.DATASETS.TMP_RAYS
349 | self.camera_stepsize = cfg.DATASETS.CAMERA_STEPSIZE
350 |
351 | self.pose_refinement = cfg.MODEL.POSE_REFINEMENT
352 | self.use_deform_view = cfg.MODEL.USE_DEFORM_VIEW
353 | self.use_deform_time = cfg.MODEL.USE_DEFORM_TIME
354 | self.use_space_time = cfg.MODEL.USE_SPACE_TIME
355 |
356 | self.transform = transform
357 | self.frame_id = frame_id
358 | self.layer_id = layer_id
359 |
360 | # Generate Frame Dataset
361 | self.frame_dataset = FrameLayerDataset(cfg, transform, frame_id, layer_id)
362 | self.camera_num = self.frame_dataset.cam_num
363 | # Save layered rays, rgbs, labels, bboxs, near_fars
364 | self.layer_rays = []
365 | self.layer_rgbs = []
366 | self.layer_labels = []
367 | if self.frame_dataset.bbox != None:
368 | self.layer_bbox = self.frame_dataset.bbox
369 | else:
370 | self.layer_bbox = torch.zeros(8,3)
371 | self.near_fars = []
372 |
373 | # Check if we already generate rays
374 | tmp_ray_path = os.path.join(self.dataset_path,self.tmp_rays,'frame'+str(frame_id))
375 | if not os.path.exists(tmp_ray_path):
376 | print('There is no rays generated before, generating rays...')
377 | os.makedirs(tmp_ray_path)
378 |
379 | # tranverse every camera
380 | tmp_layer_ray_path = os.path.join(tmp_ray_path,'layer'+str(layer_id))
381 | if sample_rate == 0.0:
382 | print('Skiping layer %d, frame %d rays for zero sample rate...' % (layer_id, frame_id))
383 | self.layer_rays = torch.tensor([])
384 | self.layer_rgbs = torch.tensor([])
385 | self.layer_labels = torch.tensor([])
386 | self.near_fars = torch.tensor([])
387 | elif not os.path.exists(tmp_layer_ray_path) or cfg.clean_ray:
388 | rays_tmp = []
389 | rgbs_tmp = []
390 | labels_tmp = []
391 | near_fars_tmp = []
392 | print('There is no rays generated for layer %d, frame %d before, generating rays...' % (layer_id, frame_id))
393 | for i in range(0,self.frame_dataset.cam_num,self.camera_stepsize):
394 | print('Generating Layer %d, Camera %d rays...'% (layer_id,i))
395 |
396 | image, label, K, T, ROI, bbox, near_far, mask = self.frame_dataset.get_data(i)
397 |
398 | if not mask:
399 | print('Skiping Camera %d by mask'% (i))
400 | continue
401 |
402 | if not use_label_map:
403 | rays, labels, rgbs, _ = ray_sampling_label_bbox(image,label,K,T,bbox)
404 | else:
405 | rays, labels, rgbs, _ = ray_sampling_label_label(image,label,K,T,layer_id)
406 |
407 | if self.pose_refinement:
408 | rays_o, rays_d = rays[:, :3], rays[:, 3:6]
409 | ids=torch.ones((rays_o.size(0),1))*i
410 | rays=torch.cat([rays_o,ids,rays_d,ids],dim = 1)
411 |
412 | if self.use_deform_view:
413 | camera_ids=torch.ones((rays.size(0),1))*i
414 | rays=torch.cat([rays, camera_ids],dim=-1)
415 |
416 | if self.use_deform_time or self.use_space_time:
417 | frame_ids = torch.Tensor([frame_id]).reshape(1,1).repeat(rays.shape[0],1)
418 | rays=torch.cat([rays, frame_ids],dim=-1)
419 |
420 | near_fars_tmp.append(near_far.repeat(rays.size(0),1))
421 | rays_tmp.append(rays)
422 | rgbs_tmp.append(rgbs)
423 | labels_tmp.append(labels)
424 |
425 | self.layer_rays = torch.cat(rays_tmp,0)
426 | self.layer_rgbs = torch.cat(rgbs_tmp,0)
427 | self.layer_labels = torch.cat(labels_tmp,0)
428 | self.near_fars = torch.cat(near_fars_tmp,0)
429 | if sample_rate != 1:
430 | rand_idx = torch.randperm(self.layer_rays.size(0))
431 | self.layer_rays = self.layer_rays[rand_idx]
432 | self.layer_rgbs = self.layer_rgbs[rand_idx]
433 | self.layer_labels = self.layer_labels[rand_idx]
434 | self.near_fars = self.near_fars[rand_idx]
435 | end = int(self.layer_rays.size(0) * sample_rate)
436 | self.layer_rays = self.layer_rays[:end,:].clone().detach()
437 | self.layer_rgbs = self.layer_rgbs[:end,:].clone().detach()
438 | self.layer_labels = self.layer_labels[:end,:].clone().detach()
439 | self.near_fars = self.near_fars[:end,:].clone().detach()
440 | if not os.path.exists(tmp_layer_ray_path):
441 | os.mkdir(tmp_layer_ray_path)
442 | torch.save(self.layer_rays, os.path.join(tmp_layer_ray_path,'rays.pt'))
443 | torch.save(self.layer_rgbs, os.path.join(tmp_layer_ray_path,'rgbs.pt'))
444 | torch.save(self.layer_labels, os.path.join(tmp_layer_ray_path,'labels.pt'))
445 | torch.save(self.near_fars, os.path.join(tmp_layer_ray_path,'near_fars.pt'))
446 | else:
447 | print('There are rays generated for layer %d, frame %d before, loading rays...' % (layer_id, frame_id))
448 | self.layer_rays = torch.load(os.path.join(tmp_layer_ray_path,'rays.pt'),map_location='cpu')
449 | self.layer_rgbs = torch.load(os.path.join(tmp_layer_ray_path,'rgbs.pt'),map_location='cpu')
450 | self.layer_labels = torch.load(os.path.join(tmp_layer_ray_path,'labels.pt'),map_location='cpu')
451 | self.near_fars = torch.load(os.path.join(tmp_layer_ray_path,'near_fars.pt'),map_location='cpu')
452 |
453 | # Fix to the layer id
454 | self.layer_bbox_labels = self.layer_id * torch.ones_like(self.layer_labels)
455 | print('Generating %d rays' % self.layer_rays.shape[0])
456 | def __len__(self):
457 | return self.layer_rays.shape[0]
458 |
459 | def __getitem__(self, index):
460 | return self.layer_rays[index,:], self.layer_rgbs[index,:], self.layer_labels[index,:], self.layer_bbox_labels[index,:], self.layer_bbox[0], self.near_fars[index,:]
--------------------------------------------------------------------------------
/data/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import glob
3 | import os
4 |
5 |
6 | def campose_to_extrinsic(camposes):
7 | if camposes.shape[1]!=12:
8 | raise Exception(" wrong campose data structure!")
9 |
10 | res = np.zeros((camposes.shape[0],4,4))
11 |
12 | res[:,0,:] = camposes[:,0:4]
13 | res[:,1,:] = camposes[:,4:8]
14 | res[:,2,:] = camposes[:,8:12]
15 | res[:,3,3] = 1.0
16 |
17 | return res
18 |
19 |
20 | def read_intrinsics(fn_instrinsic):
21 | fo = open(fn_instrinsic)
22 | data= fo.readlines()
23 | i = 0
24 | Ks = []
25 | while i max_iter:
57 | max_iter = temp
58 | if not os.path.exists(os.path.join(root_dir,'layered_rfnr_checkpoint_%d.pt' % max_iter)):
59 | return None
60 | return os.path.join(root_dir,'layered_rfnr_checkpoint_%d.pt' % max_iter)
61 |
62 | def get_iteration_path_and_iter(root_dir, fix_iter = -1):
63 | if fix_iter != -1:
64 | return os.path.join(root_dir,'frame','layered_rfnr_checkpoint_%d.pt' % fix_iter)
65 |
66 | if not os.path.exists(root_dir):
67 | return None
68 | file_names = glob.glob(os.path.join(root_dir,'layered_rfnr_checkpoint_*.pt'))
69 | max_iter = -1
70 | for file_name in file_names:
71 | num_name = file_name.split('_')[-1]
72 | temp = int(num_name.split('.')[0])
73 | if temp > max_iter:
74 | max_iter = temp
75 | if not os.path.exists(os.path.join(root_dir,'layered_rfnr_checkpoint_%d.pt' % max_iter)):
76 | return None
77 | return os.path.join(root_dir,'layered_rfnr_checkpoint_%d.pt' % max_iter), max_iter
78 |
79 |
80 | def read_mask(path):
81 | fo = open(path)
82 | data= fo.readlines()
83 | mask = []
84 | for i in range(len(data)):
85 | tmp = int(data[i])
86 | mask.append(tmp)
87 | mask = np.array(mask)
88 | fo.close()
89 |
90 | return mask
--------------------------------------------------------------------------------
/data/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: Minye Wu
4 | @GITHUB: wuminye
5 | """
6 |
7 | from .build import build_transforms,build_layered_transforms
8 |
--------------------------------------------------------------------------------
/data/transforms/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/transforms/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/transforms/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/transforms/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/data/transforms/__pycache__/build.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/transforms/__pycache__/build.cpython-36.pyc
--------------------------------------------------------------------------------
/data/transforms/__pycache__/build.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/transforms/__pycache__/build.cpython-38.pyc
--------------------------------------------------------------------------------
/data/transforms/__pycache__/random_transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/transforms/__pycache__/random_transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/data/transforms/__pycache__/random_transforms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/data/transforms/__pycache__/random_transforms.cpython-38.pyc
--------------------------------------------------------------------------------
/data/transforms/build.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torchvision.transforms as T
8 |
9 | from .random_transforms import Random_Transforms
10 |
11 |
12 | def build_transforms(cfg, is_train=True):
13 | normalize_transform = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
14 |
15 | if is_train:
16 |
17 | transform = Random_Transforms((cfg.INPUT.SIZE_TRAIN[1], cfg.INPUT.SIZE_TRAIN[0]),cfg.DATASETS.SHIFT, cfg.DATASETS.MAXRATION,cfg.DATASETS.ROTATION)
18 | #transform = T.Compose([
19 | # T.Resize((cfg.INPUT.SIZE_TRAIN[1], cfg.INPUT.SIZE_TRAIN[0])),
20 | # T.ToTensor()
21 | #])
22 | else:
23 | transform = Random_Transforms((cfg.INPUT.SIZE_TEST[1], cfg.INPUT.SIZE_TEST[0]),0, isTrain = is_train)
24 | #transform = T.Compose([
25 | # T.Resize((cfg.INPUT.SIZE_TEST[1], cfg.INPUT.SIZE_TEST[0])),
26 | # T.ToTensor()
27 | #])
28 |
29 |
30 | return transform
31 |
32 | def build_layered_transforms(cfg, is_layer=True, is_train=True):
33 | normalize_transform = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
34 |
35 | if is_train:
36 | if is_layer:
37 | transform = Random_Transforms((cfg.INPUT.SIZE_LAYER[1], cfg.INPUT.SIZE_LAYER[0]),cfg.DATASETS.SHIFT, cfg.DATASETS.MAXRATION,cfg.DATASETS.ROTATION)
38 | else:
39 | transform = Random_Transforms((cfg.INPUT.SIZE_TRAIN[1], cfg.INPUT.SIZE_TRAIN[0]),cfg.DATASETS.SHIFT, cfg.DATASETS.MAXRATION,cfg.DATASETS.ROTATION)
40 | else:
41 | transform = Random_Transforms((cfg.INPUT.SIZE_TEST[1], cfg.INPUT.SIZE_TEST[0]),0)
42 | #transform = T.Compose([
43 | # T.Resize((cfg.INPUT.SIZE_TEST[1], cfg.INPUT.SIZE_TEST[0])),
44 | # T.ToTensor()
45 | #])
46 |
47 |
48 | return transform
--------------------------------------------------------------------------------
/data/transforms/random_transforms.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torchvision.transforms as T
8 |
9 | from torch.utils import data
10 | import torch
11 |
12 | import numpy as np
13 |
14 | import random
15 | import PIL
16 | from PIL import Image
17 | import collections
18 | import math
19 | '''
20 | INPUT: mask is a (h,w) numpy array
21 | every pixel larger than 0 will be in count
22 | '''
23 | def calc_center(mask):
24 | grid = np.mgrid[0:mask.shape[0],0:mask.shape[1]]
25 | grid_mask = mask[grid[0],grid[1]].astype(np.bool)
26 | X = grid[0,grid_mask]
27 | Y = grid[1,grid_mask]
28 |
29 | return np.mean(X),np.mean(Y)
30 |
31 |
32 | def rodrigues_rotation_matrix(axis, theta):
33 | axis = np.asarray(axis)
34 | theta = np.asarray(theta)
35 | axis = axis/math.sqrt(np.dot(axis, axis))
36 | a = math.cos(theta/2.0)
37 | b, c, d = -axis*math.sin(theta/2.0)
38 | aa, bb, cc, dd = a*a, b*b, c*c, d*d
39 | bc, ad, ac, ab, bd, cd = b*c, a*d, a*c, a*b, b*d, c*d
40 | return np.array([[aa+bb-cc-dd, 2*(bc+ad), 2*(bd-ac)],
41 | [2*(bc-ad), aa+cc-bb-dd, 2*(cd+ab)],
42 | [2*(bd+ac), 2*(cd-ab), aa+dd-bb-cc]])
43 |
44 |
45 | class Random_Transforms(object):
46 | def __init__(self, size, random_range = 0, random_ration = 0, random_rotation = 0,interpolation=Image.BICUBIC, isTrain = True, is_center = False):
47 | assert isinstance(size, int) or (isinstance(size, collections.abc.Iterable) and len(size) == 2)
48 | self.size = size
49 | self.interpolation = interpolation
50 | self.random_range = random_range
51 | self.random_scale = random_ration
52 | self.isTrain = isTrain
53 | self.random_rotation = random_rotation
54 | self.is_center = is_center
55 |
56 | def __call__(self, img, Ks = None, Ts = None, mask = None, label = None):
57 |
58 | K = Ks.clone()
59 | Tc = Ts.clone()
60 | img_np = np.asarray(img)
61 |
62 | offset = random.randint(-self.random_range,self.random_range)
63 | offset2 = random.randint(-self.random_range,self.random_range)
64 |
65 | rotation = (random.random()-0.5)*np.deg2rad(self.random_rotation)
66 | ration = random.random()*self.random_scale + 1.0
67 |
68 | width, height = img.size
69 |
70 | R = torch.Tensor(rodrigues_rotation_matrix(np.array([0,0,1]),rotation))
71 |
72 |
73 | Tc[0:3,0:3] = torch.matmul(Tc[0:3,0:3],R)
74 |
75 | m_scale = height/self.size[0]
76 |
77 | cx, cy = 0, 0
78 |
79 | if mask is not None and self.isTrain:
80 | mask_np = np.asarray(mask)
81 | if mask_np.ndim == 3:
82 | mask_np = mask_np[:,:,0]
83 | cy, cx = calc_center(mask_np)
84 |
85 | cx = cx - width /2
86 | cy = cy - height/2
87 |
88 |
89 |
90 | translation = (offset*m_scale-cx,offset2*m_scale-cy )
91 |
92 | if self.is_center:
93 | translation = [width /2-K[0,2],height/2-K[1,2]]
94 | translation = list(translation)
95 | ration = 1.05
96 |
97 | if (self.size[1]/2)/(self.size[0]*ration / height) - K[0,2] != translation[0] :
98 | ration = 1.2
99 | translation[1] = (self.size[0]/2)/(self.size[0]*ration / height) - K[1,2]
100 | translation[0] = (self.size[1]/2)/(self.size[0]*ration / height) - K[0,2]
101 | translation = tuple(translation)
102 |
103 | #translation = (width /2-K[0,2],height/2-K[1,2])
104 |
105 |
106 | img = T.functional.rotate(img, angle = np.rad2deg(rotation), resample = Image.BICUBIC, center =(K[0,2],K[1,2]))
107 | img = T.functional.affine(img, angle = 0, translate = translation, scale= 1,shear=0)
108 | img = T.functional.crop(img, 0, 0, int(height/ration),int(height*self.size[1]/ration/self.size[0]) )
109 | img = T.functional.resize(img, self.size, self.interpolation)
110 | img = T.functional.to_tensor(img)
111 |
112 |
113 | ROI = np.ones_like(img_np)*255.0
114 |
115 | ROI = Image.fromarray(np.uint8(ROI))
116 | ROI = T.functional.rotate(ROI, angle = np.rad2deg(rotation), resample = Image.BICUBIC, center =(K[0,2],K[1,2]))
117 | ROI = T.functional.affine(ROI, angle = 0, translate = translation, scale= 1,shear=0)
118 | ROI = T.functional.crop(ROI, 0,0, int(height/ration),int(height*self.size[1]/ration/self.size[0]) )
119 | ROI = T.functional.resize(ROI, self.size, self.interpolation)
120 | ROI = T.functional.to_tensor(ROI)
121 | ROI = ROI[0:1,:,:]
122 |
123 |
124 |
125 | if mask is not None:
126 | mask = T.functional.rotate(mask, angle = np.rad2deg(rotation), resample = Image.BICUBIC, center =(K[0,2],K[1,2]))
127 | mask = T.functional.affine(mask, angle = 0, translate = translation, scale= 1,shear=0)
128 | mask = T.functional.crop(mask, 0, 0, int(height/ration),int(height*self.size[1]/ration/self.size[0]) )
129 | mask = T.functional.resize(mask, self.size, self.interpolation)
130 | mask = T.functional.to_tensor(mask)
131 |
132 |
133 | if label is not None:
134 | label = Image.fromarray(np.uint8(label))
135 | label = T.functional.rotate(label, angle = np.rad2deg(rotation), resample = Image.BICUBIC, center =(K[0,2],K[1,2]))
136 | label = T.functional.affine(label, angle = 0, translate = translation, scale= 1,shear=0)
137 | label = T.functional.crop(label, 0,0, int(height/ration),int(height*self.size[1]/ration/self.size[0]) )
138 | label = T.functional.resize(label, self.size, self.interpolation)
139 | label = T.functional.to_tensor(label)
140 | label = label * 255.0
141 |
142 |
143 |
144 |
145 | #K = K / m_scale
146 | #K[2,2] = 1
147 |
148 |
149 | K[0,2] = K[0,2] + translation[0]
150 | K[1,2] = K[1,2] + translation[1]
151 |
152 | s = self.size[0] * ration / height
153 |
154 | K = K*s
155 |
156 | K[2,2] = 1
157 | #print(img.size(),mask.size(),ROI.size())
158 |
159 |
160 | if label is None:
161 | return img, K, Tc, mask, ROI
162 | else:
163 | return img, label, K, Tc, ROI
164 |
165 | def __repr__(self):
166 | return self.__class__.__name__ + '()'
167 |
168 |
169 |
170 |
171 |
172 |
--------------------------------------------------------------------------------
/demo/taekwondo_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | os.environ['PYOPENGL_PLATFORM'] = 'egl'
4 | import sys
5 | from os import mkdir
6 | import shutil
7 | import torch
8 | import torch.nn.functional as F
9 | import random
10 | from torchvision import utils as vutils
11 | import numpy as np
12 | import imageio
13 | import matplotlib.pyplot as plt
14 |
15 | sys.path.append('.')
16 | from config import cfg
17 | from engine.layered_trainer import do_train
18 | from solver import make_optimizer, WarmupMultiStepLR,build_scheduler
19 | from layers import make_loss
20 | from utils.logger import setup_logger
21 | from layers.RaySamplePoint import RaySamplePoint
22 | from utils import batchify_ray, vis_density
23 | from render import LayeredNeuralRenderer
24 |
25 | text = 'This is the program to render the nerf by the specific frame id and layer id, try to get help by using '
26 | parser = argparse.ArgumentParser(description=text)
27 | parser.add_argument('-c', '--config', default='', help='set the config file path to render the network')
28 | parser.add_argument('-g','--gpu', type=int, default=0, help='set gpu id to render the network')
29 | args = parser.parse_args()
30 |
31 | torch.cuda.set_device(args.gpu)
32 | torch.autograd.set_detect_anomaly(True)
33 | torch.set_default_dtype(torch.float32)
34 |
35 |
36 | cfg.merge_from_file(args.config)
37 | cfg.freeze()
38 |
39 | neural_renderer = LayeredNeuralRenderer(cfg)
40 |
41 | key_frames_layer_1 = [21,49,74,87] # performer 1 time line
42 | key_frames_layer_2 = [13,42,80,90] # performer 2 time line
43 | key_frames = [20,50,74,85] # new time line
44 | density_threshold = 0 # Can be set to higher to hide glass
45 | inverse_y_axis = False # For some y-inversed model
46 | neural_renderer = LayeredNeuralRenderer(cfg)
47 | neural_renderer.set_save_dir('origin')
48 | neural_renderer.retime_by_key_frames(1, key_frames_layer_1, key_frames)
49 | neural_renderer.retime_by_key_frames(2, key_frames_layer_2, key_frames)
50 | neural_renderer.set_fps(25)
51 | neural_renderer.set_smooth_path_poses(101, around=False)
52 | neural_renderer.render_path(inverse_y_axis,density_threshold,auto_save=True)
53 | neural_renderer.save_video()
54 |
55 | neural_renderer = LayeredNeuralRenderer(cfg, shift=[[0,0,0],[0,2,0],[0,-2,0]])
56 | neural_renderer.set_save_dir('shift')
57 | neural_renderer.retime_by_key_frames(1, key_frames_layer_1, key_frames)
58 | neural_renderer.retime_by_key_frames(2, key_frames_layer_2, key_frames)
59 | neural_renderer.set_fps(25)
60 | neural_renderer.set_smooth_path_poses(101, around=False)
61 | neural_renderer.render_path(inverse_y_axis,density_threshold,auto_save=True)
62 | neural_renderer.save_video()
63 |
64 |
65 | neural_renderer = LayeredNeuralRenderer(cfg, scale=[1,0.75,1.5])
66 | neural_renderer.set_save_dir('scale')
67 | neural_renderer.retime_by_key_frames(1, key_frames_layer_1, key_frames)
68 | neural_renderer.retime_by_key_frames(2, key_frames_layer_2, key_frames)
69 | neural_renderer.set_fps(25)
70 | neural_renderer.set_smooth_path_poses(101, around=False)
71 | neural_renderer.render_path(inverse_y_axis,density_threshold,auto_save=True)
72 | neural_renderer.save_video()
73 |
74 |
--------------------------------------------------------------------------------
/demo/walking_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | os.environ['PYOPENGL_PLATFORM'] = 'egl'
4 | import sys
5 | from os import mkdir
6 | import shutil
7 | import torch
8 | import torch.nn.functional as F
9 | import random
10 | from torchvision import utils as vutils
11 | import numpy as np
12 | import imageio
13 | import matplotlib.pyplot as plt
14 |
15 | sys.path.append('..')
16 | from config import cfg
17 | from engine.layered_trainer import do_train
18 | from modeling import build_model
19 | from solver import make_optimizer, WarmupMultiStepLR,build_scheduler
20 | from layers import make_loss
21 | from utils.logger import setup_logger
22 | from layers.RaySamplePoint import RaySamplePoint
23 | from utils import batchify_ray, vis_density
24 | from render import LayeredNeuralRenderer
25 |
26 | text = 'This is the program to render the nerf by the specific frame id and layer id, try to get help by using '
27 | parser = argparse.ArgumentParser(description=text)
28 | parser.add_argument('-c', '--config', default='', help='set the config file path to render the network')
29 | parser.add_argument('-g','--gpu', type=int, default=0, help='set gpu id to render the network')
30 | args = parser.parse_args()
31 |
32 | torch.cuda.set_device(args.gpu)
33 | torch.autograd.set_detect_anomaly(True)
34 | torch.set_default_dtype(torch.float32)
35 |
36 |
37 | cfg.merge_from_file(args.config)
38 | cfg.freeze()
39 |
40 | neural_renderer = LayeredNeuralRenderer(cfg)
41 |
42 |
43 | density_threshold = 20 # Can be set to higher to hide glass
44 | bkgd_density_threshold = 0.8
45 | inverse_y_axis = False # For some y-inversed model
46 |
47 | neural_renderer.set_fps(25)
48 | neural_renderer.set_pose_duration(1,14) # [ min , max )
49 | neural_renderer.set_smooth_path_poses(100, around=False)
50 | neural_renderer.set_near(4)
51 | neural_renderer.invert_poses()
52 |
53 |
54 | neural_renderer.set_save_dir("origin")
55 | neural_renderer.render_path(inverse_y_axis,density_threshold,bkgd_density_threshold,auto_save=True)
56 | neural_renderer.save_video()
57 |
58 |
59 | neural_renderer.hide_layer(1)
60 | neural_renderer.set_save_dir("hide_man_1")
61 | neural_renderer.render_path(inverse_y_axis,density_threshold,bkgd_density_threshold,auto_save=True)
62 | neural_renderer.save_video()
63 |
64 |
65 | neural_renderer.hide_layer(2)
66 | neural_renderer.set_save_dir("hide_both")
67 | neural_renderer.render_path(inverse_y_axis,density_threshold,bkgd_density_threshold,auto_save=True)
68 | neural_renderer.save_video()
69 |
--------------------------------------------------------------------------------
/engine/__init__.py:
--------------------------------------------------------------------------------
1 | from .render import render
--------------------------------------------------------------------------------
/engine/layered_trainer.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import logging
8 | import imageio
9 | import torch
10 |
11 | from utils import layered_batchify_ray, vis_density, metrics
12 | from utils.metrics import *
13 | import numpy as np
14 | import os
15 | import time
16 |
17 | def evaluator(val_dataset, model, loss_fn, swriter, epoch):
18 |
19 | model.eval()
20 | rays, rgbs, labels, image, label, mask, bbox, near_far = val_dataset[0]
21 |
22 | rays = rays.cuda()
23 | rgbs = rgbs.cuda()
24 | bbox = bbox.cuda()
25 | labels = labels.cuda()
26 | color_gt = image.cuda()
27 | mask = mask.cuda()
28 | near_far = near_far.cuda()
29 |
30 | # uv_list = (mask).squeeze().nonzero()
31 | # u_list = uv_list[:,0]
32 | # v_list = uv_list[:,1]
33 |
34 | with torch.no_grad():
35 | # TODO: Use mask to gain less query of space
36 | stage2, stage1, stage2_layer, stage1_layer, _ = layered_batchify_ray(model, rays, labels, bbox, near_far=near_far)
37 | for i in range(len(stage2_layer)):
38 | color_1 = stage2_layer[i][0]
39 | depth_1 = stage2_layer[i][1]
40 | acc_map_1 = stage2_layer[i][2]
41 | #print(color_1.shape)
42 | #print(depth_1.shape)
43 | #print(acc_map_1.shape)
44 |
45 | color_0 = stage1_layer[i][0]
46 | depth_0 = stage1_layer[i][1]
47 | acc_map_00 = stage1_layer[i][2]
48 |
49 |
50 | color_img = color_1.reshape( (color_gt.size(1), color_gt.size(2), 3) ).permute(2,0,1)
51 | depth_img = depth_1.reshape( (color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1)
52 | depth_img = (depth_img-depth_img.min())/(depth_img.max()-depth_img.min())
53 | acc_map = acc_map_1.reshape( (color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1)
54 |
55 |
56 | color_img_0 = color_0.reshape( (color_gt.size(1), color_gt.size(2), 3) ).permute(2,0,1)
57 | depth_img_0 = depth_0.reshape( (color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1)
58 | depth_img_0 = (depth_img_0-depth_img_0.min())/(depth_img_0.max()-depth_img_0.min())
59 | acc_map_0 = acc_map_00.reshape((color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1)
60 |
61 |
62 | depth_img = (depth_img-depth_img.min())/(depth_img.max()-depth_img.min())
63 | depth_img_0 = (depth_img_0-depth_img_0.min())/(depth_img_0.max()-depth_img_0.min())
64 |
65 | color_img = color_img*((mask).permute(2,0,1).repeat(3,1,1))
66 | color_gt = color_gt*((mask).permute(2,0,1).repeat(3,1,1))
67 |
68 | if i == 0:
69 | swriter.add_image('stage2_bkgd/rendered', color_img, epoch)
70 | swriter.add_image('stage2_bkgd/depth', depth_img, epoch)
71 | swriter.add_image('stage2_bkgd/alpha', acc_map, epoch)
72 |
73 | swriter.add_image('stage1_bkgd/rendered', color_img_0, epoch)
74 | swriter.add_image('stage1_bkgd/depth', depth_img_0, epoch)
75 | swriter.add_image('stage1_bkgd/alpha', acc_map_0, epoch)
76 |
77 | else:
78 | swriter.add_image('stage2_layer' +str(i)+ '/rendered', color_img, epoch)
79 | swriter.add_image('stage2_layer' +str(i)+ '/depth', depth_img, epoch)
80 | swriter.add_image('stage2_layer' +str(i)+ '/alpha', acc_map, epoch)
81 |
82 | swriter.add_image('stage1_layer' +str(i)+ '/rendered', color_img_0, epoch)
83 | swriter.add_image('stage1_layer' +str(i)+ '/depth', depth_img_0, epoch)
84 | swriter.add_image('stage1_layer' +str(i)+ '/alpha', acc_map_0, epoch)
85 |
86 |
87 | color_1 = stage2[0]
88 | depth_1 = stage2[1]
89 | acc_map_1 = stage2[2]
90 | #print(color_1.shape)
91 | #print(depth_1.shape)
92 | #print(acc_map_1.shape)
93 |
94 | color_0 = stage1[0]
95 | depth_0 = stage1[1]
96 | acc_map_00 = stage1[2]
97 |
98 |
99 | color_img = color_1.reshape( (color_gt.size(1), color_gt.size(2), 3) ).permute(2,0,1)
100 | depth_img = depth_1.reshape( (color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1)
101 | depth_img = (depth_img-depth_img.min())/(depth_img.max()-depth_img.min())
102 | acc_map = acc_map_1.reshape( (color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1)
103 |
104 |
105 | color_img_0 = color_0.reshape( (color_gt.size(1), color_gt.size(2), 3) ).permute(2,0,1)
106 | depth_img_0 = depth_0.reshape( (color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1)
107 | depth_img_0 = (depth_img_0-depth_img_0.min())/(depth_img_0.max()-depth_img_0.min())
108 | acc_map_0 = acc_map_00.reshape((color_gt.size(1), color_gt.size(2), 1) ).permute(2,0,1)
109 |
110 |
111 | depth_img = (depth_img-depth_img.min())/(depth_img.max()-depth_img.min())
112 | depth_img_0 = (depth_img_0-depth_img_0.min())/(depth_img_0.max()-depth_img_0.min())
113 |
114 | color_img = color_img*((mask).permute(2,0,1).repeat(3,1,1))
115 | color_gt = color_gt*((mask).permute(2,0,1).repeat(3,1,1))
116 |
117 |
118 | swriter.add_image('GT/Label', label * 50, epoch)
119 | swriter.add_image('GT/Image', color_gt, epoch)
120 |
121 | swriter.add_image('stage2/rendered', color_img, epoch)
122 | swriter.add_image('stage2/depth', depth_img, epoch)
123 | swriter.add_image('stage2/alpha', acc_map, epoch)
124 |
125 | swriter.add_image('stage1/rendered', color_img_0, epoch)
126 | swriter.add_image('stage1/depth', depth_img_0, epoch)
127 | swriter.add_image('stage1/alpha', acc_map_0, epoch)
128 |
129 |
130 | return loss_fn(color_img, color_gt).item()
131 |
132 |
133 | def do_train(
134 | cfg,
135 | model,
136 | train_loader,
137 | val_loader,
138 | optimizer,
139 | scheduler,
140 | loss_fn,
141 | swriter,
142 | resume_epoch = 0,
143 | psnr_thres = 100
144 | ):
145 | log_period = cfg.SOLVER.LOG_PERIOD
146 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
147 | output_dir = cfg.OUTPUT_DIR
148 | max_epochs = cfg.SOLVER.MAX_EPOCHS
149 | train_by_pointcloud = cfg.MODEL.TRAIN_BY_POINTCLOUD
150 | use_label = cfg.DATASETS.USE_LABEL
151 | coarse_stage = cfg.SOLVER.COARSE_STAGE
152 | remove_outliers = cfg.MODEL.REMOVE_OUTLIERS
153 |
154 |
155 | logger = logging.getLogger("LayeredRFRender.%s.train" % cfg.OUTPUT_DIR.split('/')[-1])
156 | logger.info("Start training")
157 | #global step
158 | global_step = 0
159 |
160 | torch.autograd.set_detect_anomaly(True)
161 |
162 |
163 | for epoch in range(1+resume_epoch,max_epochs):
164 | print('Training Epoch %d...' % epoch)
165 | model.cuda()
166 |
167 | #psnr monitor
168 | psnr_monitor = []
169 |
170 | #epoch time recordingbatchify_ray
171 | epoch_start = time.time()
172 | for batch_idx, batch in enumerate(train_loader):
173 |
174 | #iteration time recording
175 | iters_start = time.time()
176 | global_step = (epoch -1) * len(train_loader) + batch_idx
177 |
178 | model.train()
179 | optimizer.zero_grad()
180 |
181 | rays, rgbs, labels, bbox_labels, bboxes, near_far = batch
182 | bbox_labels = bbox_labels.cuda()
183 | labels = labels.cuda()
184 | rays = rays.cuda()
185 | rgbs = rgbs.cuda()
186 | bboxes = bboxes.cuda()
187 | near_far = near_far.cuda()
188 |
189 | loss = 0
190 |
191 | if epoch label-0.5)
199 | # print(torch.sum(outliers))
200 | # print(torch.sum(inliers))
201 | # out_threshold = 0.5
202 |
203 | predict_rgb_0 = stage1[0]
204 | predict_rgb_1 = stage2[0]
205 |
206 | # predict_rgb_0 = stage1_layer[0][labels.repeat(1,3) != 0]
207 | # predict_rgb_1 = stage2_layer[0][labels.repeat(1,3) != 0]
208 | # rgbs = rgbs[labels.repeat(1,3) != 0]
209 |
210 | # print('ray number is %d' % torch.sum(labels != 0))
211 |
212 |
213 | # print('layer ray number is %d, bbox layer ray number is %d, outlier number is %d, total number is %d' % (torch.sum(labels != 0), torch.sum(bbox_labels != 0), outliers_1.shape[0], rays.size(0)))
214 |
215 |
216 | loss1 = loss_fn(predict_rgb_0, rgbs)
217 | loss2 = loss_fn(predict_rgb_1, rgbs)
218 | if epoch < 3 and remove_outliers:
219 | outliers_1 = []
220 | outliers_2 = []
221 | inliers_1 =[]
222 | inliers_2 = []
223 | for i in range(len(stage1_layer)):
224 |
225 | if i != 0: #i!=3 for spiderman basket
226 | outliers_1.append(stage1_layer[i][2][labels == 0])
227 | outliers_2.append(stage2_layer[i][2][labels == 0])
228 | # else:
229 | # outliers_1.append(stage1_layer[i][2][labels == 0])
230 | # outliers_2.append(stage2_layer[i][2][labels == 0])
231 | inliers_1.append(stage1_layer[i][2][labels == i])
232 | inliers_2.append(stage2_layer[i][2][labels == i])
233 |
234 | if outliers_1 != []:
235 | outliers_1 = torch.cat(outliers_1,0)
236 | outliers_2 = torch.cat(outliers_2,0)
237 | inliers_1 = torch.cat(inliers_1,0)
238 | inliers_2 = torch.cat(inliers_2,0)
239 | # print('total ray number is ', stage2[1].shape, ', the inliers number is ',predict_rgb_1.shape)
240 | # loss1 = loss_fn(predict_rgb_0, rgbs)
241 | # loss2 = loss_fn(predict_rgb_1, rgbs)
242 |
243 | #TODO: 100000 should be adapted
244 | scalar_max = 100000
245 | scalar = scalar_max
246 | #penalty 100 will make mask be smaller, 20 will be better, try 10
247 | penalty = 1
248 | if outliers_1 != []:
249 | loss_mask_0 = torch.sum(torch.abs(outliers_1)) * penalty + torch.sum(torch.abs(1-inliers_1))
250 | loss_mask_1 = torch.sum(torch.abs(outliers_2)) * penalty + torch.sum(torch.abs(1-inliers_2))
251 | else:
252 | loss_mask_0 = torch.sum(torch.abs(1-inliers_1))
253 | loss_mask_1 = torch.sum(torch.abs(1-inliers_2))
254 |
255 | # while loss_mask_1 / scalar < rays.shape[0]/(scalar_max * 2) and loss_mask_1 > 1:
256 | # scalar /= 2
257 | # if scalar <= 1:
258 | # scalar = 1.0
259 | # break
260 |
261 | # num_ray_mask = torch.sum(ray_mask.view(1,-1)).item()
262 | # print('This batch has %d rays in bbox' % num_ray_mask)
263 |
264 | if loss_mask_0 > rays.shape[0] * 0.0005 and remove_outliers:
265 | loss_mask_0 = loss_mask_0 / scalar
266 | else:
267 | loss_mask_0 = torch.Tensor([0]).cuda()
268 |
269 | if loss_mask_1 > rays.shape[0] * 0.0005 and remove_outliers:
270 | loss_mask_1 = loss_mask_1 / scalar
271 | else:
272 | loss_mask_1 = torch.Tensor([0]).cuda()
273 | else:
274 | loss_mask_0 = torch.Tensor([0]).cuda()
275 | loss_mask_1 = torch.Tensor([0]).cuda()
276 |
277 |
278 | if epoch psnr_thres:
328 | logger.info("The Mean Psnr of Epoch: {:.3f}, greater than threshold: {:.3f}, Training Stopped".format(psnr_monitor, psnr_thres))
329 | break
330 | else:
331 | logger.info("The Mean Psnr of Epoch: {:.3f}, less than threshold: {:.3f}, Continue to Training".format(psnr_monitor, psnr_thres))
332 |
333 | def val_vis(val_loader,model ,loss_fn, swriter, logger, epoch):
334 |
335 |
336 | avg_loss = evaluator(val_loader, model, loss_fn, swriter,epoch)
337 | logger.info("Validation Results - Epoch: {} Avg Loss: {:.3f}"
338 | .format(epoch, avg_loss)
339 | )
340 | swriter.add_scalar('Loss/val_loss',avg_loss, epoch)
341 |
342 | def ModelCheckpoint(model, optimizer, scheduler, output_dir, epoch, global_step = 0):
343 | # model,optimizer,scheduler saving
344 | if not os.path.exists(output_dir):
345 | os.makedirs(output_dir)
346 | if global_step == 0:
347 | torch.save({'model':model.state_dict(),'optimizer':optimizer.state_dict(),'scheduler':scheduler.state_dict()},
348 | os.path.join(output_dir,'layered_rfnr_checkpoint_%d.pt' % epoch))
349 | else:
350 | torch.save({'model':model.state_dict(),'optimizer':optimizer.state_dict(),'scheduler':scheduler.state_dict()},
351 | os.path.join(output_dir,'layered_rfnr_checkpoint_%d_%d.pt' % (epoch,global_step)))
352 | # torch.save(model.state_dict(), os.path.join(output_dir, 'spacenet_epoch_%d.pth'%epoch))
353 | # torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer_epoch_%d.pth'%epoch))
354 | # torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler_epoch_%d.pth'%epoch))
355 |
356 |
357 | def do_evaluate( model,val_dataset):
358 | mae_list = []
359 | psnr_list = []
360 | ssim_list = []
361 |
362 |
363 | model.eval()
364 | with torch.no_grad():
365 | for i in range(2):
366 | for j in range(50):
367 | rays, rgbs, labels, image, label, mask, bbox, near_far = val_dataset.get_fixed_image(i,j+1)
368 |
369 | rays = rays.cuda()
370 | rgbs = rgbs.cuda()
371 | bbox = bbox.cuda()
372 | labels = labels.cuda()
373 | color_gt = image.cuda()
374 | mask = mask.cuda()
375 | near_far = near_far.cuda()
376 |
377 | # uv_list = (mask).squeeze().nonzero()
378 | # u_list = uv_list[:,0]
379 | # v_list = uv_list[:,1]
380 |
381 |
382 | # TODO: Use mask to gain less query of space
383 | stage2, _, _, _, _ = layered_batchify_ray(model, rays, labels, bbox, near_far=near_far)
384 |
385 | color_1 = stage2[0]
386 | depth_1 = stage2[1]
387 | acc_map_1 = stage2[2]
388 | #print(color_1.shape)
389 | #print(depth_1.shape)
390 | #print(acc_map_1.shape)
391 |
392 |
393 | color_img = color_1.reshape( (color_gt.size(1), color_gt.size(2), 3) ).permute(2,0,1)
394 |
395 | mae = metrics.mae(color_img,color_gt)
396 | psnr = metrics.psnr(color_img,color_gt)
397 | ssim = metrics.ssim(color_img,color_gt)
398 | print(color_img.shape)
399 | print(color_gt.shape)
400 |
401 | #imageio.imwrite("/new_disk/zhangjk/NeuralVolumeRender-dynamic/evaluation/walking/"+str(j+1)+".png", color_img.transpose(0,2).transpose(0,1).cpu())
402 |
403 |
404 | print("mae:",mae)
405 | print("psnr:",psnr)
406 | print("ssim:",ssim)
407 | mae_list.append(mae)
408 | psnr_list.append(psnr)
409 | ssim_list.append(ssim)
410 | mae_list = np.array(mae_list)
411 | psnr_list = np.array(psnr_list)
412 | ssim_list = np.array(ssim_list)
413 | np.savetxt('/new_disk/zhangjk/NeuralVolumeRender-dynamic/evaluation/complete/mae.out',mae_list)
414 | np.savetxt('/new_disk/zhangjk/NeuralVolumeRender-dynamic/evaluation/complete/psnr.out',psnr_list)
415 | np.savetxt('/new_disk/zhangjk/NeuralVolumeRender-dynamic/evaluation/complete/ssim.out',ssim_list)
416 | avg_mae = np.mean(np.array(mae_list))
417 | avg_psnr = np.mean(np.array(psnr_list))
418 | avg_ssim = np.mean(np.array(ssim_list))
419 | print("avg_mae:",avg_mae)
420 | print("avg_psnr:",avg_psnr)
421 | print("avg_ssim:",avg_ssim)
422 | #print(color_1.shape)
423 | #print(color_gt.shape)
424 | #print(metrics.psnr(color_img, color_gt))
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
--------------------------------------------------------------------------------
/engine/render.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from utils import batchify_ray, vis_density, ray_sampling
4 | import numpy as np
5 | import os
6 | import torch
7 |
8 |
9 | '''
10 | Sample rays from views (and images) with/without masks
11 |
12 | --------------------------
13 | INPUT Tensors
14 | K: intrinsics of camera (3,3)
15 | T: extrinsic of camera (4,4)
16 | image_size: the size of image [H,W]
17 |
18 | ROI: 2D ROI bboxes (4) left up corner(x,y) followed the height and width (h,w)
19 |
20 | masks:(M,H,W)
21 | -------------------
22 | OUPUT:
23 | list of rays: (N,6) dirs(3) + pos(3)
24 | RGB: (N,C)
25 | '''
26 |
27 |
28 |
29 |
30 | def render(model, K,T,img_size,ROI = None, bboxes = None,only_coarse = False,near_far=None):
31 | model.eval()
32 | assert not (bboxes is None and near_far is None), ' either bbox or near_far should not be None.'
33 | mask = torch.ones(img_size[0],img_size[1])
34 | if ROI is not None:
35 | mask = torch.zeros(img_size[0],img_size[1])
36 | mask[ROI[0]:ROI[0]+ROI[2], ROI[1]:ROI[1]+ROI[3]] = 1.0
37 | rays,_ = ray_sampling(K.unsqueeze(0), T.unsqueeze(0), img_size, masks=mask.unsqueeze(0))
38 |
39 | if bboxes is not None:
40 | bboxes = bboxes.unsqueeze(0).repeat(rays.size(0),1,1)
41 |
42 | with torch.no_grad():
43 | stage2, stage1,_ = batchify_ray(model, rays, bboxes,near_far = near_far)
44 |
45 |
46 | rgb = torch.zeros(img_size[0],img_size[1], 3, device = stage2[0].device)
47 | rgb[mask>0.5,:] = stage2[0]
48 |
49 | depth = torch.zeros(img_size[0],img_size[1],1, device = stage2[1].device)
50 | depth[mask>0.5,:] = stage2[1]
51 |
52 | alpha = torch.zeros(img_size[0],img_size[1],1, device = stage2[2].device)
53 | alpha[mask>0.5,:] = stage2[2]
54 |
55 | stage2_final = [None]*3
56 | stage2_final[0] = rgb.reshape(img_size[0],img_size[1], 3)
57 | stage2_final[1] = depth.reshape(img_size[0],img_size[1])
58 | stage2_final[2] = alpha.reshape(img_size[0],img_size[1])
59 |
60 |
61 | rgb = torch.zeros(img_size[0],img_size[1], 3, device = stage1[0].device)
62 | rgb[mask>0.5,:] = stage1[0]
63 |
64 | depth = torch.zeros(img_size[0],img_size[1],1, device = stage1[1].device)
65 | depth[mask>0.5,:] = stage1[1]
66 |
67 | alpha = torch.zeros(img_size[0],img_size[1],1, device = stage1[2].device)
68 | alpha[mask>0.5,:] = stage1[2]
69 |
70 | stage1_final = [None]*3
71 | stage1_final[0] = rgb.reshape(img_size[0],img_size[1], 3)
72 | stage1_final[1] = depth.reshape(img_size[0],img_size[1])
73 | stage1_final[2] = alpha.reshape(img_size[0],img_size[1])
74 |
75 |
76 |
77 | return stage2_final, stage1_final
78 |
--------------------------------------------------------------------------------
/images/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/images/teaser.jpg
--------------------------------------------------------------------------------
/layers/RaySamplePoint-1.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch.nn.functional as F
4 | from torch import nn
5 | import torch
6 | from layers.render_layer import gen_weight
7 | import pdb
8 |
9 | def intersection(rays, bbox):
10 | n = rays.shape[0]
11 | left_face = bbox[:, 0, 0]
12 | right_face = bbox[:, 6, 0]
13 | front_face = bbox[:, 0, 1]
14 | back_face = bbox[:, 6, 1]
15 | bottom_face = bbox[:, 0, 2]
16 | up_face = bbox[:, 6, 2]
17 | # parallel t 无穷大
18 | left_t = ((left_face - rays[:, 0]) / (rays[:, 3] + np.finfo(float).eps.item())).reshape((n, 1))
19 | right_t = ((right_face - rays[:, 0]) / (rays[:, 3] + np.finfo(float).eps.item())).reshape((n, 1))
20 | front_t = ((front_face - rays[:, 1]) / (rays[:, 4] + np.finfo(float).eps.item())).reshape((n, 1))
21 | back_t = ((back_face - rays[:, 1]) / (rays[:, 4] + np.finfo(float).eps.item())).reshape((n, 1))
22 | bottom_t = ((bottom_face - rays[:, 2]) / (rays[:, 5] + np.finfo(float).eps.item())).reshape((n, 1))
23 | up_t = ((up_face - rays[:, 2]) / (rays[:, 5] + np.finfo(float).eps)).reshape((n, 1))
24 |
25 |
26 | rays_o = rays[:, :3]
27 | rays_d = rays[:, 3:6]
28 | left_point = left_t * rays_d + rays_o
29 | right_point = right_t * rays_d + rays_o
30 | front_point = front_t * rays_d + rays_o
31 | back_point = back_t * rays_d + rays_o
32 | bottom_point = bottom_t * rays_d + rays_o
33 | up_point = up_t * rays_d + rays_o
34 |
35 | left_mask = (left_point[:, 1] >= bbox[:, 0, 1]) & (left_point[:, 1] <= bbox[:, 7, 1]) \
36 | & (left_point[:, 2] >= bbox[:, 0, 2]) & (left_point[:, 2] <= bbox[:, 7, 2])
37 | right_mask = (right_point[:, 1] >= bbox[:, 1, 1]) & (right_point[:, 1] <= bbox[:, 6, 1]) \
38 | & (right_point[:, 2] >= bbox[:, 1, 2]) & (right_point[:, 2] <= bbox[:, 6, 2])
39 |
40 | # compare x, z
41 | front_mask = (front_point[:, 0] >= bbox[:, 0, 0]) & (front_point[:, 0] <= bbox[:, 5, 0]) \
42 | & (front_point[:, 2] >= bbox[:, 0, 2]) & (front_point[:, 2] <= bbox[:, 5, 2])
43 |
44 | back_mask = (back_point[:, 0] >= bbox[:, 3, 0]) & (back_point[:, 0] <= bbox[:, 6, 0]) \
45 | & (back_point[:, 2] >= bbox[:, 3, 2]) & (back_point[:, 2] <= bbox[:, 6, 2])
46 |
47 | # compare x,y
48 | bottom_mask = (bottom_point[:, 0] >= bbox[:, 0, 0]) & (bottom_point[:, 0] <= bbox[:, 2, 0]) \
49 | & (bottom_point[:, 1] >= bbox[:, 0, 1]) & (bottom_point[:, 1] <= bbox[:, 2, 1])
50 |
51 | up_mask = (up_point[:, 0] >= bbox[:, 4, 0]) & (up_point[:, 0] <= bbox[:, 6, 0]) \
52 | & (up_point[:, 1] >= bbox[:, 4, 1]) & (up_point[:, 1] <= bbox[:, 6, 1])
53 |
54 | tlist = -torch.ones_like(rays, device=rays.device)*1e3
55 | tlist[left_mask, 0] = left_t[left_mask].reshape((-1,))
56 | tlist[right_mask, 1] = right_t[right_mask].reshape((-1,))
57 | tlist[front_mask, 2] = front_t[front_mask].reshape((-1,))
58 | tlist[back_mask, 3] = back_t[back_mask].reshape((-1,))
59 | tlist[bottom_mask, 4] = bottom_t[bottom_mask].reshape((-1,))
60 | tlist[up_mask, 5] = up_t[up_mask].reshape((-1,))
61 | tlist = tlist.topk(k=2, dim=-1)
62 |
63 | return tlist[0]
64 |
65 | class RaySamplePoint(nn.Module):
66 | def __init__(self, coarse_num=64):
67 | super(RaySamplePoint, self).__init__()
68 | self.coarse_num = coarse_num
69 |
70 |
71 | def forward(self, rays, bbox, pdf=None, method='coarse'):
72 | '''
73 | :param rays: N*6
74 | :param bbox: N*8*3 0,1,2,3 bottom 4,5,6,7 up
75 | pdf: n*coarse_num 表示权重
76 | :param method:
77 | :return: N*C*1 , N*C*3, N
78 | '''
79 | n = rays.shape[0]
80 | #if method=='coarse':
81 | sample_num = self.coarse_num
82 | bin_range = torch.arange(0, sample_num, device=rays.device).reshape((1, sample_num)).float()
83 |
84 | bin_num = sample_num
85 | n = rays.shape[0]
86 | tlist = intersection(rays, bbox)
87 | start = (tlist[:,1]).reshape((n,1))
88 | end = (tlist[:, 0]).reshape((n, 1))
89 |
90 | bin_sample = torch.rand((n, sample_num), device=rays.device)
91 | bin_width = (end - start)/bin_num
92 | sample_t = (bin_range + bin_sample)* bin_width + start
93 | sample_point = sample_t.unsqueeze(-1)*rays[:,3:6].unsqueeze(1) + rays[:,:3].unsqueeze(1)
94 | mask = (torch.abs(bin_width)> 1e-5).squeeze()
95 | return sample_t.unsqueeze(-1), sample_point, mask
96 |
97 |
98 | class RayDistributedSamplePoint(nn.Module):
99 | def __init__(self, fine_num=10):
100 | super(RayDistributedSamplePoint, self).__init__()
101 | self.fine_num = fine_num
102 |
103 | def forward(self, rays, depth, density, noise=0.0):
104 | '''
105 | :param rays: N*L*6
106 | :param depth: N*L*1
107 | :param density: N*L*1
108 | :param noise:0
109 | :return:
110 | '''
111 |
112 | sample_num = self.fine_num
113 | n = density.shape[0]
114 |
115 | weights = gen_weight(depth, density, noise=noise) # N*L
116 | weights += 1e-5
117 | bin = depth.squeeze()
118 |
119 | weights = weights[:, 1:].squeeze() #N*(L-1)
120 | pdf = weights/torch.sum(weights, dim=1, keepdim=True)
121 | cdf = torch.cumsum(pdf, dim=1)
122 | cdf_s = torch.cat((torch.zeros((n, 1)).type(cdf.dtype), cdf), dim=1)
123 | fine_bin = torch.linspace(0, 1, sample_num, device=density.device).reshape((1, sample_num)).repeat((n, 1))
124 | above_index = torch.ones_like(fine_bin, device=density.device).type(torch.LongTensor)
125 | for i in range(cdf.shape[1]):
126 | mask = (fine_bin > (cdf_s[:, i]).reshape((n, 1))) & (fine_bin <= (cdf[:, i]).reshape((n, 1)))
127 | above_index[mask] = i+1
128 | below_index = above_index-1
129 | below_index[below_index==-1]=0
130 | sn_below = torch.gather(bin, dim=1, index=below_index)
131 | sn_above = torch.gather(bin, dim=1, index=above_index)
132 | cdf_below = torch.gather(cdf_s, dim=1, index=below_index)
133 | cdf_above = torch.gather(cdf_s, dim=1, index=above_index)
134 | dnorm = cdf_above - cdf_below
135 | dnorm = torch.where(dnorm<1e-5, torch.ones_like(dnorm, device=density.device), dnorm)
136 | d = (fine_bin - cdf_below)/dnorm
137 | fine_t = (sn_above - sn_below) * d + sn_below
138 | fine_sample_point = fine_t.unsqueeze(-1) * rays[:, 3:6].unsqueeze(1) + rays[:, :3].unsqueeze(1)
139 | return fine_t, fine_sample_point
140 |
141 |
142 |
143 | class RaySamplePoint_Near_Far(nn.Module):
144 | def __init__(self, sample_num=75):
145 | super(RaySamplePoint_Near_Far, self).__init__()
146 | self.sample_num = sample_num
147 |
148 |
149 | def forward(self, rays,near_far):
150 | '''
151 | :param rays: N*6
152 | :param bbox: N*8*3 0,1,2,3 bottom 4,5,6,7 up
153 | pdf: n*coarse_num 表示权重
154 | :param method:
155 | :return: N*C*3
156 | '''
157 | n = rays.size(0)
158 |
159 |
160 | ray_o = rays[:,:3]
161 | ray_d = rays[:,3:6]
162 |
163 | # near = 0.1
164 | # far = 5
165 |
166 |
167 | t_vals = torch.linspace(0., 1., steps=self.sample_num,device =rays.device)
168 | #print(near_far[:,0:1].repeat(1, self.sample_num).size(), t_vals.unsqueeze(0).repeat(n,1).size())
169 | # print(near_far[:,0].unsqueeze(1).expand([n,self.sample_num]).shape)
170 | # print(near_far[:,1].unsqueeze(1).expand([n,self.sample_num]).shape)
171 | # print('------------------------')
172 | #z_vals = near_far[:,0].unsqueeze(1).expand([n,self.sample_num]) * (1.-t_vals).unsqueeze(0).repeat(n,1) + near_far[:,1].unsqueeze(1).expand([n,self.sample_num]) * (t_vals.unsqueeze(0).repeat(n,1))
173 | z_vals = near_far[:,0:1].repeat(1, self.sample_num) * (1.-t_vals).unsqueeze(0).repeat(n,1) + near_far[:,1:2].repeat(1, self.sample_num) * (t_vals.unsqueeze(0).repeat(n,1))
174 | # z_vals = near * (1.-t_vals) + far * (t_vals)
175 | # z_vals = z_vals.expand([n, self.sample_num])
176 | mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
177 | upper = torch.cat([mids, z_vals[...,-1:]], -1)
178 | lower = torch.cat([z_vals[...,:1], mids], -1)
179 |
180 |
181 | t_rand = torch.rand(z_vals.size(), device = rays.device)
182 |
183 | z_vals = lower + (upper - lower) * t_rand
184 |
185 |
186 | pts = ray_o[...,None,:] + ray_d[...,None,:] * z_vals[...,:,None]
187 |
188 | return z_vals.unsqueeze(-1), pts
189 |
190 |
191 |
192 |
--------------------------------------------------------------------------------
/layers/RaySamplePoint.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn.functional as F
3 | from torch import nn
4 | import torch
5 | from layers.render_layer import gen_weight
6 | import pdb
7 |
8 | def intersection(rays, bbox):
9 | n = rays.shape[0]
10 | left_face = bbox[:, 0, 0]
11 | right_face = bbox[:, 6, 0]
12 | front_face = bbox[:, 0, 1]
13 | back_face = bbox[:, 6, 1]
14 | bottom_face = bbox[:, 0, 2]
15 | up_face = bbox[:, 6, 2]
16 | # parallel t 无穷大
17 | left_t = ((left_face - rays[:, 0]) / (rays[:, 3] + np.finfo(float).eps.item())).reshape((n, 1))
18 | right_t = ((right_face - rays[:, 0]) / (rays[:, 3] + np.finfo(float).eps.item())).reshape((n, 1))
19 | front_t = ((front_face - rays[:, 1]) / (rays[:, 4] + np.finfo(float).eps.item())).reshape((n, 1))
20 | back_t = ((back_face - rays[:, 1]) / (rays[:, 4] + np.finfo(float).eps.item())).reshape((n, 1))
21 | bottom_t = ((bottom_face - rays[:, 2]) / (rays[:, 5] + np.finfo(float).eps.item())).reshape((n, 1))
22 | up_t = ((up_face - rays[:, 2]) / (rays[:, 5] + np.finfo(float).eps)).reshape((n, 1))
23 |
24 |
25 | rays_o = rays[:, :3]
26 | rays_d = rays[:, 3:6]
27 | left_point = left_t * rays_d + rays_o
28 | right_point = right_t * rays_d + rays_o
29 | front_point = front_t * rays_d + rays_o
30 | back_point = back_t * rays_d + rays_o
31 | bottom_point = bottom_t * rays_d + rays_o
32 | up_point = up_t * rays_d + rays_o
33 |
34 | left_mask = (left_point[:, 1] >= bbox[:, 0, 1]) & (left_point[:, 1] <= bbox[:, 7, 1]) \
35 | & (left_point[:, 2] >= bbox[:, 0, 2]) & (left_point[:, 2] <= bbox[:, 7, 2])
36 | right_mask = (right_point[:, 1] >= bbox[:, 1, 1]) & (right_point[:, 1] <= bbox[:, 6, 1]) \
37 | & (right_point[:, 2] >= bbox[:, 1, 2]) & (right_point[:, 2] <= bbox[:, 6, 2])
38 |
39 | # compare x, z
40 | front_mask = (front_point[:, 0] >= bbox[:, 0, 0]) & (front_point[:, 0] <= bbox[:, 5, 0]) \
41 | & (front_point[:, 2] >= bbox[:, 0, 2]) & (front_point[:, 2] <= bbox[:, 5, 2])
42 |
43 | back_mask = (back_point[:, 0] >= bbox[:, 3, 0]) & (back_point[:, 0] <= bbox[:, 6, 0]) \
44 | & (back_point[:, 2] >= bbox[:, 3, 2]) & (back_point[:, 2] <= bbox[:, 6, 2])
45 |
46 | # compare x,y
47 | bottom_mask = (bottom_point[:, 0] >= bbox[:, 0, 0]) & (bottom_point[:, 0] <= bbox[:, 2, 0]) \
48 | & (bottom_point[:, 1] >= bbox[:, 0, 1]) & (bottom_point[:, 1] <= bbox[:, 2, 1])
49 |
50 | up_mask = (up_point[:, 0] >= bbox[:, 4, 0]) & (up_point[:, 0] <= bbox[:, 6, 0]) \
51 | & (up_point[:, 1] >= bbox[:, 4, 1]) & (up_point[:, 1] <= bbox[:, 6, 1])
52 |
53 | tlist = -torch.ones_like(rays, device=rays.device)*1e3
54 | tlist[left_mask, 0] = left_t[left_mask].reshape((-1,))
55 | tlist[right_mask, 1] = right_t[right_mask].reshape((-1,))
56 | tlist[front_mask, 2] = front_t[front_mask].reshape((-1,))
57 | tlist[back_mask, 3] = back_t[back_mask].reshape((-1,))
58 | tlist[bottom_mask, 4] = bottom_t[bottom_mask].reshape((-1,))
59 | tlist[up_mask, 5] = up_t[up_mask].reshape((-1,))
60 | tlist = tlist.topk(k=2, dim=-1)
61 |
62 | return tlist[0]
63 |
64 | class RaySamplePoint(nn.Module):
65 | def __init__(self, coarse_num=64):
66 | super(RaySamplePoint, self).__init__()
67 | self.coarse_num = coarse_num
68 |
69 |
70 | def forward(self, rays, bbox, pdf=None, method='coarse'):
71 | '''
72 | :param rays: N*6
73 | :param bbox: N*L*8*3 0,1,2,3 bottom 4,5,6,7 up
74 | pdf: n*coarse_num 表示权重
75 | :param method:
76 | :return: L*N*C*1 , L*N*C*3, L*N
77 | '''
78 | n = rays.shape[0]
79 | l = bbox.shape[1]
80 | #if method=='coarse':
81 | sample_num = self.coarse_num
82 | sample_t = []
83 | sample_point = []
84 | mask = []
85 | for i in range(l):
86 |
87 | bin_range = torch.arange(0, sample_num, device=rays.device).reshape((1, sample_num)).float()
88 |
89 | bin_num = sample_num
90 | n = rays.shape[0]
91 | tlist = intersection(rays, bbox[:,i,:,:])
92 | start = (tlist[:,1]).reshape((n,1))
93 | if i == 0:
94 | idx = start <= 0
95 | start[idx] = 0
96 | end = (tlist[:, 0]).reshape((n,1))
97 |
98 | bin_sample = torch.rand((n, sample_num), device=rays.device)
99 |
100 | bin_width = (end - start)/bin_num
101 |
102 | sample_t.append(((bin_range + bin_sample)* bin_width + start ).unsqueeze(-1))
103 | sample_point.append(sample_t[i]*rays[:,3:6].unsqueeze(1) + rays[:,:3].unsqueeze(1))
104 |
105 | mask.append((torch.abs(bin_width)> 1e-5).squeeze())
106 |
107 | return sample_t, sample_point, mask
108 |
109 |
110 | class RayDistributedSamplePoint(nn.Module):
111 | def __init__(self, fine_num=10):
112 | super(RayDistributedSamplePoint, self).__init__()
113 | self.fine_num = fine_num
114 |
115 | def forward(self, rays, depth, density, noise=0.0):
116 | '''
117 | :param rays: N*L*6
118 | :param depth: N*L*1
119 | :param density: N*L*1
120 | :param noise:0
121 | :return:
122 | '''
123 |
124 | sample_num = self.fine_num
125 | n = density.shape[0]
126 |
127 | weights = gen_weight(depth, density, noise=noise) # N*L
128 | weights += 1e-5
129 | bin = depth.squeeze()
130 |
131 | weights = weights[:, 1:].squeeze() #N*(L-1)
132 | pdf = weights/torch.sum(weights, dim=1, keepdim=True)
133 | cdf = torch.cumsum(pdf, dim=1)
134 | cdf_s = torch.cat((torch.zeros((n, 1)).type(cdf.dtype), cdf), dim=1)
135 | fine_bin = torch.linspace(0, 1, sample_num, device=density.device).reshape((1, sample_num)).repeat((n, 1))
136 | above_index = torch.ones_like(fine_bin, device=density.device).type(torch.LongTensor)
137 | for i in range(cdf.shape[1]):
138 | mask = (fine_bin > (cdf_s[:, i]).reshape((n, 1))) & (fine_bin <= (cdf[:, i]).reshape((n, 1)))
139 | above_index[mask] = i+1
140 | below_index = above_index-1
141 | below_index[below_index==-1]=0
142 | sn_below = torch.gather(bin, dim=1, index=below_index)
143 | sn_above = torch.gather(bin, dim=1, index=above_index)
144 | cdf_below = torch.gather(cdf_s, dim=1, index=below_index)
145 | cdf_above = torch.gather(cdf_s, dim=1, index=above_index)
146 | dnorm = cdf_above - cdf_below
147 | dnorm = torch.where(dnorm<1e-5, torch.ones_like(dnorm, device=density.device), dnorm)
148 | d = (fine_bin - cdf_below)/dnorm
149 | fine_t = (sn_above - sn_below) * d + sn_below
150 | fine_sample_point = fine_t.unsqueeze(-1) * rays[:, 3:6].unsqueeze(1) + rays[:, :3].unsqueeze(1)
151 | return fine_t, fine_sample_point
152 |
153 |
154 |
155 | class RaySamplePoint_Near_Far(nn.Module):
156 | def __init__(self, sample_num=75):
157 | super(RaySamplePoint_Near_Far, self).__init__()
158 | self.sample_num = sample_num
159 |
160 |
161 | def forward(self, rays,near_far):
162 | '''
163 | :param rays: N*6
164 | :param bbox: N*8*3 0,1,2,3 bottom 4,5,6,7 up
165 | pdf: n*coarse_num 表示权重
166 | :param method:
167 | :return: N*C*3
168 | '''
169 | n = rays.size(0)
170 |
171 |
172 | ray_o = rays[:,:3]
173 | ray_d = rays[:,3:6]
174 |
175 | # near = 0.1
176 | # far = 5
177 |
178 |
179 | t_vals = torch.linspace(0., 1., steps=self.sample_num,device =rays.device)
180 | #print(near_far[:,0:1].repeat(1, self.sample_num).size(), t_vals.unsqueeze(0).repeat(n,1).size())
181 | # print(near_far[:,0].unsqueeze(1).expand([n,self.sample_num]).shape)
182 | # print(near_far[:,1].unsqueeze(1).expand([n,self.sample_num]).shape)
183 | # print('------------------------')
184 | #z_vals = near_far[:,0].unsqueeze(1).expand([n,self.sample_num]) * (1.-t_vals).unsqueeze(0).repeat(n,1) + near_far[:,1].unsqueeze(1).expand([n,self.sample_num]) * (t_vals.unsqueeze(0).repeat(n,1))
185 | z_vals = near_far[:,0:1].repeat(1, self.sample_num) * (1.-t_vals).unsqueeze(0).repeat(n,1) + near_far[:,1:2].repeat(1, self.sample_num) * (t_vals.unsqueeze(0).repeat(n,1))
186 | # z_vals = near * (1.-t_vals) + far * (t_vals)
187 | # z_vals = z_vals.expand([n, self.sample_num])
188 | mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
189 | upper = torch.cat([mids, z_vals[...,-1:]], -1)
190 | lower = torch.cat([z_vals[...,:1], mids], -1)
191 |
192 |
193 | t_rand = torch.rand(z_vals.size(), device = rays.device)
194 |
195 | z_vals = lower + (upper - lower) * t_rand
196 |
197 |
198 | pts = ray_o[...,None,:] + ray_d[...,None,:] * z_vals[...,:,None]
199 |
200 | return z_vals.unsqueeze(-1), pts
201 |
--------------------------------------------------------------------------------
/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .RaySamplePoint import RaySamplePoint,RaySamplePoint_Near_Far
2 | from .render_layer import VolumeRenderer
3 | from .loss import make_loss
--------------------------------------------------------------------------------
/layers/__pycache__/RaySamplePoint.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/RaySamplePoint.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/RaySamplePoint.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/RaySamplePoint.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/RaySamplePoint1.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/RaySamplePoint1.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/camera_transform.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/camera_transform.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/loss.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/render_layer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/render_layer.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/render_layer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/layers/__pycache__/render_layer.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/camera_transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.autograd.set_detect_anomaly(True)
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 | def corrupt_cameras(cam_poses, offset=(-0.1, 0.1), rotation=(-5, 5)):
7 | rand_t = np.random.rand(cam_poses.shape[0], 3)
8 | perturb_t = (1 - rand_t) * offset[0] + rand_t * offset[1]
9 | tr = cam_poses[:, :3, 3] + perturb_t
10 | tr = tr[..., None] # [N, 3, 1]
11 |
12 | rand_r = np.random.rand(cam_poses.shape[0], 3)
13 | rand_r = (1 - rand_r) * rotation[0] + rand_r * rotation[1]
14 | rand_r = np.deg2rad(rand_r)
15 |
16 | # Pre-compute rotation matrices
17 | Rx = np.stack((
18 | np.ones_like(rand_r[:, 0]), np.zeros_like(rand_r[:, 0]), np.zeros_like(rand_r[:, 0]),
19 | np.zeros_like(rand_r[:, 0]), np.cos(rand_r[:, 0]), -np.sin(rand_r[:, 0]),
20 | np.zeros_like(rand_r[:, 0]), np.sin(rand_r[:, 0]), np.cos(rand_r[:, 0])
21 | ), axis=1).reshape(-1, 3, 3)
22 |
23 | Ry = np.stack((
24 | np.cos(rand_r[:, 1]), np.zeros_like(rand_r[:, 1]), np.sin(rand_r[:, 1]),
25 | np.zeros_like(rand_r[:, 1]), np.ones_like(rand_r[:, 1]), np.zeros_like(rand_r[:, 1]),
26 | -np.sin(rand_r[:, 1]), np.zeros_like(rand_r[:, 1]), np.cos(rand_r[:, 1])
27 | ), axis=1).reshape(-1, 3, 3)
28 |
29 | Rz = np.stack((
30 | np.cos(rand_r[:, 2]), -np.sin(rand_r[:, 2]), np.zeros_like(rand_r[:, 2]),
31 | np.sin(rand_r[:, 2]), np.cos(rand_r[:, 2]), np.zeros_like(rand_r[:, 2]),
32 | np.zeros_like(rand_r[:, 2]), np.zeros_like(rand_r[:, 2]), np.ones_like(rand_r[:, 2])
33 | ), axis=1).reshape(-1, 3, 3)
34 |
35 | # Apply rotation sequentially
36 | rot = cam_poses[:, :3, :3] # [N, 3, 3]
37 | for perturb_r in [Rz, Ry, Rx]:
38 | rot = np.matmul(perturb_r, rot)
39 |
40 | return np.concatenate([rot, tr], axis=-1)
41 |
42 | # Camera Transformation Layer
43 | class CameraTransformer(nn.Module):
44 |
45 | def __init__(self, num_cams, trainable=False):
46 | """ Init layered sampling
47 | num_cams: number of training cameras
48 | trainable: Whether planes can be trained by optimizer
49 | """
50 | super(CameraTransformer, self).__init__()
51 |
52 | self.trainable = trainable
53 |
54 | identity_quat = torch.Tensor([0, 0, 0, 1]).repeat((num_cams, 1))
55 | identity_off = torch.Tensor([0, 0, 0]).repeat((num_cams, 1))
56 | if self.trainable:
57 | self.rvec = nn.Parameter(torch.Tensor(identity_quat)) # [N_cameras, 4]
58 | self.tvec = nn.Parameter(torch.Tensor(identity_off)) # [N_cameras, 3]
59 | else:
60 | self.register_buffer('rvec', torch.Tensor(identity_quat)) # [N_cameras, 4]
61 | self.register_buffer('tvec', torch.Tensor(identity_off)) # [N_cameras, 3]
62 |
63 | print("Create %d %s camera transformer" % (num_cams, 'trainable' if self.rvec.requires_grad else 'non-trainable'))
64 |
65 | def rot_mats(self):
66 | theta = torch.sqrt(1e-5 + torch.sum(self.rvec ** 2, dim=1))
67 | rvec = self.rvec / theta[:, None]
68 | return torch.stack((
69 | 1. - 2. * rvec[:, 1] ** 2 - 2. * rvec[:, 2] ** 2,
70 | 2. * (rvec[:, 0] * rvec[:, 1] - rvec[:, 2] * rvec[:, 3]),
71 | 2. * (rvec[:, 0] * rvec[:, 2] + rvec[:, 1] * rvec[:, 3]),
72 |
73 | 2. * (rvec[:, 0] * rvec[:, 1] + rvec[:, 2] * rvec[:, 3]),
74 | 1. - 2. * rvec[:, 0] ** 2 - 2. * rvec[:, 2] ** 2,
75 | 2. * (rvec[:, 1] * rvec[:, 2] - rvec[:, 0] * rvec[:, 3]),
76 |
77 | 2. * (rvec[:, 0] * rvec[:, 2] - rvec[:, 1] * rvec[:, 3]),
78 | 2. * (rvec[:, 0] * rvec[:, 3] + rvec[:, 1] * rvec[:, 2]),
79 | 1. - 2. * rvec[:, 0] ** 2 - 2. * rvec[:, 1] ** 2
80 | ), dim=1).view(-1, 3, 3)
81 |
82 | def forward(self, rays_o, rays_d, **render_kwargs):
83 | """ Generate sample points
84 | Args:
85 | rays_o: [N_rays, 3+1] origin points of rays with camera id
86 | rays_d: [N_rays, 3+1] directions of rays with camera id
87 |
88 | render_kwargs: other render parameters
89 |
90 | Return:
91 | rays_o: [N_rays, 3] Transformed origin points
92 | rays_d: [N_rays, 3] Transformed directions of rays
93 | """
94 | assert rays_o.shape[-1] == 4
95 | assert (rays_o[:, 3] == rays_d[:, 3]).all()
96 | indx = rays_o[:, 3].type(torch.LongTensor)
97 |
98 | # Rotate ray directions w.r.t. rvec
99 | c2w = self.rot_mats()[indx]
100 | rays_d = torch.sum(rays_d[..., None, :3] * c2w[:, :3, :3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
101 |
102 | # Translate camera w.r.t. tvec
103 | rays_o = rays_o[..., :3] + self.tvec[indx]
104 |
105 | return rays_o, rays_d
--------------------------------------------------------------------------------
/layers/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def make_loss(cfg):
5 | return nn.MSELoss()
6 |
--------------------------------------------------------------------------------
/layers/render_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import numpy as np
6 |
7 |
8 | def gen_weight(sigma, delta, act_fn=F.relu):
9 | """Generate transmittance from predicted density
10 | """
11 | alpha = 1.-torch.exp(-act_fn(sigma.squeeze(-1))*delta)
12 | weight = 1.-alpha + 1e-10
13 | #weight = alpha * torch.cumprod(weight, dim=-1) / weight # exclusive cum_prod
14 |
15 | weight = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1),device = alpha.device), weight], -1), -1)[:, :-1]
16 |
17 | return weight
18 |
19 | class VolumeRenderer(nn.Module):
20 | def __init__(self, use_mask= False, boarder_weight = 1e10):
21 | super(VolumeRenderer, self).__init__()
22 | self.boarder_weight = boarder_weight
23 | self.use_mask = use_mask
24 |
25 | def forward(self, depth, rgb, sigma, noise=0):
26 | """
27 | N - num rays; L - num samples;
28 | :param depth: torch.tensor, depth for each sample along the ray. [N, L, 1]
29 | :param rgb: torch.tensor, raw rgb output from the network. [N, L, 3]
30 | :param sigma: torch.tensor, raw density (without activation). [N, L, 1]
31 |
32 | :return:
33 | color: torch.tensor [N, 3]
34 | depth: torch.tensor [N, 1]
35 | """
36 |
37 | delta = (depth[:, 1:] - depth[:, :-1]).squeeze() # [N, L-1]
38 | #pad = torch.Tensor([1e10],device=delta.device).expand_as(delta[...,:1])
39 | pad = self.boarder_weight*torch.ones(delta[...,:1].size(),device = delta.device)
40 | delta = torch.cat([delta, pad], dim=-1) # [N, L]
41 |
42 | if noise > 0.:
43 | sigma += (torch.randn(size=sigma.size(),device = delta.device) * noise)
44 |
45 | weights = gen_weight(sigma, delta).unsqueeze(-1) #[N, L, 1]
46 |
47 | color = torch.sum(torch.sigmoid(rgb) * weights, dim=1) #[N, 3]
48 | depth = torch.sum(weights * depth, dim=1) # [N, 1]
49 | acc_map = torch.sum(weights, dim = 1) #
50 | # #TODO: This scaling will make the program crash. because the summing nan value when acc_map is near to 0.
51 | # if acc_map.max() > 0.0001:
52 | # acc_map = acc_map / acc_map.max()
53 |
54 | if self.use_mask:
55 | #TODO: Here may have a bug about multiply color at the last
56 | color = color + (1.-acc_map[...,None]) * color
57 |
58 | return color, depth, acc_map, weights
59 |
60 |
61 | if __name__ == "__main__":
62 | N_rays = 1024
63 | N_samples = 64
64 |
65 | depth = torch.randn(N_rays, N_samples, 1)
66 | raw = torch.randn(N_rays, N_samples, 3)
67 | sigma = torch.randn(N_rays, N_samples, 1)
68 |
69 | renderer = VolumeRenderer()
70 |
71 | color, dpt, weights = renderer(depth, raw, sigma)
72 | print('Predicted [CPU]: ', color.shape, dpt.shape, weights.shape)
73 |
74 | if torch.cuda.is_available():
75 | depth = depth.cuda()
76 | raw = raw.cuda()
77 | sigma = sigma.cuda()
78 | renderer = renderer.cuda()
79 |
80 | color, dpt, weights = renderer(depth, raw, sigma)
81 | print('Predicted [GPU]: ', color.shape, dpt.shape, weights.shape)
82 |
83 | print('Test load data')
84 | tf_depth = np.load('layers/test_output/depth_map.npy')
85 | tf_color = np.load('layers/test_output/rgb_map.npy')
86 | tf_weights = np.load('layers/test_output/weights.npy')
87 | print('TF output = ', tf_depth.shape, tf_color.shape, tf_weights.shape)
88 |
89 | raws = torch.from_numpy(np.load('layers/test_output/raws.npy'))
90 | ray_d = torch.from_numpy(np.load('layers/test_output/ray_d.npy'))
91 | z_val = torch.from_numpy(np.load('layers/test_output/z_vals.npy'))
92 |
93 | print('TF input = ', raws.shape, ray_d.shape, z_val.shape)
94 |
95 | in_depth = z_val
96 | print('in_depth = ', in_depth.shape)
97 | in_raw = raws[:, :, :3]
98 | print('in_raw = ', in_raw.shape)
99 | in_sigma = raws[:, :, 3:]
100 | print('in_sigma = ', in_sigma.shape)
101 |
102 | color, dpt, weights = renderer(in_depth.unsqueeze(-1).cuda(), in_raw.cuda(), in_sigma.cuda())
103 | print('Predicted-TF [GPU]: ', color.shape, dpt.shape, weights.shape)
104 |
105 | print('ERROR [GPU]: ',
106 | np.mean(tf_color - color.detach().cpu().numpy()),
107 | np.mean(tf_depth - dpt.squeeze(-1).detach().cpu().numpy()),
108 | np.mean(tf_weights - weights.squeeze(-1).detach().cpu().numpy()))
--------------------------------------------------------------------------------
/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 |
3 | from .layered_rfrender import LayeredRFRender
4 |
5 | def build_layered_model(cfg,camera_num=0,scale=None,shift=None):
6 | model = LayeredRFRender(cfg, camera_num=camera_num, scale=scale,shift=shift)
7 | return model
8 |
--------------------------------------------------------------------------------
/modeling/motion_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from utils import Trigonometric_kernel
5 | class MotionNet(nn.Module):
6 | # (x,y,z,t)
7 | def __init__(self, c_input=5, include_input = True, input_time = False):
8 | """ Init layered sampling
9 | """
10 | super(MotionNet, self).__init__()
11 | self.c_input = c_input
12 | self.input_time = input_time
13 | #Positional Encoding
14 | self.tri_kernel_pos = Trigonometric_kernel(L=10,input_dim = c_input, include_input = include_input)
15 |
16 | self.pos_dim = self.tri_kernel_pos.calc_dim(c_input)
17 | backbone_dim = 128
18 | head_dim = 128
19 |
20 | self.motion_net = nn.Sequential(
21 | nn.Linear(self.pos_dim, head_dim),
22 | nn.ReLU(inplace=False),
23 | nn.Linear(head_dim,backbone_dim),
24 | nn.ReLU(inplace=True),
25 | nn.Linear(backbone_dim,backbone_dim),
26 | nn.ReLU(inplace=True),
27 | nn.Linear(backbone_dim,backbone_dim),
28 | nn.ReLU(inplace=True),
29 | nn.Linear(backbone_dim ,head_dim),
30 | nn.ReLU(inplace=True),
31 | nn.Linear(head_dim,3)
32 | )
33 |
34 | def forward(self, input_0):
35 | """ Generate sample points
36 | Input:
37 | pos: [N,3] points in real world coordinates
38 |
39 | Output:
40 | flow: [N,3] Scene Flow in real world coordinates
41 | """
42 |
43 | bins_mode = False
44 | if len(input_0.size()) > 2:
45 | bins_mode = True
46 | L = input_0.size(1)
47 | input_0 = input_0.reshape((-1, self.c_input)) # (N,input)
48 |
49 | if self.input_time:
50 | xyz = input_0[:,:-1]
51 | time = input_0[:,-1:]
52 | lower = torch.floor(time)
53 | if not torch.all(torch.eq(lower, time)):
54 | upper = lower + 1
55 | weight = time - lower
56 | i_lower = torch.cat([xyz,lower],-1)
57 | i_upper = torch.cat([xyz,upper],-1)
58 | i_lower = self.tri_kernel_pos(i_lower)
59 | i_upper = self.tri_kernel_pos(i_upper)
60 | input_0 = (1-weight) * i_lower + weight * i_upper
61 | else:
62 | input_0 = self.tri_kernel_pos(input_0)
63 | else:
64 | input_0 = self.tri_kernel_pos(input_0)
65 |
66 | flow = self.motion_net(input_0)
67 |
68 | if bins_mode:
69 | flow = flow.reshape(-1, L, 3)
70 |
71 | return flow
72 |
--------------------------------------------------------------------------------
/modeling/spacenet.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from torch import nn
6 | import time
7 |
8 | from utils import Trigonometric_kernel
9 |
10 |
11 |
12 |
13 | class SpaceNet(nn.Module):
14 |
15 |
16 | def __init__(self, c_pos=3, include_input = True, use_dir = True, use_time = False, deep_rgb = False):
17 | super(SpaceNet, self).__init__()
18 |
19 |
20 | self.tri_kernel_pos = Trigonometric_kernel(L=10,include_input = include_input)
21 | if use_dir:
22 | self.tri_kernel_dir = Trigonometric_kernel(L=4, include_input = include_input)
23 | if use_time:
24 | self.tri_kernel_time = Trigonometric_kernel(L=10, input_dim=1, include_input = include_input)
25 |
26 | self.c_pos = c_pos
27 |
28 | self.pos_dim = self.tri_kernel_pos.calc_dim(c_pos)
29 | if use_dir:
30 | self.dir_dim = self.tri_kernel_dir.calc_dim(3)
31 | else:
32 | self.dir_dim = 0
33 |
34 | if use_time:
35 | self.time_dim = self.tri_kernel_time.calc_dim(1)
36 | else:
37 | self.time_dim = 0
38 |
39 | self.use_dir = use_dir
40 | self.use_time = use_time
41 | backbone_dim = 256
42 | head_dim = 128
43 |
44 |
45 | self.stage1 = nn.Sequential(
46 | nn.Linear(self.pos_dim, backbone_dim),
47 | nn.ReLU(inplace=True),
48 | nn.Linear(backbone_dim,backbone_dim),
49 | nn.ReLU(inplace=True),
50 | nn.Linear(backbone_dim,backbone_dim),
51 | nn.ReLU(inplace=True),
52 | nn.Linear(backbone_dim,backbone_dim),
53 | nn.ReLU(inplace=True),
54 | )
55 |
56 | self.stage2 = nn.Sequential(
57 | nn.Linear(backbone_dim+self.pos_dim, backbone_dim),
58 | nn.ReLU(inplace=True),
59 | nn.Linear(backbone_dim,backbone_dim),
60 | nn.ReLU(inplace=True),
61 | nn.Linear(backbone_dim,backbone_dim),
62 | nn.ReLU(inplace=True),
63 | )
64 |
65 | self.density_net = nn.Sequential(
66 | nn.Linear(backbone_dim, 1)
67 | )
68 | if deep_rgb:
69 | print("deep")
70 | self.rgb_net = nn.Sequential(
71 | nn.ReLU(inplace=True),
72 | nn.Linear(backbone_dim+self.dir_dim+self.time_dim, head_dim),
73 | nn.ReLU(inplace=True),
74 | nn.Linear(head_dim, head_dim),
75 | nn.ReLU(inplace=True),
76 | nn.Linear(head_dim, head_dim),
77 | nn.ReLU(inplace=True),
78 | nn.Linear(head_dim,3)
79 | )
80 | else:
81 | self.rgb_net = nn.Sequential(
82 | nn.ReLU(inplace=True),
83 | nn.Linear(backbone_dim+self.dir_dim+self.time_dim, head_dim),
84 | nn.ReLU(inplace=True),
85 | nn.Linear(head_dim,3)
86 | )
87 |
88 |
89 | '''
90 | INPUT
91 | pos: 3D positions (N,L,c_pos) or (N,c_pos)
92 | rays: corresponding rays (N,6)
93 | times: corresponding time (N,1)
94 |
95 | OUTPUT
96 |
97 | rgb: color (N,L,3) or (N,3)
98 | density: (N,L,1) or (N,1)
99 |
100 | '''
101 | def forward(self, pos, rays, times=None, maxs=None, mins=None):
102 |
103 | #beg = time.time()
104 | rgbs = None
105 | if rays is not None and self.use_dir:
106 |
107 | dirs = rays[...,3:6]
108 |
109 | bins_mode = False
110 | if len(pos.size())>2:
111 | bins_mode = True
112 | L = pos.size(1)
113 | pos = pos.reshape((-1,self.c_pos)) #(N,c_pos)
114 | if rays is not None and self.use_dir:
115 | dirs = dirs.unsqueeze(1).repeat(1,L,1)
116 | dirs = dirs.reshape((-1,self.c_pos)) #(N,3)
117 | if rays is not None and self.use_time:
118 | times = times.unsqueeze(1).repeat(1,L,1)
119 | times = times.reshape((-1,1)) #(N,1)
120 |
121 |
122 |
123 |
124 | if maxs is not None:
125 | pos = ((pos - mins)/(maxs-mins) - 0.5) * 2
126 |
127 | pos = self.tri_kernel_pos(pos)
128 | if rays is not None and self.use_dir:
129 | dirs = self.tri_kernel_dir(dirs)
130 | if self.use_time:
131 | times = self.tri_kernel_time(times)
132 | #torch.cuda.synchronize()
133 | #print('transform :',time.time()-beg)
134 |
135 | #beg = time.time()
136 | x = self.stage1(pos)
137 | x = self.stage2(torch.cat([x,pos],dim =1))
138 |
139 | density = self.density_net(x)
140 |
141 | x1 = 0
142 | if rays is not None and self.use_dir:
143 | x1 = torch.cat([x,dirs],dim =1)
144 | else:
145 | x1 = x.clone()
146 |
147 | rgbs = None
148 | if self.use_time:
149 | x2 = torch.cat([x1,times],dim =1)
150 | rgbs = self.rgb_net(x2)
151 | else:
152 | rgbs = self.rgb_net(x1)
153 | #torch.cuda.synchronize()
154 | #print('fc:',time.time()-beg)
155 |
156 | if bins_mode:
157 | density = density.reshape((-1,L,1))
158 | rgbs = rgbs.reshape((-1,L,3))
159 |
160 | return rgbs, density
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
--------------------------------------------------------------------------------
/outputs/taekwondo/layered_rfnr_checkpoint_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/outputs/taekwondo/layered_rfnr_checkpoint_1.pt
--------------------------------------------------------------------------------
/outputs/walking/layered_rfnr_checkpoint_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/outputs/walking/layered_rfnr_checkpoint_1.pt
--------------------------------------------------------------------------------
/render/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 |
3 | from .render_functions import *
4 | from .neural_renderer import NeuralRenderer
5 | from .layered_neural_renderer import LayeredNeuralRenderer
--------------------------------------------------------------------------------
/render/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/render/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/render/__pycache__/bkgd_renderer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/render/__pycache__/bkgd_renderer.cpython-38.pyc
--------------------------------------------------------------------------------
/render/__pycache__/layered_neural_renderer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/render/__pycache__/layered_neural_renderer.cpython-38.pyc
--------------------------------------------------------------------------------
/render/__pycache__/neural_renderer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/render/__pycache__/neural_renderer.cpython-38.pyc
--------------------------------------------------------------------------------
/render/__pycache__/render_functions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DarlingHang/st-nerf/e0c1c32b09d90101218410443193ddabc1f66d2f/render/__pycache__/render_functions.cpython-38.pyc
--------------------------------------------------------------------------------
/render/bkgd_renderer.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Tuple, Optional
3 | from collections import namedtuple
4 |
5 | import numpy as np
6 | from PIL import Image
7 | import pyrender as pr
8 |
9 | # skew is generally not supported
10 | Pinhole = namedtuple('Pinhole', ['fx', 'fy', 'cx', 'cy'])
11 |
12 |
13 | class MeshRender(ABC):
14 | @abstractmethod
15 | def load_mesh(self, fn: str) -> None:
16 | pass
17 |
18 | @abstractmethod
19 | def render(self, pinhole: Optional[Pinhole] = None,
20 | pose: Optional[np.ndarray] = None) -> Image.Image:
21 | pass
22 |
23 |
24 | class PrRender(MeshRender):
25 | _gl_cv = np.array([
26 | [1, 0, 0, 0],
27 | [0, -1, 0, 0],
28 | [0, 0, -1, 0],
29 | [0, 0, 0, 1],
30 | ])
31 |
32 | def __init__(self, resolution: Tuple[int, int]):
33 | self._scene = pr.Scene(ambient_light=np.ones(3))
34 | self._mesh = None
35 | self._cam = None
36 |
37 | self._width, self._height = resolution
38 | self._render = pr.OffscreenRenderer(self._width, self._height)
39 |
40 | def load_mesh(self, fn: str) -> None:
41 | tm = pr.mesh.trimesh.load_mesh(fn)
42 | mesh = pr.Mesh.from_trimesh(tm)
43 | if mesh:
44 | if self._mesh is not None:
45 | self._scene.remove_node(self._mesh)
46 | self._scene.add(mesh)
47 | self._mesh = mesh
48 |
49 | def render(self, pinhole: Optional[Pinhole] = None,
50 | pose: Optional[np.ndarray] = None,
51 | znear=pr.constants.DEFAULT_Z_NEAR,
52 | zfar=pr.constants.DEFAULT_Z_FAR) -> Image.Image:
53 | if pinhole is not None:
54 | # update intrinsics
55 | if self._cam is None:
56 | cam = pr.IntrinsicsCamera(*pinhole, znear, zfar)
57 | self._cam = self._scene.add(cam)
58 | else:
59 | cam = self._cam.camera
60 | cam.fx, cam.fy, cam.cx, cam.cy = pinhole
61 | cam.znear, cam.zfar = znear, zfar
62 | if self._cam is None:
63 | raise ValueError('Empty intrinsics while previous camera not set')
64 |
65 | # camera is not None
66 | if pose is not None:
67 | # update camera pose
68 | # gl_to_world = cv_to_world @ gl_to_cv
69 | self._scene.set_pose(self._cam, pose @ self._gl_cv)
70 |
71 | color, _ = self._render.render(self._scene)
72 | return color
--------------------------------------------------------------------------------
/render/layered_neural_renderer.py:
--------------------------------------------------------------------------------
1 | from config import cfg
2 | import imageio
3 | import os
4 | import numpy as np
5 | import torch
6 | from data import make_ray_data_loader_render, get_iteration_path
7 | from modeling import build_layered_model
8 | from utils import layered_batchify_ray, add_two_dim_dict
9 | from .render_functions import *
10 | from robopy import *
11 |
12 | from scipy.spatial.transform import Rotation as R
13 | from scipy.spatial.transform import Slerp
14 | from scipy.interpolate import splprep, splev
15 | import pdb
16 |
17 | class LayeredNeuralRenderer:
18 |
19 | def __init__(self,cfg,scale=None,shift=None,rotation = None,s_shift = None,s_scale=None,s_alpha=None):
20 | self.alpha = None
21 | self.cfg = cfg
22 | self.scale = scale
23 | self.shift = shift
24 | self.rotation = rotation
25 | self.s_shift = s_shift
26 | self.s_scale = s_scale
27 | self.s_alpha = s_alpha
28 |
29 |
30 | if s_shift != None:
31 | self.shift = self.s_shift[0]
32 |
33 | if s_scale != None:
34 | self.scale = self.s_scale[0]
35 |
36 | if s_alpha != None:
37 | self.alpha = self.s_alpha[0]
38 |
39 |
40 | # The dictionary save all rendered images and videos
41 | self.dataset_dir = self.cfg.OUTPUT_DIR
42 | self.output_dir = os.path.join(self.cfg.OUTPUT_DIR,'rendered')
43 |
44 | self.dataset, self.model = self.load_dataset_model()
45 |
46 | # {0,1} dictionary, 1 means display, 0 means hide, key is [layer_id]
47 | self.display_layers = {}
48 |
49 | # (0,1,2,...,LAYER_NUM)
50 | for layer_id in range(cfg.DATASETS.LAYER_NUM+1):
51 | self.display_layers[layer_id] = 1
52 |
53 | # Intrinsic for all rendered image, update when firstly load dataset
54 | self.gt_poses = self.dataset.poses
55 | self.gt_Ks = self.dataset.Ks
56 |
57 | # self.near = 0.1
58 | self.far = 20.0
59 |
60 | # Each layer will have a min-max frame range
61 | self.min_frame = [1+cfg.DATASETS.FRAME_OFFSET for i in range(cfg.DATASETS.LAYER_NUM+1)]
62 | self.max_frame = [cfg.DATASETS.FRAME_NUM+cfg.DATASETS.FRAME_OFFSET for i in range(cfg.DATASETS.LAYER_NUM+1)]
63 |
64 | self.images = []
65 | self.depths = []
66 |
67 | # Total image number rendered and saved in renderer
68 | self.image_num = 0
69 | # Total frame number and layer number, use it carefully, because it may not be all loaded into model
70 | self.frame_num = cfg.DATASETS.FRAME_NUM
71 | self.layer_num = cfg.DATASETS.LAYER_NUM
72 | self.camera_num = self.dataset.camera_num
73 | self.min_camera_id = 0
74 | self.max_camera_id = self.camera_num-1
75 |
76 | self.fps = 25
77 | self.height = cfg.INPUT.SIZE_TEST[1]
78 | self.width = cfg.INPUT.SIZE_TEST[0]
79 |
80 | #Count for save multiple videos
81 | self.save_count = 0
82 |
83 | # All rendered poses and intrinsics aligned with images
84 | self.poses = []
85 | self.Ks = []
86 | # Corresponding to each pose, we will have mutiple (layer_id, frame_id) pairs to identify the visible layers and frames.
87 | # Example [[(0,1),(1,1)],[(0,1),(1,2)],...] represent [(layer_0, frame_1), (layer_1,frame_1)] for poses[0] and so on
88 | self.layer_frame_pairs = []
89 |
90 | # Trace one layer (lookat to the center of layer), -1 means no trace layer
91 | self.trace_layer = -1
92 |
93 | # auto saving dir
94 | self.dir_name = ''
95 |
96 | def load_dataset_model(self):
97 | para_file = get_iteration_path(self.dataset_dir)
98 | print(para_file)
99 |
100 | if para_file is None:
101 | assert 'training model does not exist'
102 |
103 | _, dataset = make_ray_data_loader_render(cfg)
104 |
105 | model = build_layered_model(cfg, dataset.camera_num, scale = self.scale, shift=self.shift)
106 |
107 | model.set_bkgd_bbox(dataset.datasets[0][0].bbox)
108 | model.set_bboxes(dataset.bboxes)
109 | model_dict = model.state_dict()
110 | dict_0 = torch.load(os.path.join(para_file),map_location='cuda')
111 |
112 | model_dict = dict_0['model']
113 | model_new_dict = model.state_dict()
114 | offset = {k: v for k, v in model_new_dict.items() if k not in model_dict}
115 | for k,v in offset.items():
116 | model_dict[k] = v
117 | model.load_state_dict(model_dict)
118 |
119 | model.cuda()
120 |
121 | return dataset, model
122 |
123 |
124 | def check_label(self):
125 | output = os.path.join(self.output_dir,'masked_images')
126 | if not os.path.exists(output):
127 | os.makedirs(output)
128 | for i in range(self.frame_num):
129 | output_f = os.path.join(output, 'frame%d' % i)
130 | if not os.path.exists(output_f):
131 | os.makedirs(output_f)
132 | for j in range(self.camera_num):
133 | image, label = self.dataset.get_image_label(j, i)
134 | image = image.permute(1,2,0)
135 | image[label[0,...]==0] = 0
136 | imageio.imwrite(os.path.join(output_f,'%d.jpg'% j), image)
137 |
138 | return
139 |
140 |
141 |
142 |
143 | # The function set the pose, before using it, set the right frame duration for each layer
144 | def set_path_lookat(self, start,end,step_num,center,up):
145 |
146 | # Generate poses
147 | if self.trace_layer == -1:
148 | poses = generate_poses_by_path(start,end,step_num,center,up)
149 | else:
150 | centers = []
151 | temp = center
152 | for idx in range(step_num):
153 | frame_id = int((self.max_frame-self.min_frame)/step_num*(idx+1)) + self.min_frame
154 | frame_dic = self.datasets[frame_id]
155 | for layer_id in frame_dic:
156 | if layer_id == self.trace_layer:
157 | temp = frame_dic[layer_id].center
158 | centers.append(temp)
159 | poses = generate_poses_by_path_center(start,end,step_num,centers,up)
160 |
161 | self.poses = self.poses + poses
162 |
163 | for idx in range(len(poses)+1):
164 | layer_frame_pair = []
165 | for layer_id in range(self.layer_num+1):
166 | if self.is_shown_layer(layer_id):
167 | frame_id = int((self.max_frame[layer_id]-self.min_frame[layer_id])/len(poses)*(idx)) + self.min_frame[layer_id]
168 | layer_frame_pair.append((layer_id,frame_id))
169 | self.layer_frame_pairs.append(layer_frame_pair)
170 |
171 | def set_path_gt_poses(self):
172 | poses = []
173 | for i in range(self.dataset.poses.shape[0]):
174 | poses.append(self.dataset.poses[i])
175 |
176 | self.poses = self.poses + poses
177 | self.Ks = self.Ks + self.gt_Ks
178 |
179 | for idx in range(len(poses)+1):
180 | layer_frame_pair = []
181 | for layer_id in range(self.layer_num+1):
182 | if self.is_shown_layer(layer_id):
183 | frame_id = int((self.max_frame[layer_id]-self.min_frame[layer_id])/len(poses)*(idx)) + self.min_frame[layer_id]
184 | layer_frame_pair.append((layer_id,frame_id))
185 | self.layer_frame_pairs.append(layer_frame_pair)
186 |
187 |
188 | def set_path_fixed_gt_poses(self,id,num=None):
189 | poses = []
190 | Ks = []
191 | if self.s_shift != None:
192 | s_shift_start = np.array(self.s_shift[0])
193 | s_shift_end = np.array(self.s_shift[1])
194 | s_shift_step = (s_shift_end-s_shift_start)/(num-1)
195 | self.s_shift_frame = []
196 |
197 | if self.s_scale != None:
198 | s_scale_start = np.array(self.s_scale[0])
199 | s_scale_end = np.array(self.s_scale[1])
200 | s_scale_step = (s_scale_end-s_scale_start)/(num-1)
201 | self.s_scale_frame = []
202 |
203 |
204 | for i in range(num):
205 | poses.append(self.dataset.poses[id])
206 | K = self.dataset.Ks[id]
207 |
208 | # EXPEDIENCY
209 | if K == None:
210 | K = self.dataset.Ks[id+1]
211 | Ks.append(K)
212 | if self.s_shift != None:
213 | self.s_shift_frame.append((s_shift_start+i*s_shift_step).tolist())
214 |
215 | if self.s_scale != None:
216 | self.s_scale_frame.append((s_scale_start+i*s_scale_step).tolist())
217 |
218 | self.poses = self.poses + poses
219 | self.Ks = self.Ks + Ks
220 |
221 | for idx in range(len(poses)+1):
222 | layer_frame_pair = []
223 | for layer_id in range(self.layer_num+1):
224 | if self.is_shown_layer(layer_id):
225 | frame_id = int((self.max_frame[layer_id]-self.min_frame[layer_id])/len(poses)*(idx)) + self.min_frame[layer_id]
226 | layer_frame_pair.append((layer_id,frame_id))
227 | self.layer_frame_pairs.append(layer_frame_pair)
228 |
229 |
230 | def set_smooth_path_poses(self,step_num, around=False, smooth_time = False):
231 |
232 | if self.s_shift != None:
233 | s_shift_start = np.array(self.s_shift[0])
234 | s_shift_end = np.array(self.s_shift[1])
235 | s_shift_step = (s_shift_end-s_shift_start)/(step_num-1)
236 | self.s_shift_frame = []
237 |
238 | if self.s_alpha != None:
239 | s_alpha_start = self.s_alpha[0]
240 | s_alpha_end = self.s_alpha[1]
241 | s_alpha_step = (s_alpha_end-s_alpha_start)/(step_num-1)
242 | self.s_alpha_frame = []
243 |
244 | poses = []
245 | Rs = self.gt_poses[self.min_camera_id:self.max_camera_id+1,:3,:3].cpu().numpy()
246 | Ts = self.gt_poses[self.min_camera_id:self.max_camera_id+1,:3,3].cpu().numpy()
247 | #print(Ts)
248 |
249 | key_frames = [i for i in range(self.min_camera_id,self.max_camera_id+1)]
250 | # Only use the first and the last
251 | if not around:
252 | temp = [Rs[0],Rs[-1]]
253 | Rs = np.array(temp)
254 | key_frames = [self.min_camera_id,self.max_camera_id]
255 |
256 | # key_frames = [i for i in range(self.min_camera_id,self.max_camera_id)]
257 | # key_frames = [self.min_camera_id,self.max_camera_id-1]
258 |
259 | interp_frames = [(i * (self.max_camera_id-self.min_camera_id) / (step_num-1) + self.min_camera_id) for i in range(step_num)]
260 | #print(interp_frames)
261 | # print(interp_frames)
262 | Rs = R.from_matrix(Rs)
263 | slerp = Slerp(key_frames, Rs)
264 | interp_Rs = slerp(interp_frames).as_matrix()
265 | #print(interp_Rs)
266 |
267 | x = Ts[:,0]
268 | y = Ts[:,1]
269 | z = Ts[:,2]
270 |
271 | tck, u0 = splprep([x,y,z])
272 | u_new = [i / (step_num-1) for i in range(step_num)]
273 | new_points = splev(u_new,tck)
274 |
275 | new_points = np.stack(new_points, axis=1)
276 |
277 | K0 = self.gt_Ks[self.min_camera_id]
278 | K1 = self.gt_Ks[self.max_camera_id]
279 |
280 | if self.s_scale != None:
281 | s_scale_start = np.array(self.s_scale[0])
282 | s_scale_end = np.array(self.s_scale[1])
283 | s_scale_step = (s_scale_end-s_scale_start)/(step_num-1)
284 | self.s_scale_frame = []
285 | for i in range(step_num):
286 | pose = np.zeros((4,4))
287 | pose[:3,:3] = interp_Rs[i]
288 | pose[:3,3] = new_points[i]
289 | pose[3,3] = 1
290 | poses.append(pose)
291 |
292 | K = (K1 - K0) * i / (step_num - 1) + K0
293 |
294 | # print(K)
295 |
296 | self.Ks.append(K)
297 | if self.s_scale != None:
298 | self.s_scale_frame.append((s_scale_start+i*s_scale_step).tolist())
299 |
300 | if self.s_shift != None:
301 | self.s_shift_frame.append((s_shift_start+i*s_shift_step).tolist())
302 |
303 | if self.s_alpha != None:
304 | self.s_alpha_frame.append((s_alpha_start+i*s_alpha_step))
305 |
306 | self.poses = self.poses + poses
307 |
308 | # Generate corresponding layer id and frame id for poses
309 | for idx in range(len(poses)+1):
310 | layer_frame_pair = []
311 | for layer_id in range(self.layer_num+1):
312 | if self.is_shown_layer(layer_id):
313 |
314 | if not smooth_time:
315 | frame_id = int((self.max_frame[layer_id]-self.min_frame[layer_id])/len(poses)*(idx)) + self.min_frame[layer_id]
316 | else:
317 | frame_id = (self.max_frame[layer_id]-self.min_frame[layer_id])/len(poses)*(idx) + self.min_frame[layer_id]
318 | layer_frame_pair.append((layer_id,frame_id))
319 | self.layer_frame_pairs.append(layer_frame_pair)
320 |
321 | def load_path_poses(self,poses):
322 | self.poses = poses
323 | step_num = len(poses)
324 | K0 = self.gt_Ks[self.min_camera_id]
325 | K1 = self.gt_Ks[self.max_camera_id-1]
326 | for i in range(step_num):
327 | K = (K1 - K0) * i / (step_num - 1) + K0
328 | self.Ks.append(K)
329 |
330 | for idx in range(len(poses)+1):
331 | layer_frame_pair = []
332 | for layer_id in range(self.layer_num+1):
333 | if self.is_shown_layer(layer_id):
334 | frame_id = int((self.max_frame[layer_id]-self.min_frame[layer_id])/len(poses)*(idx)) + self.min_frame[layer_id]
335 | layer_frame_pair.append((layer_id,frame_id))
336 | self.layer_frame_pairs.append(layer_frame_pair)
337 |
338 |
339 | def load_cams_from_path(self, path):
340 |
341 | campose = np.load(os.path.join(path, 'RT_c2w.npy'))
342 | Ts = np.zeros((campose.shape[0],4,4))
343 | Ts[:,:3,:] = campose.reshape(-1, 3, 4)
344 | Ts[:,3,3] = 1.
345 |
346 | #scale
347 | Ts[:,:3,3] = self.cfg.DATASETS.SCALE * Ts[:,:3,3]
348 |
349 | Ks = np.load(os.path.join(path, 'K.npy'))
350 | Ks = Ks.reshape(-1, 3, 3)
351 |
352 | self.poses = Ts
353 | self.Ks = torch.from_numpy(Ks.astype(np.float32))
354 |
355 | for idx in range(len(self.poses)+1):
356 | layer_frame_pair = []
357 | for layer_id in range(self.layer_num+1):
358 | if self.is_shown_layer(layer_id):
359 | frame_id = int((self.max_frame[layer_id]-self.min_frame[layer_id])/len(self.poses)*(idx)) + self.min_frame[layer_id]
360 | layer_frame_pair.append((layer_id,frame_id))
361 | self.layer_frame_pairs.append(layer_frame_pair)
362 |
363 |
364 | def render_pose(self, pose, K, layer_frame_pair, density_threshold=0,bkgd_density_threshold=0):
365 | print(K)
366 | print(pose)
367 | #print(K)
368 | H = self.dataset.height
369 | W = self.dataset.width
370 | rays, labels, bbox, near_far = self.dataset.get_rays_by_pose_and_K(pose, K, layer_frame_pair)
371 |
372 | rays = rays.cuda()
373 | bbox = bbox.cuda()
374 | labels = labels.cuda()
375 | near_far = near_far.cuda()
376 |
377 | with torch.no_grad():
378 | stage2, stage1, stage2_layer, stage1_layer, _ = layered_batchify_ray(self.model, rays, labels, bbox, near_far=near_far, density_threshold=density_threshold,bkgd_density_threshold=bkgd_density_threshold)
379 |
380 | color = stage2[0].reshape(H,W,3)
381 | depth = stage2[1].reshape(H,W,1)
382 | depth[depth < 0] = 0
383 | depth = depth / self.far
384 | color_layer = [i[0].reshape(H,W,3) for i in stage2_layer]
385 | depth_layer = []
386 | for temp in stage2_layer:
387 | depth_1 = temp[1].reshape(H,W,1)
388 | depth_1[depth < 0] = 0
389 | depth_1 = depth_1 / self.far
390 | depth_layer.append(depth_1)
391 |
392 | return color,depth,color_layer,depth_layer
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 | def render_path(self,inverse_y_axis=False,density_threshold=0,bkgd_density_threshold=0, auto_save=True):
402 |
403 | if self.dir_name == '':
404 | save_dir = os.path.join(self.output_dir,'video_%d' % self.save_count,'mixed')
405 | else:
406 | save_dir = os.path.join(self.output_dir,self.dir_name,'video_%d' % self.save_count,'mixed')
407 | if not os.path.exists(save_dir):
408 | os.makedirs(save_dir)
409 | os.mkdir(os.path.join(save_dir,'color'))
410 | os.mkdir(os.path.join(save_dir,'depth'))
411 |
412 | file_poses = open(os.path.join(save_dir,'poses'),mode='w')
413 | for pose in self.poses:
414 | file_poses.write(str(pose)+"\n")
415 | file_poses.close()
416 |
417 | file_Ks = open(os.path.join(save_dir,'Ks'),mode='w')
418 | for K in self.Ks:
419 | file_Ks.write(str(K)+"\n")
420 | file_Ks.close()
421 |
422 |
423 | self.images = []
424 | self.depths = []
425 | self.images_layer = [[] for i in range(self.layer_num+1)]
426 | self.depths_layer = [[] for i in range(self.layer_num+1)]
427 |
428 | self.image_num = 0
429 |
430 | for idx in range(len(self.poses)):
431 | print('Rendering image %d' % idx)
432 | K = self.Ks[idx]
433 | pose = self.poses[idx]
434 | layer_frame_pair = self.layer_frame_pairs[idx]
435 | if self.s_shift != None:
436 | self.model.shift = self.s_shift_frame[idx]
437 | if self.s_scale != None:
438 | self.model.scale = self.s_scale_frame[idx]
439 | if self.s_alpha != None:
440 | self.model.alpha = self.s_alpha_frame[idx]
441 |
442 | color,depth,color_layer,depth_layer = self.render_pose(pose, K, layer_frame_pair, density_threshold,bkgd_density_threshold)
443 |
444 |
445 | if inverse_y_axis:
446 | color = torch.flip(color,[0])
447 | depth = torch.flip(depth,[0])
448 | color_layer = [torch.flip(i,[0]) for i in color_layer]
449 | depth_layer = [torch.flip(i,[0]) for i in depth_layer]
450 |
451 | color = color.cpu()
452 | depth = depth.cpu()
453 | color_layer = [i.cpu() for i in color_layer]
454 | depth_layer = [i.cpu() for i in depth_layer]
455 |
456 | if auto_save:
457 | if self.dir_name == '':
458 | save_dir = os.path.join(self.output_dir,'video_%d' % self.save_count,'mixed')
459 | else:
460 | save_dir = os.path.join(self.output_dir,self.dir_name,'video_%d' % self.save_count,'mixed')
461 | if not os.path.exists(save_dir):
462 | os.makedirs(save_dir)
463 | os.mkdir(os.path.join(save_dir,'color'))
464 | os.mkdir(os.path.join(save_dir,'depth'))
465 |
466 | #print(rgb.shape)
467 | imageio.imwrite(os.path.join(save_dir,'color','%d.jpg'% self.image_num), color)
468 | imageio.imwrite(os.path.join(save_dir,'depth','%d.png'% self.image_num), depth)
469 | self.images.append(color)
470 | self.depths.append(depth)
471 | for layer_id in range(self.layer_num+1):
472 | if self.is_shown_layer(layer_id):
473 | if self.dir_name == '':
474 | save_dir = os.path.join(self.output_dir,'video_%d' % self.save_count,str(layer_id))
475 | else:
476 | save_dir = os.path.join(self.output_dir,self.dir_name,'video_%d' % self.save_count,str(layer_id))
477 | if not os.path.exists(save_dir):
478 | os.makedirs(save_dir)
479 | os.mkdir(os.path.join(save_dir,'color'))
480 | os.mkdir(os.path.join(save_dir,'depth'))
481 |
482 | imageio.imwrite(os.path.join(save_dir,'color','%d.jpg'% self.image_num), color_layer[layer_id])
483 | imageio.imwrite(os.path.join(save_dir,'depth','%d.png'% self.image_num), depth_layer[layer_id])
484 | self.images_layer[layer_id].append(color)
485 | self.depths_layer[layer_id].append(depth)
486 |
487 |
488 | self.image_num += 1
489 |
490 |
491 |
492 |
493 |
494 |
495 | def retime_by_key_frames(self, layer_id, key_frames_layer, key_frames):
496 |
497 | assert (len(key_frames_layer) == len(key_frames))
498 |
499 | for i in range(len(self.layer_frame_pairs)):
500 | for j in range(len(self.layer_frame_pairs[i])):
501 | layer, frame = self.layer_frame_pairs[i][j]
502 | #Retiming the corresponding layer
503 | if layer == layer_id:
504 | idx_start = -1
505 | idx_end = -1
506 | weight = 0
507 | for idx in range(len(key_frames)):
508 | if frame <= key_frames[idx]:
509 | idx_end = idx
510 | idx_start = idx_end-1
511 | end = key_frames[idx]
512 | start = 0
513 | if idx == 0:
514 | start = self.min_frame[layer]
515 | else:
516 | start = key_frames[idx-1]
517 | weight = (frame-start) / (end-start)
518 | # print('frame %d, start %d, end %d' % (frame,start,end))
519 | # print('idx_end %d, idx_start %d' % (idx_end, idx_start))
520 | break
521 |
522 | new_end = 0
523 | new_start = 0
524 | # print('123')
525 | # print('idx_end %d, idx_start %d' % (idx_end, idx_start))
526 | if idx_start == -1 and idx_end == 0:
527 | weight = (frame-self.min_frame[layer]) / (key_frames[0] - self.min_frame[layer])
528 | new_start = self.min_frame[layer]
529 | new_end = key_frames_layer[0]
530 | elif idx_start >= -1 and idx_end != -1:
531 | new_start = key_frames_layer[idx_start]
532 | new_end = key_frames_layer[idx_start+1]
533 | elif idx_start == -1 and idx_end == -1:
534 | weight = (frame-key_frames[-1]) / (self.max_frame[layer] - key_frames[-1])
535 | new_start = key_frames_layer[-1]
536 | new_end = self.max_frame[layer]
537 | else:
538 | print('Undefined branch', 'start idx is %d, end idx is %d' % (idx_start,idx_end))
539 | exit(-1)
540 |
541 | new_frame = round(weight * (new_end - new_start) + new_start)
542 | # print('new end is %d, new start is %d' % (new_end,new_start))
543 | # print('layer %d: old frame is %d, new is %d, weight %f' % (layer,frame,new_frame,weight))
544 | self.layer_frame_pairs[i][j] = (layer, new_frame)
545 |
546 | # exit(0)
547 |
548 |
549 |
550 | def render_path_walking(self,inverse_y_axis=False,density_threshold=0,bkgd_density_threshold=0, auto_save=True):
551 |
552 | self.images = []
553 | self.depths = []
554 | self.images_layer = [[] for i in range(self.layer_num+1)]
555 | self.depths_layer = [[] for i in range(self.layer_num+1)]
556 |
557 | self.image_num = 0
558 |
559 | for idx in range(len(self.poses)):
560 | print('Rendering image %d' % idx)
561 | K = self.Ks[idx]
562 | pose = self.poses[idx]
563 | layer_frame_pair = self.layer_frame_pairs[idx]
564 |
565 | color,depth,color_layer,depth_layer = self.render_pose(pose, K, layer_frame_pair, density_threshold,bkgd_density_threshold)
566 |
567 | if inverse_y_axis:
568 | color = torch.flip(color,[0])
569 | depth = torch.flip(depth,[0])
570 | color_layer = [torch.flip(i,[0]) for i in color_layer]
571 | depth_layer = [torch.flip(i,[0]) for i in depth_layer]
572 |
573 | color = color.cpu()
574 | depth = depth.cpu()
575 | color_layer = [i.cpu() for i in color_layer]
576 | depth_layer = [i.cpu() for i in depth_layer]
577 |
578 | if auto_save:
579 | save_dir = os.path.join(self.output_dir,'mixed')
580 | if not os.path.exists(save_dir):
581 | os.makedirs(save_dir)
582 | os.mkdir(os.path.join(save_dir,'color'))
583 | os.mkdir(os.path.join(save_dir,'depth'))
584 |
585 | #print(rgb.shape)
586 | imageio.imwrite(os.path.join(save_dir,'color','%d.jpg'% self.image_num), color)
587 | imageio.imwrite(os.path.join(save_dir,'depth','%d.png'% self.image_num), depth)
588 | self.images.append(color)
589 | self.depths.append(depth)
590 | for layer_id in range(self.layer_num+1):
591 | save_dir = os.path.join(self.output_dir,str(layer_id))
592 | if not os.path.exists(save_dir):
593 | os.makedirs(save_dir)
594 | os.mkdir(os.path.join(save_dir,'color'))
595 | os.mkdir(os.path.join(save_dir,'depth'))
596 |
597 | imageio.imwrite(os.path.join(save_dir,'color','%d.jpg'% self.image_num), color_layer[layer_id])
598 | imageio.imwrite(os.path.join(save_dir,'depth','%d.png'% self.image_num), depth_layer[layer_id])
599 | self.images_layer[layer_id].append(color)
600 | self.depths_layer[layer_id].append(depth)
601 |
602 | color_hide = color_layer[0].clone()
603 | index = depth_layer[2]=start_epoches:
66 | return (1.0-scale)*math.exp(-(epoch0-start_epoches)/(end_epoches-start_epoches)) + scale
67 |
68 |
69 | return 1.0
70 | return torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=[scheduler]*len(optimizer.param_groups))
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .dimension_kernel import Trigonometric_kernel
8 | from .ray_sampling import ray_sampling,ray_sampling_label_bbox,ray_sampling_label_label
9 | from .batchify_rays import batchify_ray, layered_batchify_ray,layered_batchify_ray_big
10 | from .vis_density import vis_density
11 | from .sample_pdf import sample_pdf
12 | from .high_dim_dics import add_two_dim_dict, add_three_dim_dict
13 | from .render_helpers import *
14 |
--------------------------------------------------------------------------------
/utils/batchify_rays.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def batchify_ray(model, rays, bboxes, chuncks = 1024*7, near_far=None, near_far_points = [], density_threshold=0,bkgd_density_threshold=0):
5 | N = rays.size(0)
6 | if N 0:
46 | ray_masks = torch.cat(ray_masks, dim=0)
47 |
48 | return (colors[1], depths[1], acc_maps[1]), (colors[0], depths[0], acc_maps[0]), ray_masks
49 |
50 |
51 | def layered_batchify_ray(model, rays, labels, bboxes, chuncks = 512*7, near_far=None, near_far_points = [], density_threshold=0,bkgd_density_threshold=0):
52 | N = rays.size(0)
53 | if N 0:
130 | for i in range(len(stage2_layer)):
131 | ray_masks[i] = torch.cat(ray_masks[i], dim=0)
132 |
133 | stage1_layer_final = []
134 | stage2_layer_final = []
135 |
136 | for i in range(len(stage2_layer)):
137 | stage1_layer_final.append((colors[2+i*2], depths[2+i*2], acc_maps[2+i*2]))
138 | stage2_layer_final.append((colors[3+i*2], depths[3+i*2], acc_maps[3+i*2]))
139 | return (colors[1], depths[1], acc_maps[1]), (colors[0], depths[0], acc_maps[0]),\
140 | stage2_layer_final, stage1_layer_final, ray_masks
141 |
142 |
143 |
144 | def layered_batchify_ray_big(layer_big,scale,model, rays, labels, bboxes, chuncks = 512*7, near_far=None, near_far_points = [], density_threshold=0):
145 | N = rays.size(0)
146 | if N 0:
223 | for i in range(len(stage2_layer)):
224 | ray_masks[i] = torch.cat(ray_masks[i], dim=0)
225 |
226 | stage1_layer_final = []
227 | stage2_layer_final = []
228 |
229 | for i in range(len(stage2_layer)):
230 | stage1_layer_final.append((colors[2+i*2], depths[2+i*2], acc_maps[2+i*2]))
231 | stage2_layer_final.append((colors[3+i*2], depths[3+i*2], acc_maps[3+i*2]))
232 | return (colors[1], depths[1], acc_maps[1]), (colors[0], depths[0], acc_maps[0]),\
233 | stage2_layer_final, stage1_layer_final, ray_masks
--------------------------------------------------------------------------------
/utils/dimension_kernel.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class Embedder:
4 | def __init__(self, **kwargs):
5 | self.kwargs = kwargs
6 | self.create_embedding_fn()
7 |
8 | def create_embedding_fn(self):
9 | embed_fns = []
10 | d = self.kwargs['input_dims']
11 | out_dim = 0
12 | if self.kwargs['include_input']:
13 | embed_fns.append(lambda x : x)
14 | out_dim += d
15 |
16 | max_freq = self.kwargs['max_freq_log2']
17 | N_freqs = self.kwargs['num_freqs']
18 |
19 | if self.kwargs['log_sampling']:
20 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
21 | else:
22 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
23 |
24 | for freq in freq_bands:
25 | for p_fn in self.kwargs['periodic_fns']:
26 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
27 | out_dim += d
28 |
29 | self.embed_fns = embed_fns
30 | self.out_dim = out_dim
31 |
32 | def embed(self, inputs):
33 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
34 |
35 |
36 | def get_embedder(multires, i=0,include_input = True,input_dim=3):
37 | if i == -1:
38 | return nn.Identity(), 3
39 |
40 | embed_kwargs = {
41 | 'include_input' : include_input,
42 | 'input_dims' : input_dim,
43 | 'max_freq_log2' : multires-1,
44 | 'num_freqs' : multires,
45 | 'log_sampling' : True,
46 | 'periodic_fns' : [torch.sin, torch.cos],
47 | }
48 |
49 | embedder_obj = Embedder(**embed_kwargs)
50 | embed = lambda x, eo=embedder_obj : eo.embed(x)
51 | return embed, embedder_obj.out_dim
52 |
53 | # Positional encoding
54 | class Trigonometric_kernel:
55 | def __init__(self, L = 10, input_dim = 3, include_input=True):
56 |
57 | self.L = L
58 |
59 | self.embed_fn, self.out_ch= get_embedder(L,include_input = include_input, input_dim=input_dim)
60 |
61 | '''
62 | INPUT
63 | x: input vectors (N,C)
64 |
65 | OUTPUT
66 |
67 | pos_kernel: (N, calc_dim(C) )
68 | '''
69 | def __call__(self, x):
70 | return self.embed_fn(x)
71 |
72 | def calc_dim(self, dims=0):
73 | return self.out_ch
--------------------------------------------------------------------------------
/utils/high_dim_dics.py:
--------------------------------------------------------------------------------
1 |
2 | def add_two_dim_dict(adic, key_a, key_b, val):
3 | if key_a in adic:
4 | adic[key_a].update({key_b: val})
5 | else:
6 | adic.update({key_a:{key_b: val}})
7 |
8 | def add_three_dim_dict(adic, key_a, key_b, key_c, val):
9 | if key_a in adic:
10 | if key_b in adic[key_a]:
11 | adic[key_a][key_b].update({key_c: val})
12 | else:
13 | adic[key_a].update({key_b:{key_c: val}})
14 | else:
15 | adic.update({key_a: {key_b: {key_c: val}}})
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import logging
8 | import os
9 | import sys
10 |
11 |
12 | def setup_logger(name, save_dir, distributed_rank):
13 | logger = logging.getLogger(name)
14 | logger.setLevel(logging.DEBUG)
15 | # don't log results for the non-master process
16 | if distributed_rank > 0:
17 | return logger
18 | ch = logging.StreamHandler(stream=sys.stdout)
19 | ch.setLevel(logging.DEBUG)
20 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
21 | ch.setFormatter(formatter)
22 | logger.addHandler(ch)
23 |
24 | if save_dir:
25 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w')
26 | fh.setLevel(logging.DEBUG)
27 | fh.setFormatter(formatter)
28 | logger.addHandler(fh)
29 |
30 | return logger
31 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from kornia.losses import ssim as dssim
3 |
4 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'):
5 | value=(image_pred-image_gt)**2
6 | if valid_mask is not None:
7 | value = value[valid_mask]
8 | if reduction == 'mean':
9 | return torch.mean(value)
10 | return value
11 |
12 | def mae(image_pred, image_gt):
13 | value=torch.abs(image_pred-image_gt)
14 | return torch.mean(value)
15 |
16 | def psnr(image_pred, image_gt, valid_mask=None, reduction='mean'):
17 | return -10*torch.log10(mse(image_pred, image_gt, valid_mask, reduction))
18 |
19 | def ssim(image_pred, image_gt, reduction='mean'):
20 | """
21 | image_pred and image_gt: (3, H, W)
22 | """
23 | dssim_ = dssim(image_pred.unsqueeze(0), image_gt.unsqueeze(0), 3, reduction) # dissimilarity in [0, 1]
24 | return 1-2*dssim_ # in [-1, 1]
--------------------------------------------------------------------------------
/utils/ray_sampling.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 |
5 | '''
6 | Sample rays from views (and images) with/without masks
7 |
8 | --------------------------
9 | INPUT Tensors
10 | Ks: intrinsics of cameras (M,3,3)
11 | Ts: extrinsic of cameras (M,4,4)
12 | image_size: the size of image [H,W]
13 | images: (M,C,H,W)
14 | mask_threshold: a float threshold to mask rays
15 | masks:(M,H,W)
16 | -------------------
17 | OUPUT:
18 | list of rays: (N,6) dirs(3) + pos(3)
19 | RGB: (N,C)
20 | '''
21 |
22 | def ray_sampling(Ks, Ts, image_size, masks=None, mask_threshold = 0.5, images=None, outlier_map=None):
23 | h = image_size[0]
24 | w = image_size[1]
25 | M = Ks.size(0)
26 |
27 |
28 | x = torch.linspace(0,h-1,steps=h,device = Ks.device )
29 | y = torch.linspace(0,w-1,steps=w,device = Ks.device )
30 |
31 | grid_x, grid_y = torch.meshgrid(x,y)
32 | coordinates = torch.stack([grid_y, grid_x]).unsqueeze(0).repeat(M,1,1,1) #(M,2,H,W)
33 | coordinates = torch.cat([coordinates,torch.ones(coordinates.size(0),1,coordinates.size(2),
34 | coordinates.size(3),device = Ks.device) ],dim=1).permute(0,2,3,1).unsqueeze(-1)
35 |
36 |
37 | inv_Ks = torch.inverse(Ks)
38 |
39 | dirs = torch.matmul(inv_Ks,coordinates) #(M,H,W,3,1)
40 | dirs = dirs/torch.norm(dirs,dim=3,keepdim = True)
41 | dirs = torch.cat([dirs,torch.zeros(dirs.size(0),coordinates.size(1),
42 | coordinates.size(2),1,1,device = Ks.device) ],dim=3) #(M,H,W,4,1)
43 |
44 |
45 | dirs = torch.matmul(Ts,dirs) #(M,H,W,4,1)
46 | dirs = dirs[:,:,:,0:3,0] #(M,H,W,3)
47 |
48 | pos = Ts[:,0:3,3] #(M,3)
49 | pos = pos.unsqueeze(1).unsqueeze(1).repeat(1,h,w,1)
50 |
51 | if outlier_map is not None:
52 | ids = outlier_map.reshape([M,h,w,1])
53 | rays = torch.cat([pos,dirs,ids],dim = 3) #(M,H,W,7)
54 | else:
55 | rays = torch.cat([pos,dirs],dim = 3) #(M,H,W,6)
56 |
57 | if images is not None:
58 | rgbs = images.permute(0,2,3,1) #(M,H,W,C)
59 | else:
60 | rgbs = None
61 |
62 | if masks is not None:
63 | rays = rays[masks>mask_threshold,:]
64 | if rgbs is not None:
65 | rgbs = rgbs[masks>mask_threshold,:]
66 |
67 | else:
68 | rays = rays.reshape((-1,rays.size(3)))
69 | if rgbs is not None:
70 | rgbs = rgbs.reshape((-1, rgbs.size(3)))
71 |
72 | return rays,rgbs
73 |
74 | # Sample rays and labels with K,T and bbox
75 | def ray_sampling_label_bbox(image,label,K,T,bbox=None, bboxes=None):
76 |
77 | _,H,W = image.shape
78 |
79 | if bbox != None:
80 | bbox = bbox.reshape(8,3)
81 | bbox = torch.transpose(bbox,0,1) #(3,8)
82 | bbox = torch.cat([bbox,torch.ones(1,bbox.shape[1])],0)
83 | inv_T = torch.inverse(T)
84 |
85 | pts = torch.mm(inv_T,bbox)
86 |
87 | pts = pts[:3,:]
88 | pixels = torch.mm(K,pts)
89 | pixels = pixels / pixels[2,:]
90 | pixels = pixels[:2,:]
91 | temp = torch.zeros_like(pixels)
92 | temp[1,:] = pixels[0,:]
93 | temp[0,:] = pixels[1,:]
94 | pixels = temp
95 |
96 |
97 | min_pixel = torch.min(pixels, dim=1)[0]
98 | max_pixel = torch.max(pixels, dim=1)[0]
99 |
100 | # print(pixels)
101 | # print(min_pixel)
102 | # print(max_pixel)
103 |
104 | min_pixel[min_pixel < 0.0] = 0
105 | if min_pixel[0] >= H-1:
106 | min_pixel[0] = H-1
107 | if min_pixel[1] >= W-1:
108 | min_pixel[1] = W-1
109 |
110 | max_pixel[max_pixel < 0.0] = 0
111 | if max_pixel[0] >= H-1:
112 | max_pixel[0] = H-1
113 | if max_pixel[1] >= W-1:
114 | max_pixel[1] = W-1
115 |
116 | minh = int(min_pixel[0])
117 | minw = int(min_pixel[1])
118 | maxh = int(max_pixel[0])+1
119 | maxw = int(max_pixel[1])+1
120 | else:
121 | minh = 0
122 | minw = 0
123 | maxh = H
124 | maxw = W
125 |
126 | # print(max_pixel,min_pixel)
127 | # print(minh,maxh,minw,maxw)
128 |
129 | if minh == maxh or minw == maxw:
130 | print('Warning: there is a pointcloud cannot find right bbox')
131 |
132 | # minh = 0
133 | # minw = 0
134 | # maxh = H
135 | # maxw = W
136 | # image_cutted = image[:,minh:maxh,minw:maxw]
137 | # label_cutted = label[:,minh:maxh,minw:maxw]
138 |
139 | K = K.unsqueeze(0)
140 | T = T.unsqueeze(0)
141 | M = 1
142 |
143 |
144 | x = torch.linspace(0,H-1,steps=H,device = K.device )
145 | y = torch.linspace(0,W-1,steps=W,device = K.device )
146 |
147 | grid_x, grid_y = torch.meshgrid(x,y)
148 | coordinates = torch.stack([grid_y, grid_x]).unsqueeze(0).repeat(M,1,1,1) #(M,2,H,W)
149 | coordinates = torch.cat([coordinates,torch.ones(coordinates.size(0),1,coordinates.size(2),
150 | coordinates.size(3),device = K.device) ],dim=1).permute(0,2,3,1).unsqueeze(-1)
151 |
152 |
153 | inv_Ks = torch.inverse(K)
154 |
155 | dirs = torch.matmul(inv_Ks,coordinates) #(M,H,W,3,1)
156 | dirs = dirs/torch.norm(dirs,dim=3,keepdim = True)
157 | dirs = torch.cat([dirs,torch.zeros(dirs.size(0),coordinates.size(1),
158 | coordinates.size(2),1,1,device = K.device) ],dim=3) #(M,H,W,4,1)
159 |
160 |
161 | dirs = torch.matmul(T,dirs) #(M,H,W,4,1)
162 | dirs = dirs[:,:,:,0:3,0] #(M,H,W,3)
163 |
164 | pos = T[:,0:3,3] #(M,3)
165 | pos = pos.unsqueeze(1).unsqueeze(1).repeat(1,H,W,1)
166 | rays = torch.cat([pos,dirs],dim = 3)
167 |
168 | rays = rays[:,minh:maxh,minw:maxw,:] #(H',W',6)
169 | rays = rays.reshape((-1,rays.size(3)))
170 |
171 | ray_mask = torch.zeros_like(label)
172 | ray_mask[:,minh:maxh,minw:maxw] = 1.0
173 | ray_mask = ray_mask.permute(1,2,0)
174 |
175 | label = label[:,minh:maxh,minw:maxw].permute(1,2,0) #(H',W',1)
176 | image = image[:,minh:maxh,minw:maxw].permute(1,2,0) #(H',W',3)
177 |
178 |
179 | rays = rays.reshape(-1,6)
180 | label = label.reshape(-1,1) #(N,1)
181 | image = image.reshape(-1,3)
182 |
183 | if bboxes is not None:
184 | layered_bboxes = torch.zeros(rays.size(0),8,3)
185 | for i in range(len(bboxes)):
186 | idx = (label == i).squeeze() #(N,)
187 | layered_bboxes[idx] = bboxes[i]
188 |
189 | if bboxes is None:
190 | return rays, label, image, ray_mask
191 | else:
192 | return rays, label, image, ray_mask,layered_bboxes
193 |
194 | def ray_sampling_label_label(image,label,K,T,label0):
195 |
196 | _,H,W = image.shape
197 |
198 | K = K.unsqueeze(0)
199 | T = T.unsqueeze(0)
200 | M = 1
201 |
202 |
203 | x = torch.linspace(0,H-1,steps=H,device = K.device )
204 | y = torch.linspace(0,W-1,steps=W,device = K.device )
205 |
206 | grid_x, grid_y = torch.meshgrid(x,y)
207 | coordinates = torch.stack([grid_y, grid_x]).unsqueeze(0).repeat(M,1,1,1) #(M,2,H,W)
208 | coordinates = torch.cat([coordinates,torch.ones(coordinates.size(0),1,coordinates.size(2),
209 | coordinates.size(3),device = K.device) ],dim=1).permute(0,2,3,1).unsqueeze(-1)
210 |
211 |
212 | inv_Ks = torch.inverse(K)
213 |
214 | dirs = torch.matmul(inv_Ks,coordinates) #(M,H,W,3,1)
215 | dirs = dirs/torch.norm(dirs,dim=3,keepdim = True)
216 | dirs = torch.cat([dirs,torch.zeros(dirs.size(0),coordinates.size(1),
217 | coordinates.size(2),1,1,device = K.device) ],dim=3) #(M,H,W,4,1)
218 |
219 |
220 | dirs = torch.matmul(T,dirs) #(M,H,W,4,1)
221 | dirs = dirs[:,:,:,0:3,0] #(M,H,W,3)
222 |
223 | pos = T[:,0:3,3] #(M,3)
224 | pos = pos.unsqueeze(1).unsqueeze(1).repeat(1,H,W,1)
225 | rays = torch.cat([pos,dirs],dim = 3)
226 |
227 |
228 | ray_mask = torch.zeros_like(label)
229 | idx = (label == label0)
230 | ray_mask[idx] = 1.0
231 | ray_mask = ray_mask.permute(1,2,0)
232 |
233 | rays = rays[idx,:] #(N,6)
234 |
235 | label = label[idx] #(N)
236 | label = label.reshape(-1,1)
237 | image = image[:,idx.squeeze()].permute(1,0) #(N,3)
238 |
239 |
240 | return rays, label, image, ray_mask
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
--------------------------------------------------------------------------------
/utils/render_helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def lookat(eye,center,up):
6 | z = eye - center
7 | z /= np.sqrt(z.dot(z))
8 |
9 | y = up
10 | x = np.cross(y,z)
11 | y = np.cross(z,x)
12 |
13 | x /= np.sqrt(x.dot(x))
14 | y /= np.sqrt(y.dot(y))
15 |
16 | T = np.identity(4)
17 | T[0,:3] = x
18 | T[1,:3] = y
19 | T[2,:3] = z
20 | T[0,3] = -x.dot(eye)
21 | T[1,3] = -y.dot(eye)
22 | T[2,3] = -z.dot(eye)
23 | T[3,:] = np.array([0,0,0,1])
24 |
25 | # What we need is camera pose
26 | T = np.linalg.inv(T)
27 | T[:3,1] = -T[:3,1]
28 | T[:3,2] = -T[:3,2]
29 |
30 | return T
31 |
32 | # degree is True means using degree measure, else means using radian system
33 | def getSphericalPosition(r,theta,phi,degree=True):
34 | if degree:
35 | theta = theta / 180 * pi
36 | phi = phi / 180 * pi
37 | x = r * cos(theta) * sin(phi)
38 | z = r * cos(theta) * cos(phi)
39 | y = r * sin(theta)
40 | return np.array([x,y,z])
41 |
42 | def generate_rays(K, T, bbox, h, w):
43 |
44 | if bbox is not None:
45 | bbox = bbox.reshape(8,3)
46 | bbox = torch.transpose(bbox,0,1) #(3,8)
47 | bbox = torch.cat([bbox,torch.ones(1,bbox.shape[1])],0)
48 | inv_T = torch.inverse(T)
49 |
50 | pts = torch.mm(inv_T,bbox)
51 |
52 | pts = pts[:3,:]
53 | pixels = torch.mm(K,pts)
54 | pixels = pixels / pixels[2,:]
55 | pixels = pixels[:2,:]
56 | temp = torch.zeros_like(pixels)
57 | temp[1,:] = pixels[0,:]
58 | temp[0,:] = pixels[1,:]
59 | pixels = temp
60 |
61 | min_pixel = torch.min(pixels, dim=1)[0]
62 | max_pixel = torch.max(pixels, dim=1)[0]
63 |
64 | min_pixel[min_pixel < 0.0] = 0
65 | if min_pixel[0] >= h-1:
66 | min_pixel[0] = h-1
67 | if min_pixel[1] >= w-1:
68 | min_pixel[1] = w-1
69 |
70 | max_pixel[max_pixel < 0.0] = 0
71 | if max_pixel[0] >= h-1:
72 | max_pixel[0] = h-1
73 | if max_pixel[1] >= w-1:
74 | max_pixel[1] = w-1
75 |
76 | minh = int(min_pixel[0])
77 | minw = int(min_pixel[1])
78 | maxh = int(max_pixel[0])+1
79 | maxw = int(max_pixel[1])+1
80 | else:
81 | minh = 0
82 | minw = 0
83 | maxh = h
84 | maxw = w
85 |
86 | # print(max_pixel,min_pixel)
87 | # print(minh,maxh,minw,maxw)
88 |
89 | if minh == maxh or minw == maxw:
90 | print('Warning: there is a pointcloud cannot find right bbox')
91 |
92 | K = K.unsqueeze(0)
93 | T = T.unsqueeze(0)
94 | M = 1
95 |
96 | x = torch.linspace(0,h-1,steps=h,device = K.device )
97 | y = torch.linspace(0,w-1,steps=w,device = K.device )
98 |
99 | grid_x, grid_y = torch.meshgrid(x,y)
100 | coordinates = torch.stack([grid_y, grid_x]).unsqueeze(0).repeat(M,1,1,1) #(M,2,H,W)
101 | coordinates = torch.cat([coordinates,torch.ones(coordinates.size(0),1,coordinates.size(2),
102 | coordinates.size(3),device = K.device) ],dim=1).permute(0,2,3,1).unsqueeze(-1)
103 |
104 |
105 | inv_K = torch.inverse(K)
106 |
107 | dirs = torch.matmul(inv_K,coordinates) #(M,H,W,3,1)
108 | dirs = dirs/torch.norm(dirs,dim=3,keepdim = True)
109 | dirs = torch.cat([dirs,torch.zeros(dirs.size(0),coordinates.size(1),
110 | coordinates.size(2),1,1,device = K.device) ],dim=3) #(M,H,W,4,1)
111 |
112 |
113 | dirs = torch.matmul(T,dirs) #(M,H,W,4,1)
114 | dirs = dirs[:,:,:,0:3,0] #(M,H,W,3)
115 |
116 | pos = T[:,0:3,3] #(M,3)
117 | pos = pos.unsqueeze(1).unsqueeze(1).repeat(1,h,w,1)
118 |
119 | rays = torch.cat([pos,dirs],dim = 3) #(M,H,W,6)
120 |
121 | rays = rays[:,minh:maxh,minw:maxw,:] #(M,H',W',6)
122 |
123 | rays = rays.reshape((-1,rays.size(3)))
124 |
125 | ray_mask = torch.zeros(h,w,1)
126 | ray_mask[minh:maxh,minw:maxw,:] = 1.0
127 |
128 | return rays, ray_mask
--------------------------------------------------------------------------------
/utils/sample_pdf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | '''
4 | INPUT:
5 |
6 | z_vals: (N,L)
7 | weights: (N,L)
8 |
9 | OUPUT:
10 |
11 | samples_z: (N,L)
12 |
13 |
14 | '''
15 | torch.autograd.set_detect_anomaly(True)
16 |
17 |
18 | def sample_pdf(z_vals, weights, N_samples, det=False, pytest=False):
19 | # Get pdf
20 | bins = .5 * (z_vals[...,1:] + z_vals[...,:-1])
21 | weights = weights + 1e-5 # prevent nans
22 | pdf = weights / torch.sum(weights, -1, keepdim=True)
23 | cdf = torch.cumsum(pdf, -1)
24 | cdf = torch.cat([torch.zeros_like(cdf[...,:1], device = z_vals.device), cdf], -1) # (batch, len(bins))
25 |
26 | # Take uniform samples
27 | if det:
28 | u = torch.linspace(0., 1., steps=N_samples, device = z_vals.device)
29 | u = u.expand(list(cdf.shape[:-1]) + [N_samples])
30 | else:
31 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device = z_vals.device)
32 |
33 | # Pytest, overwrite u with numpy's fixed random numbers
34 | if pytest:
35 | np.random.seed(0)
36 | new_shape = list(cdf.shape[:-1]) + [N_samples]
37 | if det:
38 | u = np.linspace(0., 1., N_samples)
39 | u = np.broadcast_to(u, new_shape)
40 | else:
41 | u = np.random.rand(*new_shape)
42 | u = torch.Tensor(u)
43 |
44 | # Invert CDF
45 | u = u.contiguous()
46 |
47 | inds = torch.searchsorted(cdf, u, right = True)
48 | below = torch.max(torch.zeros_like(inds-1, device = inds.device), inds-1)
49 | above = torch.min(cdf.shape[-1]-1 * torch.ones_like(inds, device = inds.device), inds)
50 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
51 |
52 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
53 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
54 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
55 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
56 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
57 |
58 | denom = (cdf_g[...,1]-cdf_g[...,0])
59 | denom = torch.where(denom<1e-5, torch.ones_like(denom, device=denom.device), denom)
60 | t = (u-cdf_g[...,0])/denom
61 | samples_z = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
62 |
63 | return samples_z
64 |
--------------------------------------------------------------------------------
/utils/vis_density.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def vis_density(model,bbox, L= 32):
4 |
5 | maxs = torch.max(bbox, dim=0).values
6 | mins = torch.min(bbox, dim=0).values
7 |
8 |
9 | x = torch.linspace(mins[0],maxs[0],steps=L).cuda()
10 | y = torch.linspace(mins[1],maxs[1],steps=L).cuda()
11 | z = torch.linspace(mins[2],maxs[2],steps=L).cuda()
12 | grid_x ,grid_y,grid_z = torch.meshgrid(x, y,z)
13 | xyz = torch.stack([grid_x ,grid_y,grid_z], dim = -1) #(L,L,L,3)
14 |
15 | xyz = xyz.reshape((-1,3)) #(L*L*L,3)
16 |
17 |
18 | xyzs = xyz.split(5000, dim=0)
19 |
20 | sigmas = []
21 | for i in xyzs:
22 | with torch.no_grad():
23 | _,density = model.spacenet_fine(i, None, model.maxs, model.mins) #(L*L*L,1)
24 | density = torch.nn.functional.relu(density)
25 | sigmas.append(density.detach().cpu())
26 |
27 | sigmas = torch.cat(sigmas, dim=0)
28 |
29 | return sigmas
--------------------------------------------------------------------------------