├── .gitignore
├── LICENSE
├── README.md
├── config
├── stage1.yaml
├── stage2.yaml
└── stereo_human_config.py
├── core
├── __init__.py
├── corr.py
├── extractor.py
├── raft_stereo_human.py
├── update.py
└── utils
│ ├── __init__.py
│ ├── augmentor.py
│ ├── frame_utils.py
│ └── utils.py
├── environment.yml
├── gaussian_renderer
└── __init__.py
├── lib
├── GaussianRender.py
├── TaichiRender.py
├── graphics_utils.py
├── gs_parm_network.py
├── human_loader.py
├── loss.py
├── network.py
├── train_recoder.py
└── utils.py
├── prepare_data
├── MAKE_DATA.md
├── render_data.py
└── taichi_three
│ ├── __init__.py
│ ├── common.py
│ ├── geometry.py
│ ├── light.py
│ ├── loader.py
│ ├── meshgen.py
│ ├── model.py
│ ├── raycast.py
│ ├── scatter.py
│ ├── scene.py
│ ├── shading.py
│ ├── transform.py
│ └── version.py
├── test_real_data.py
├── test_view_interp.py
├── train_stage1.py
└── train_stage2.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | /experiments
3 | /test_out
4 | /interp_out
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Shunyuan Zheng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | #
GPS-Gaussian: Generalizable Pixel-wise 3D Gaussian Splatting for Real-time Human Novel View Synthesis
4 |
5 | [Shunyuan Zheng](https://shunyuanzheng.github.io)
†,1, [Boyao Zhou](https://yaourtb.github.io)
2, [Ruizhi Shao](https://dsaurus.github.io/saurus)
2, [Boning Liu](https://liuboning2.github.io)
2, [Shengping Zhang](http://homepage.hit.edu.cn/zhangshengping)
*,1,3, [Liqiang Nie](https://liqiangnie.github.io)
1, [Yebin Liu](https://www.liuyebin.com)
2
6 |
7 |
1Harbin Institute of Technology 2Tsinghua Univserity 3Peng Cheng Laboratory
8 |
*Corresponding author †Work done during an internship at Tsinghua Univserity
9 |
10 | ### [Projectpage](https://shunyuanzheng.github.io/GPS-Gaussian) · [Video](https://youtu.be/HjnBAqjGIAo) · [Paper](https://openaccess.thecvf.com/content/CVPR2024/papers/Zheng_GPS-Gaussian_Generalizable_Pixel-wise_3D_Gaussian_Splatting_for_Real-time_Human_Novel_CVPR_2024_paper.pdf) · [Supp.](https://openaccess.thecvf.com/content/CVPR2024/supplemental/Zheng_GPS-Gaussian_Generalizable_Pixel-wise_CVPR_2024_supplemental.pdf)
11 |
12 |
13 |
14 | ## Introduction
15 |
16 | We propose GPS-Gaussian, a generalizable pixel-wise 3D Gaussian representation for synthesizing novel views of any unseen characters instantly without any fine-tuning or optimization.
17 |
18 | https://github.com/ShunyuanZheng/GPS-Gaussian/assets/33752042/54a253ad-012a-448f-8303-168d80d3f594
19 |
20 | ## Installation
21 |
22 | To deploy and run GPS-Gaussian, run the following scripts:
23 | ```
24 | conda env create --file environment.yml
25 | conda activate gps_gaussian
26 | ```
27 | Then, compile the ```diff-gaussian-rasterization``` in [3DGS](https://github.com/graphdeco-inria/gaussian-splatting) repository:
28 | ```
29 | git clone https://github.com/graphdeco-inria/gaussian-splatting --recursive
30 | cd gaussian-splatting/
31 | pip install -e submodules/diff-gaussian-rasterization
32 | cd ..
33 | ```
34 | (optinal) [RAFT-Stereo](https://github.com/princeton-vl/RAFT-Stereo) provides a faster CUDA implementation of the correlation sampler to speed up the model without impacting performance:
35 | ```
36 | git clone https://github.com/princeton-vl/RAFT-Stereo.git
37 | cd RAFT-Stereo/sampler && python setup.py install && cd ../..
38 | ```
39 | If compiled this CUDA implementation, set ```corr_implementation='reg_cuda'``` in [config/stereo_human_config.py](config/stereo_human_config.py#L33) else ```corr_implementation='reg'```.
40 |
41 | ## Run on synthetic human dataset
42 |
43 | ### Dataset Preparation
44 | - We provide rendered THuman2.0 dataset for GPS-Gaussian training in 16-camera setting, download ```render_data``` from [Baidu Netdisk](https://pan.baidu.com/s/1sX9m8wRDSQAI9d78wST7mw?pwd=rax4) or [OneDrive](https://hiteducn0-my.sharepoint.com/:f:/g/personal/sawyer0503_hit_edu_cn/EkE2GFd2saBCh_XkY3TsoV0BVTmK1UiTTKJDYje3U3vdkw?e=YazWdd) and unzip it. Since we recommend rectifying the source images and determining the disparity in an offline manner, the saved files and the downloaded data necessity around 50GB of free storage space.
45 | - To train a more robust model, we recommend collecting more human scans for training (e.g. [Twindom](https://web.twindom.com), [Render People](https://renderpeople.com/), [2K2K](https://sanghunhan92.github.io/conference/2K2K/)). Then, render the training data as the target scenario, including the number of cameras and the radius of the scene. We provide the rendering code to generate training data from human scans, see [data documentation](prepare_data/MAKE_DATA.md) for more details.
46 |
47 | ### Training
48 | Note: At the first training time, we do stereo rectify and determine the disparity offline, the processed data will be saved at ```render_data/rectified_local```. This process takes several hours and can extremely speed up the following training scheme. If you want to skip this pre-processing, set ```use_processed_data=False``` in [stage1.yaml](config/stage1.yaml#L11) and [stage2.yaml](config/stage2.yaml#L15).
49 |
50 | - Stage1: pretrain the depth prediction model. Set ```data_root``` in [stage1.yaml](config/stage1.yaml#L12) to the path of unzipped folder ```render_data```.
51 | ```
52 | python train_stage1.py
53 | ```
54 |
55 | - Stage2: train the full model. Set ```data_root``` in [stage2.yaml](config/stage2.yaml#L16) to the path of unzipped folder ```render_data```, and set the correct pretrained stage1 model path ```stage1_ckpt``` in [stage2.yaml](config/stage2.yaml#L3)
56 | ```
57 | python train_stage2.py
58 | ```
59 | - We provide the pretrained model ```GPS-GS_stage2_final.pth``` in [Baidu Netdisk](https://pan.baidu.com/s/1sX9m8wRDSQAI9d78wST7mw?pwd=rax4) and [OneDrive](https://hiteducn0-my.sharepoint.com/:f:/g/personal/sawyer0503_hit_edu_cn/EkE2GFd2saBCh_XkY3TsoV0BVTmK1UiTTKJDYje3U3vdkw?e=YazWdd) for fast evaluation and testing.
60 |
61 | ### Testing
62 |
63 | - Real-world data: download the test data ```real_data``` from [Baidu Netdisk](https://pan.baidu.com/s/1sX9m8wRDSQAI9d78wST7mw?pwd=rax4) or [OneDrive](https://hiteducn0-my.sharepoint.com/:f:/g/personal/sawyer0503_hit_edu_cn/EkE2GFd2saBCh_XkY3TsoV0BVTmK1UiTTKJDYje3U3vdkw?e=YazWdd). Then, run the following code for synthesizing a fixed novel view between ```src_view``` 0 and 1, the position of novel viewpoint between source views is adjusted with a ```ratio``` ranging from 0 to 1.
64 | ```
65 | python test_real_data.py \
66 | --test_data_root 'PATH/TO/REAL_DATA' \
67 | --ckpt_path 'PATH/TO/GPS-GS_stage2_final.pth' \
68 | --src_view 0 1 \
69 | --ratio=0.5
70 | ```
71 |
72 | - Freeview rendering: run the following code to interpolate freeview between source views, and modify the ```novel_view_nums``` to set a specific number of novel viewpoints.
73 | ```
74 | python test_view_interp.py \
75 | --test_data_root 'PATH/TO/RENDER_DATA/val' \
76 | --ckpt_path 'PATH/TO/GPS-GS_stage2_final.pth' \
77 | --novel_view_nums 5
78 | ```
79 |
80 | # Citation
81 |
82 | If you find this code useful for your research, please consider citing:
83 | ```bibtex
84 | @inproceedings{zheng2024gpsgaussian,
85 | title={GPS-Gaussian: Generalizable Pixel-wise 3D Gaussian Splatting for Real-time Human Novel View Synthesis},
86 | author={Zheng, Shunyuan and Zhou, Boyao and Shao, Ruizhi and Liu, Boning and Zhang, Shengping and Nie, Liqiang and Liu, Yebin},
87 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
88 | year={2024}
89 | }
90 | ```
91 |
--------------------------------------------------------------------------------
/config/stage1.yaml:
--------------------------------------------------------------------------------
1 | name: 'GPS-GS_stage1'
2 |
3 | restore_ckpt: None
4 | lr: 0.0002
5 | wdecay: 1e-5
6 | batch_size: 6
7 | num_steps: 40000
8 |
9 | dataset:
10 | source_id: [0, 1]
11 | src_res: 1024
12 | use_processed_data: True
13 | data_root: 'PATH/TO/RENDER_DATA'
14 |
15 | raft:
16 | mixed_precision: False
17 | train_iters: 3
18 | val_iters: 3
19 | encoder_dims: [32, 48, 96]
20 | hidden_dims: [96, 96, 96]
21 |
22 | record:
23 | loss_freq: 2000
24 | eval_freq: 2000
25 |
--------------------------------------------------------------------------------
/config/stage2.yaml:
--------------------------------------------------------------------------------
1 | name: 'GPS-GS_stage2'
2 |
3 | stage1_ckpt: 'PATH/TO/GPS-GS_stage1_final.pth'
4 | restore_ckpt: None
5 | lr: 0.0002
6 | wdecay: 1e-5
7 | batch_size: 2
8 | num_steps: 100000
9 |
10 | dataset:
11 | source_id: [0, 1]
12 | train_novel_id: [2, 3, 4]
13 | val_novel_id: [3]
14 | src_res: 1024
15 | use_hr_img: True
16 | use_processed_data: True
17 | data_root: 'PATH/TO/RENDER_DATA'
18 |
19 | raft:
20 | mixed_precision: True
21 | train_iters: 3
22 | val_iters: 3
23 | encoder_dims: [32, 48, 96]
24 | hidden_dims: [96, 96, 96]
25 |
26 | gsnet:
27 | encoder_dims: [32, 48, 96]
28 | decoder_dims: [48, 64, 96]
29 | parm_head_dim: 32
30 |
31 | record:
32 | loss_freq: 5000
33 | eval_freq: 5000
34 |
--------------------------------------------------------------------------------
/config/stereo_human_config.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 |
4 | class ConfigStereoHuman:
5 | def __init__(self):
6 | self.cfg = CN()
7 | self.cfg.name = ''
8 | self.cfg.stage1_ckpt = None
9 | self.cfg.restore_ckpt = None
10 | self.cfg.lr = 0.0
11 | self.cfg.wdecay = 0.0
12 | self.cfg.batch_size = 0
13 | self.cfg.num_steps = 0
14 |
15 | self.cfg.dataset = CN()
16 | self.cfg.dataset.source_id = None
17 | self.cfg.dataset.train_novel_id = None
18 | self.cfg.dataset.val_novel_id = None
19 | self.cfg.dataset.src_res = None
20 | self.cfg.dataset.use_hr_img = None
21 | self.cfg.dataset.use_processed_data = None
22 | self.cfg.dataset.data_root = ''
23 | # gsussian render settings
24 | self.cfg.dataset.bg_color = [0, 0, 0]
25 | self.cfg.dataset.zfar = 100.0
26 | self.cfg.dataset.znear = 0.01
27 | self.cfg.dataset.trans = [0.0, 0.0, 0.0]
28 | self.cfg.dataset.scale = 1.0
29 |
30 | self.cfg.raft = CN()
31 | self.cfg.raft.mixed_precision = None
32 | self.cfg.raft.train_iters = 0
33 | self.cfg.raft.val_iters = 0
34 | self.cfg.raft.corr_implementation = 'reg_cuda' # or 'reg'
35 | self.cfg.raft.corr_levels = 4
36 | self.cfg.raft.corr_radius = 4
37 | self.cfg.raft.n_downsample = 3
38 | self.cfg.raft.n_gru_layers = 1
39 | self.cfg.raft.slow_fast_gru = None
40 | self.cfg.raft.encoder_dims = [64, 96, 128]
41 | self.cfg.raft.hidden_dims = [128]*3
42 |
43 | self.cfg.gsnet = CN()
44 | self.cfg.gsnet.encoder_dims = None
45 | self.cfg.gsnet.decoder_dims = None
46 | self.cfg.gsnet.parm_head_dim = None
47 |
48 | self.cfg.record = CN()
49 | self.cfg.record.ckpt_path = None
50 | self.cfg.record.show_path = None
51 | self.cfg.record.logs_path = None
52 | self.cfg.record.file_path = None
53 | self.cfg.record.loss_freq = 0
54 | self.cfg.record.eval_freq = 0
55 |
56 | def get_cfg(self):
57 | return self.cfg.clone()
58 |
59 | def load(self, config_file):
60 | self.cfg.defrost()
61 | self.cfg.merge_from_file(config_file)
62 | self.cfg.freeze()
63 |
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aipixel/GPS-Gaussian/0024776deee4824f270d4bb534a17ffd85f63cf2/core/__init__.py
--------------------------------------------------------------------------------
/core/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from core.utils.utils import bilinear_sampler
4 |
5 | try:
6 | import corr_sampler
7 | except:
8 | pass
9 |
10 | try:
11 | import alt_cuda_corr
12 | except:
13 | # alt_cuda_corr is not compiled
14 | pass
15 |
16 |
17 | class CorrSampler(torch.autograd.Function):
18 | @staticmethod
19 | def forward(ctx, volume, coords, radius):
20 | ctx.save_for_backward(volume,coords)
21 | ctx.radius = radius
22 | corr, = corr_sampler.forward(volume, coords, radius)
23 | return corr
24 | @staticmethod
25 | def backward(ctx, grad_output):
26 | volume, coords = ctx.saved_tensors
27 | grad_output = grad_output.contiguous()
28 | grad_volume, = corr_sampler.backward(volume, coords, grad_output, ctx.radius)
29 | return grad_volume, None, None
30 |
31 | class CorrBlockFast1D:
32 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
33 | self.num_levels = num_levels
34 | self.radius = radius
35 | self.corr_pyramid = []
36 | # all pairs correlation
37 | corr = CorrBlockFast1D.corr(fmap1, fmap2)
38 | batch, h1, w1, dim, w2 = corr.shape
39 | corr = corr.reshape(batch*h1*w1, dim, 1, w2)
40 | for i in range(self.num_levels):
41 | self.corr_pyramid.append(corr.view(batch, h1, w1, -1, w2//2**i))
42 | corr = F.avg_pool2d(corr, [1,2], stride=[1,2])
43 |
44 | def __call__(self, coords):
45 | out_pyramid = []
46 | bz, _, ht, wd = coords.shape
47 | coords = coords[:, [0]]
48 | for i in range(self.num_levels):
49 | corr = CorrSampler.apply(self.corr_pyramid[i].squeeze(3), coords/2**i, self.radius)
50 | out_pyramid.append(corr.view(bz, -1, ht, wd))
51 | return torch.cat(out_pyramid, dim=1)
52 |
53 | @staticmethod
54 | def corr(fmap1, fmap2):
55 | B, D, H, W1 = fmap1.shape
56 | _, _, _, W2 = fmap2.shape
57 | fmap1 = fmap1.view(B, D, H, W1)
58 | fmap2 = fmap2.view(B, D, H, W2)
59 | corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2)
60 | corr = corr.reshape(B, H, W1, 1, W2).contiguous()
61 | return corr / torch.sqrt(torch.tensor(D).float())
62 |
63 |
64 | class PytorchAlternateCorrBlock1D:
65 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
66 | self.num_levels = num_levels
67 | self.radius = radius
68 | self.corr_pyramid = []
69 | self.fmap1 = fmap1
70 | self.fmap2 = fmap2
71 |
72 | def corr(self, fmap1, fmap2, coords):
73 | B, D, H, W = fmap2.shape
74 | # map grid coordinates to [-1,1]
75 | xgrid, ygrid = coords.split([1,1], dim=-1)
76 | xgrid = 2*xgrid/(W-1) - 1
77 | ygrid = 2*ygrid/(H-1) - 1
78 |
79 | grid = torch.cat([xgrid, ygrid], dim=-1)
80 | output_corr = []
81 | for grid_slice in grid.unbind(3):
82 | fmapw_mini = F.grid_sample(fmap2, grid_slice, align_corners=True)
83 | corr = torch.sum(fmapw_mini * fmap1, dim=1)
84 | output_corr.append(corr)
85 | corr = torch.stack(output_corr, dim=1).permute(0,2,3,1)
86 |
87 | return corr / torch.sqrt(torch.tensor(D).float())
88 |
89 | def __call__(self, coords):
90 | r = self.radius
91 | coords = coords.permute(0, 2, 3, 1)
92 | batch, h1, w1, _ = coords.shape
93 | fmap1 = self.fmap1
94 | fmap2 = self.fmap2
95 | out_pyramid = []
96 | for i in range(self.num_levels):
97 | dx = torch.zeros(1)
98 | dy = torch.linspace(-r, r, 2*r+1)
99 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
100 | centroid_lvl = coords.reshape(batch, h1, w1, 1, 2).clone()
101 | centroid_lvl[...,0] = centroid_lvl[...,0] / 2**i
102 | coords_lvl = centroid_lvl + delta.view(-1, 2)
103 | corr = self.corr(fmap1, fmap2, coords_lvl)
104 | fmap2 = F.avg_pool2d(fmap2, [1, 2], stride=[1, 2])
105 | out_pyramid.append(corr)
106 | out = torch.cat(out_pyramid, dim=-1)
107 | return out.permute(0, 3, 1, 2).contiguous().float()
108 |
109 |
110 | class CorrBlock1D:
111 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
112 | self.num_levels = num_levels
113 | self.radius = radius
114 | self.corr_pyramid = []
115 |
116 | # all pairs correlation
117 | corr = CorrBlock1D.corr(fmap1, fmap2)
118 |
119 | batch, h1, w1, _, w2 = corr.shape
120 | corr = corr.reshape(batch*h1*w1, 1, 1, w2)
121 |
122 | self.corr_pyramid.append(corr)
123 | for i in range(self.num_levels):
124 | corr = F.avg_pool2d(corr, [1,2], stride=[1,2])
125 | self.corr_pyramid.append(corr)
126 |
127 | def __call__(self, coords):
128 | r = self.radius
129 | coords = coords[:, :1].permute(0, 2, 3, 1)
130 | batch, h1, w1, _ = coords.shape
131 |
132 | out_pyramid = []
133 | for i in range(self.num_levels):
134 | corr = self.corr_pyramid[i]
135 | dx = torch.linspace(-r, r, 2*r+1)
136 | dx = dx.view(2*r+1, 1).to(coords.device)
137 | x0 = dx + coords.reshape(batch*h1*w1, 1, 1, 1) / 2**i
138 | y0 = torch.zeros_like(x0)
139 |
140 | coords_lvl = torch.cat([x0,y0], dim=-1)
141 | corr = bilinear_sampler(corr, coords_lvl)
142 | corr = corr.view(batch, h1, w1, -1)
143 | out_pyramid.append(corr)
144 |
145 | out = torch.cat(out_pyramid, dim=-1)
146 | return out.permute(0, 3, 1, 2).contiguous().float()
147 |
148 | @staticmethod
149 | def corr(fmap1, fmap2):
150 | B, D, H, W1 = fmap1.shape
151 | _, _, _, W2 = fmap2.shape
152 | fmap1 = fmap1.view(B, D, H, W1)
153 | fmap2 = fmap2.view(B, D, H, W2)
154 | corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2)
155 | corr = corr.reshape(B, H, W1, 1, W2).contiguous()
156 | return corr / torch.sqrt(torch.tensor(D).float())
157 |
158 |
159 | class AlternateCorrBlock:
160 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
161 | raise NotImplementedError
162 | self.num_levels = num_levels
163 | self.radius = radius
164 |
165 | self.pyramid = [(fmap1, fmap2)]
166 | for i in range(self.num_levels):
167 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
168 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
169 | self.pyramid.append((fmap1, fmap2))
170 |
171 | def __call__(self, coords):
172 | coords = coords.permute(0, 2, 3, 1)
173 | B, H, W, _ = coords.shape
174 | dim = self.pyramid[0][0].shape[1]
175 |
176 | corr_list = []
177 | for i in range(self.num_levels):
178 | r = self.radius
179 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
180 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
181 |
182 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
183 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
184 | corr_list.append(corr.squeeze(1))
185 |
186 | corr = torch.stack(corr_list, dim=1)
187 | corr = corr.reshape(B, -1, H, W)
188 | return corr / torch.sqrt(torch.tensor(dim).float())
189 |
--------------------------------------------------------------------------------
/core/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8 | super(ResidualBlock, self).__init__()
9 |
10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | num_groups = planes // 8
15 |
16 | if norm_fn == 'group':
17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | if not (stride == 1 and in_planes == planes):
20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21 |
22 | elif norm_fn == 'batch':
23 | self.norm1 = nn.BatchNorm2d(planes)
24 | self.norm2 = nn.BatchNorm2d(planes)
25 | if not (stride == 1 and in_planes == planes):
26 | self.norm3 = nn.BatchNorm2d(planes)
27 |
28 | elif norm_fn == 'instance':
29 | self.norm1 = nn.InstanceNorm2d(planes)
30 | self.norm2 = nn.InstanceNorm2d(planes)
31 | if not (stride == 1 and in_planes == planes):
32 | self.norm3 = nn.InstanceNorm2d(planes)
33 |
34 | elif norm_fn == 'none':
35 | self.norm1 = nn.Sequential()
36 | self.norm2 = nn.Sequential()
37 | if not (stride == 1 and in_planes == planes):
38 | self.norm3 = nn.Sequential()
39 |
40 | if stride == 1 and in_planes == planes:
41 | self.downsample = None
42 |
43 | else:
44 | self.downsample = nn.Sequential(
45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46 |
47 |
48 | def forward(self, x):
49 | y = x
50 | y = self.conv1(y)
51 | y = self.norm1(y)
52 | y = self.relu(y)
53 | y = self.conv2(y)
54 | y = self.norm2(y)
55 | y = self.relu(y)
56 |
57 | if self.downsample is not None:
58 | x = self.downsample(x)
59 |
60 | return self.relu(x+y)
61 |
62 |
63 | class UnetExtractor(nn.Module):
64 | def __init__(self, in_channel=3, encoder_dim=[64, 96, 128], norm_fn='group'):
65 | super().__init__()
66 | self.in_ds = nn.Sequential(
67 | nn.Conv2d(in_channel, 32, kernel_size=5, stride=2, padding=2),
68 | nn.GroupNorm(num_groups=8, num_channels=32),
69 | nn.ReLU(inplace=True)
70 | )
71 |
72 | self.res1 = nn.Sequential(
73 | ResidualBlock(32, encoder_dim[0], norm_fn=norm_fn),
74 | ResidualBlock(encoder_dim[0], encoder_dim[0], norm_fn=norm_fn)
75 | )
76 | self.res2 = nn.Sequential(
77 | ResidualBlock(encoder_dim[0], encoder_dim[1], stride=2, norm_fn=norm_fn),
78 | ResidualBlock(encoder_dim[1], encoder_dim[1], norm_fn=norm_fn)
79 | )
80 | self.res3 = nn.Sequential(
81 | ResidualBlock(encoder_dim[1], encoder_dim[2], stride=2, norm_fn=norm_fn),
82 | ResidualBlock(encoder_dim[2], encoder_dim[2], norm_fn=norm_fn),
83 | )
84 |
85 | def forward(self, x):
86 | x = self.in_ds(x)
87 | x1 = self.res1(x)
88 | x2 = self.res2(x1)
89 | x3 = self.res3(x2)
90 |
91 | return x1, x2, x3
92 |
93 |
94 | class MultiBasicEncoder(nn.Module):
95 | def __init__(self, output_dim=[128], encoder_dim=[64, 96, 128]):
96 | super(MultiBasicEncoder, self).__init__()
97 |
98 | # output convolution for feature
99 | self.conv2 = nn.Sequential(
100 | ResidualBlock(encoder_dim[2], encoder_dim[2], stride=1),
101 | nn.Conv2d(encoder_dim[2], encoder_dim[2]*2, 3, padding=1))
102 |
103 | # output convolution for context
104 | output_list = []
105 | for dim in output_dim:
106 | conv_out = nn.Sequential(
107 | ResidualBlock(encoder_dim[2], encoder_dim[2], stride=1),
108 | nn.Conv2d(encoder_dim[2], dim[2], 3, padding=1))
109 | output_list.append(conv_out)
110 |
111 | self.outputs08 = nn.ModuleList(output_list)
112 |
113 | def forward(self, x):
114 | feat1, feat2 = self.conv2(x).split(dim=0, split_size=x.shape[0]//2)
115 |
116 | outputs08 = [f(x) for f in self.outputs08]
117 | return outputs08, feat1, feat2
118 |
119 |
120 | if __name__ == '__main__':
121 |
122 | data = torch.ones((1, 3, 1024, 1024))
123 |
124 | model = UnetExtractor(in_channel=3, encoder_dim=[64, 96, 128])
125 |
126 | x1, x2, x3 = model(data)
127 | print(x1.shape, x2.shape, x3.shape)
128 |
--------------------------------------------------------------------------------
/core/raft_stereo_human.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from core.update import BasicMultiUpdateBlock
5 | from core.extractor import MultiBasicEncoder
6 | from core.corr import CorrBlock1D, CorrBlockFast1D
7 | from core.utils.utils import coords_grid, downflow8
8 | from torch.cuda.amp import autocast as autocast
9 |
10 |
11 | class RAFTStereoHuman(nn.Module):
12 | def __init__(self, args):
13 | super().__init__()
14 | self.args = args
15 |
16 | context_dims = args.hidden_dims
17 | self.cnet = MultiBasicEncoder(output_dim=[args.hidden_dims, context_dims], encoder_dim=args.encoder_dims)
18 | self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])
19 | self.update_module = FlowUpdateModule(self.args)
20 |
21 | def freeze_bn(self):
22 | for m in self.modules():
23 | if isinstance(m, nn.BatchNorm2d):
24 | m.eval()
25 |
26 | def forward(self, image_pair, iters=12, flow_init=None, test_mode=False):
27 | """ Estimate optical flow between pair of frames """
28 |
29 | if flow_init is not None:
30 | flow_init = downflow8(flow_init)
31 | flow_init = torch.cat([flow_init, torch.zeros_like(flow_init)], dim=1)
32 |
33 | # run the context network
34 | with autocast(enabled=self.args.mixed_precision):
35 | *cnet_list, fmap1, fmap2 = self.cnet(image_pair)
36 | fmap12 = torch.cat((fmap1, fmap2), dim=0)
37 | fmap21 = torch.cat((fmap2, fmap1), dim=0)
38 |
39 | net_list = [torch.tanh(x[0]) for x in cnet_list]
40 | inp_list = [torch.relu(x[1]) for x in cnet_list]
41 |
42 | # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning
43 | inp_list = [list(conv(i).split(split_size=conv.out_channels // 3, dim=1)) for i, conv in zip(inp_list, self.context_zqr_convs)]
44 |
45 | # run update module
46 | flow_pred = self.update_module(fmap12, fmap21, net_list, inp_list, iters, flow_init, test_mode)
47 |
48 | if not test_mode:
49 | return flow_pred
50 | else:
51 | return flow_pred.split(dim=0, split_size=flow_pred.shape[0]//2)
52 |
53 |
54 | class FlowUpdateModule(nn.Module):
55 | def __init__(self, args):
56 | super().__init__()
57 | self.args = args
58 | self.update_block = BasicMultiUpdateBlock(self.args, hidden_dims=args.hidden_dims)
59 |
60 | def initialize_flow(self, img):
61 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
62 | N, _, H, W = img.shape
63 |
64 | coords0 = coords_grid(N, H, W).to(img.device)
65 | coords1 = coords_grid(N, H, W).to(img.device)
66 |
67 | return coords0, coords1
68 |
69 | def upsample_flow(self, flow, mask):
70 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
71 | N, D, H, W = flow.shape
72 | factor = 2 ** self.args.n_downsample
73 | mask = mask.view(N, 1, 9, factor, factor, H, W)
74 | mask = torch.softmax(mask, dim=2)
75 |
76 | up_flow = F.unfold(factor * flow, [3,3], padding=1)
77 | up_flow = up_flow.view(N, D, 9, 1, 1, H, W)
78 |
79 | up_flow = torch.sum(mask * up_flow, dim=2)
80 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
81 | return up_flow.reshape(N, D, factor*H, factor*W)
82 |
83 | def forward(self, fmap1, fmap2, net_list, inp_list, iters=12, flow_init=None, test_mode=False):
84 | if self.args.corr_implementation == "reg": # Default
85 | corr_block = CorrBlock1D
86 | fmap1, fmap2 = fmap1.float(), fmap2.float()
87 | elif self.args.corr_implementation == "reg_cuda": # Faster version of reg
88 | corr_block = CorrBlockFast1D
89 | corr_fn = corr_block(fmap1, fmap2, radius=self.args.corr_radius, num_levels=self.args.corr_levels)
90 |
91 | coords0, coords1 = self.initialize_flow(net_list[0])
92 |
93 | if flow_init is not None:
94 | coords1 = coords1 + flow_init
95 |
96 | flow_predictions = []
97 | for itr in range(iters):
98 | coords1 = coords1.detach()
99 | corr = corr_fn(coords1) # index correlation volume
100 | flow = coords1 - coords0
101 | with autocast(enabled=self.args.mixed_precision):
102 | if self.args.n_gru_layers == 3 and self.args.slow_fast_gru: # Update low-res GRU
103 | net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False)
104 | if self.args.n_gru_layers >= 2 and self.args.slow_fast_gru: # Update low-res GRU and mid-res GRU
105 | net_list = self.update_block(net_list, inp_list, iter32=self.args.n_gru_layers==3, iter16=True, iter08=False, update=False)
106 | net_list, up_mask, delta_flow = self.update_block(net_list, inp_list, corr, flow, iter32=self.args.n_gru_layers==3, iter16=self.args.n_gru_layers>=2)
107 |
108 | # in stereo mode, project flow onto epipolar
109 | delta_flow[:, 1] = 0.0
110 |
111 | # F(t+1) = F(t) + \Delta(t)
112 | coords1 = coords1 + delta_flow
113 |
114 | # We do not need to upsample or output intermediate results in test_mode
115 | if test_mode and itr < iters-1:
116 | continue
117 |
118 | # upsample predictions
119 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
120 | flow_up = flow_up[:, :1]
121 |
122 | flow_predictions.append(flow_up)
123 |
124 | if test_mode:
125 | return flow_up
126 |
127 | return flow_predictions
128 |
--------------------------------------------------------------------------------
/core/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | # from opt_einsum import contract
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256, output_dim=2):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim, input_dim, kernel_size=3):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
22 |
23 | def forward(self, h, cz, cr, cq, *x_list):
24 | x = torch.cat(x_list, dim=1)
25 | hx = torch.cat([h, x], dim=1)
26 |
27 | z = torch.sigmoid(self.convz(hx) + cz)
28 | r = torch.sigmoid(self.convr(hx) + cr)
29 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)) + cq)
30 |
31 | h = (1-z) * h + z * q
32 | return h
33 |
34 | class SepConvGRU(nn.Module):
35 | def __init__(self, hidden_dim=128, input_dim=192+128):
36 | super(SepConvGRU, self).__init__()
37 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
40 |
41 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
44 |
45 |
46 | def forward(self, h, *x):
47 | # horizontal
48 | x = torch.cat(x, dim=1)
49 | hx = torch.cat([h, x], dim=1)
50 | z = torch.sigmoid(self.convz1(hx))
51 | r = torch.sigmoid(self.convr1(hx))
52 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
53 | h = (1-z) * h + z * q
54 |
55 | # vertical
56 | hx = torch.cat([h, x], dim=1)
57 | z = torch.sigmoid(self.convz2(hx))
58 | r = torch.sigmoid(self.convr2(hx))
59 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
60 | h = (1-z) * h + z * q
61 |
62 | return h
63 |
64 | class BasicMotionEncoder(nn.Module):
65 | def __init__(self, args):
66 | super(BasicMotionEncoder, self).__init__()
67 | self.args = args
68 |
69 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)
70 |
71 | self.convc1 = nn.Conv2d(cor_planes, 64, 1, padding=0)
72 | self.convc2 = nn.Conv2d(64, 64, 3, padding=1)
73 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
74 | self.convf2 = nn.Conv2d(64, 64, 3, padding=1)
75 | self.conv = nn.Conv2d(64+64, 128-2, 3, padding=1)
76 |
77 | def forward(self, flow, corr):
78 | cor = F.relu(self.convc1(corr))
79 | cor = F.relu(self.convc2(cor))
80 | flo = F.relu(self.convf1(flow))
81 | flo = F.relu(self.convf2(flo))
82 |
83 | cor_flo = torch.cat([cor, flo], dim=1)
84 | out = F.relu(self.conv(cor_flo))
85 | return torch.cat([out, flow], dim=1)
86 |
87 | def pool2x(x):
88 | return F.avg_pool2d(x, 3, stride=2, padding=1)
89 |
90 | def pool4x(x):
91 | return F.avg_pool2d(x, 5, stride=4, padding=1)
92 |
93 | def interp(x, dest):
94 | interp_args = {'mode': 'bilinear', 'align_corners': True}
95 | return F.interpolate(x, dest.shape[2:], **interp_args)
96 |
97 | class BasicMultiUpdateBlock(nn.Module):
98 | def __init__(self, args, hidden_dims=[]):
99 | super().__init__()
100 | self.args = args
101 | self.encoder = BasicMotionEncoder(args)
102 | encoder_output_dim = 128
103 |
104 | self.gru08 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (args.n_gru_layers > 1))
105 | self.gru16 = ConvGRU(hidden_dims[1], hidden_dims[0] * (args.n_gru_layers == 3) + hidden_dims[2])
106 | self.gru32 = ConvGRU(hidden_dims[0], hidden_dims[1])
107 | self.flow_head = FlowHead(hidden_dims[2], hidden_dim=256, output_dim=2)
108 | factor = 2**self.args.n_downsample
109 |
110 | self.mask = nn.Sequential(
111 | nn.Conv2d(hidden_dims[2], 256, 3, padding=1),
112 | nn.ReLU(inplace=True),
113 | nn.Conv2d(256, (factor**2)*9, 1, padding=0))
114 |
115 | def forward(self, net, inp, corr=None, flow=None, iter08=True, iter16=True, iter32=True, update=True):
116 |
117 | if iter32:
118 | net[2] = self.gru32(net[2], *(inp[2]), pool2x(net[1]))
119 | if iter16:
120 | if self.args.n_gru_layers > 2:
121 | net[1] = self.gru16(net[1], *(inp[1]), pool2x(net[0]), interp(net[2], net[1]))
122 | else:
123 | net[1] = self.gru16(net[1], *(inp[1]), pool2x(net[0]))
124 | if iter08:
125 | motion_features = self.encoder(flow, corr)
126 | if self.args.n_gru_layers > 1:
127 | net[0] = self.gru08(net[0], *(inp[0]), motion_features, interp(net[1], net[0]))
128 | else:
129 | net[0] = self.gru08(net[0], *(inp[0]), motion_features)
130 |
131 | if not update:
132 | return net
133 |
134 | delta_flow = self.flow_head(net[0])
135 |
136 | # scale mask to balence gradients
137 | mask = .25 * self.mask(net[0])
138 | return net, mask, delta_flow
139 |
--------------------------------------------------------------------------------
/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aipixel/GPS-Gaussian/0024776deee4824f270d4bb534a17ffd85f63cf2/core/utils/__init__.py
--------------------------------------------------------------------------------
/core/utils/augmentor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import warnings
4 | import os
5 | import time
6 | from glob import glob
7 | from skimage import color, io
8 | from PIL import Image
9 |
10 | import cv2
11 | cv2.setNumThreads(0)
12 | cv2.ocl.setUseOpenCL(False)
13 |
14 | import torch
15 | from torchvision.transforms import ColorJitter, functional, Compose
16 | import torch.nn.functional as F
17 |
18 | def get_middlebury_images():
19 | root = "datasets/Middlebury/MiddEval3"
20 | with open(os.path.join(root, "official_train.txt"), 'r') as f:
21 | lines = f.read().splitlines()
22 | return sorted([os.path.join(root, 'trainingQ', f'{name}/im0.png') for name in lines])
23 |
24 | def get_eth3d_images():
25 | return sorted(glob('datasets/ETH3D/two_view_training/*/im0.png'))
26 |
27 | def get_kitti_images():
28 | return sorted(glob('datasets/KITTI/training/image_2/*_10.png'))
29 |
30 | def transfer_color(image, style_mean, style_stddev):
31 | reference_image_lab = color.rgb2lab(image)
32 | reference_stddev = np.std(reference_image_lab, axis=(0,1), keepdims=True)# + 1
33 | reference_mean = np.mean(reference_image_lab, axis=(0,1), keepdims=True)
34 |
35 | reference_image_lab = reference_image_lab - reference_mean
36 | lamb = style_stddev/reference_stddev
37 | style_image_lab = lamb * reference_image_lab
38 | output_image_lab = style_image_lab + style_mean
39 | l, a, b = np.split(output_image_lab, 3, axis=2)
40 | l = l.clip(0, 100)
41 | output_image_lab = np.concatenate((l,a,b), axis=2)
42 | with warnings.catch_warnings():
43 | warnings.simplefilter("ignore", category=UserWarning)
44 | output_image_rgb = color.lab2rgb(output_image_lab) * 255
45 | return output_image_rgb
46 |
47 | class AdjustGamma(object):
48 |
49 | def __init__(self, gamma_min, gamma_max, gain_min=1.0, gain_max=1.0):
50 | self.gamma_min, self.gamma_max, self.gain_min, self.gain_max = gamma_min, gamma_max, gain_min, gain_max
51 |
52 | def __call__(self, sample):
53 | gain = random.uniform(self.gain_min, self.gain_max)
54 | gamma = random.uniform(self.gamma_min, self.gamma_max)
55 | return functional.adjust_gamma(sample, gamma, gain)
56 |
57 | def __repr__(self):
58 | return f"Adjust Gamma {self.gamma_min}, ({self.gamma_max}) and Gain ({self.gain_min}, {self.gain_max})"
59 |
60 | class FlowAugmentor:
61 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True, yjitter=False, saturation_range=[0.6,1.4], gamma=[1,1,1,1]):
62 |
63 | # spatial augmentation params
64 | self.crop_size = crop_size
65 | self.min_scale = min_scale
66 | self.max_scale = max_scale
67 | self.spatial_aug_prob = 1.0
68 | self.stretch_prob = 0.8
69 | self.max_stretch = 0.2
70 |
71 | # flip augmentation params
72 | self.yjitter = yjitter
73 | self.do_flip = do_flip
74 | self.h_flip_prob = 0.5
75 | self.v_flip_prob = 0.1
76 |
77 | # photometric augmentation params
78 | self.photo_aug = Compose([ColorJitter(brightness=0.4, contrast=0.4, saturation=saturation_range, hue=0.5/3.14), AdjustGamma(*gamma)])
79 | self.asymmetric_color_aug_prob = 0.2
80 | self.eraser_aug_prob = 0.5
81 |
82 | def color_transform(self, img1, img2):
83 | """ Photometric augmentation """
84 |
85 | # asymmetric
86 | if np.random.rand() < self.asymmetric_color_aug_prob:
87 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
88 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
89 |
90 | # symmetric
91 | else:
92 | image_stack = np.concatenate([img1, img2], axis=0)
93 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
94 | img1, img2 = np.split(image_stack, 2, axis=0)
95 |
96 | return img1, img2
97 |
98 | def eraser_transform(self, img1, img2, bounds=[50, 100]):
99 | """ Occlusion augmentation """
100 |
101 | ht, wd = img1.shape[:2]
102 | if np.random.rand() < self.eraser_aug_prob:
103 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
104 | for _ in range(np.random.randint(1, 3)):
105 | x0 = np.random.randint(0, wd)
106 | y0 = np.random.randint(0, ht)
107 | dx = np.random.randint(bounds[0], bounds[1])
108 | dy = np.random.randint(bounds[0], bounds[1])
109 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
110 |
111 | return img1, img2
112 |
113 | def spatial_transform(self, img1, img2, flow):
114 | # randomly sample scale
115 | ht, wd = img1.shape[:2]
116 | min_scale = np.maximum(
117 | (self.crop_size[0] + 8) / float(ht),
118 | (self.crop_size[1] + 8) / float(wd))
119 |
120 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
121 | scale_x = scale
122 | scale_y = scale
123 | if np.random.rand() < self.stretch_prob:
124 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
125 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
126 |
127 | scale_x = np.clip(scale_x, min_scale, None)
128 | scale_y = np.clip(scale_y, min_scale, None)
129 |
130 | if np.random.rand() < self.spatial_aug_prob:
131 | # rescale the images
132 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
133 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
134 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
135 | flow = flow * [scale_x, scale_y]
136 |
137 | if self.do_flip:
138 | if np.random.rand() < self.h_flip_prob and self.do_flip == 'hf': # h-flip
139 | img1 = img1[:, ::-1]
140 | img2 = img2[:, ::-1]
141 | flow = flow[:, ::-1] * [-1.0, 1.0]
142 |
143 | if np.random.rand() < self.h_flip_prob and self.do_flip == 'h': # h-flip for stereo
144 | tmp = img1[:, ::-1]
145 | img1 = img2[:, ::-1]
146 | img2 = tmp
147 |
148 | if np.random.rand() < self.v_flip_prob and self.do_flip == 'v': # v-flip
149 | img1 = img1[::-1, :]
150 | img2 = img2[::-1, :]
151 | flow = flow[::-1, :] * [1.0, -1.0]
152 |
153 | if self.yjitter:
154 | y0 = np.random.randint(2, img1.shape[0] - self.crop_size[0] - 2)
155 | x0 = np.random.randint(2, img1.shape[1] - self.crop_size[1] - 2)
156 |
157 | y1 = y0 + np.random.randint(-2, 2 + 1)
158 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
159 | img2 = img2[y1:y1+self.crop_size[0], x0:x0+self.crop_size[1]]
160 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
161 |
162 | else:
163 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
164 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
165 |
166 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
167 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
168 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
169 |
170 | return img1, img2, flow
171 |
172 |
173 | def __call__(self, img1, img2, flow):
174 | img1, img2 = self.color_transform(img1, img2)
175 | img1, img2 = self.eraser_transform(img1, img2)
176 | img1, img2, flow = self.spatial_transform(img1, img2, flow)
177 |
178 | img1 = np.ascontiguousarray(img1)
179 | img2 = np.ascontiguousarray(img2)
180 | flow = np.ascontiguousarray(flow)
181 |
182 | return img1, img2, flow
183 |
184 | class SparseFlowAugmentor:
185 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False, yjitter=False, saturation_range=[0.7,1.3], gamma=[1,1,1,1]):
186 | # spatial augmentation params
187 | self.crop_size = crop_size
188 | self.min_scale = min_scale
189 | self.max_scale = max_scale
190 | self.spatial_aug_prob = 0.8
191 | self.stretch_prob = 0.8
192 | self.max_stretch = 0.2
193 |
194 | # flip augmentation params
195 | self.do_flip = do_flip
196 | self.h_flip_prob = 0.5
197 | self.v_flip_prob = 0.1
198 |
199 | # photometric augmentation params
200 | self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)])
201 | self.asymmetric_color_aug_prob = 0.2
202 | self.eraser_aug_prob = 0.5
203 |
204 | def color_transform(self, img1, img2):
205 | image_stack = np.concatenate([img1, img2], axis=0)
206 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
207 | img1, img2 = np.split(image_stack, 2, axis=0)
208 | return img1, img2
209 |
210 | def eraser_transform(self, img1, img2):
211 | ht, wd = img1.shape[:2]
212 | if np.random.rand() < self.eraser_aug_prob:
213 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
214 | for _ in range(np.random.randint(1, 3)):
215 | x0 = np.random.randint(0, wd)
216 | y0 = np.random.randint(0, ht)
217 | dx = np.random.randint(50, 100)
218 | dy = np.random.randint(50, 100)
219 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
220 |
221 | return img1, img2
222 |
223 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
224 | ht, wd = flow.shape[:2]
225 | coords = np.meshgrid(np.arange(wd), np.arange(ht))
226 | coords = np.stack(coords, axis=-1)
227 |
228 | coords = coords.reshape(-1, 2).astype(np.float32)
229 | flow = flow.reshape(-1, 2).astype(np.float32)
230 | valid = valid.reshape(-1).astype(np.float32)
231 |
232 | coords0 = coords[valid>=1]
233 | flow0 = flow[valid>=1]
234 |
235 | ht1 = int(round(ht * fy))
236 | wd1 = int(round(wd * fx))
237 |
238 | coords1 = coords0 * [fx, fy]
239 | flow1 = flow0 * [fx, fy]
240 |
241 | xx = np.round(coords1[:,0]).astype(np.int32)
242 | yy = np.round(coords1[:,1]).astype(np.int32)
243 |
244 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
245 | xx = xx[v]
246 | yy = yy[v]
247 | flow1 = flow1[v]
248 |
249 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
250 | valid_img = np.zeros([ht1, wd1], dtype=np.int32)
251 |
252 | flow_img[yy, xx] = flow1
253 | valid_img[yy, xx] = 1
254 |
255 | return flow_img, valid_img
256 |
257 | def spatial_transform(self, img1, img2, flow, valid):
258 | # randomly sample scale
259 |
260 | ht, wd = img1.shape[:2]
261 | min_scale = np.maximum(
262 | (self.crop_size[0] + 1) / float(ht),
263 | (self.crop_size[1] + 1) / float(wd))
264 |
265 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
266 | scale_x = np.clip(scale, min_scale, None)
267 | scale_y = np.clip(scale, min_scale, None)
268 |
269 | if np.random.rand() < self.spatial_aug_prob:
270 | # rescale the images
271 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
272 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
273 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
274 |
275 | if self.do_flip:
276 | if np.random.rand() < self.h_flip_prob and self.do_flip == 'hf': # h-flip
277 | img1 = img1[:, ::-1]
278 | img2 = img2[:, ::-1]
279 | flow = flow[:, ::-1] * [-1.0, 1.0]
280 |
281 | if np.random.rand() < self.h_flip_prob and self.do_flip == 'h': # h-flip for stereo
282 | tmp = img1[:, ::-1]
283 | img1 = img2[:, ::-1]
284 | img2 = tmp
285 |
286 | if np.random.rand() < self.v_flip_prob and self.do_flip == 'v': # v-flip
287 | img1 = img1[::-1, :]
288 | img2 = img2[::-1, :]
289 | flow = flow[::-1, :] * [1.0, -1.0]
290 |
291 | margin_y = 20
292 | margin_x = 50
293 |
294 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
295 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
296 |
297 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
298 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
299 |
300 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
301 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
302 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
303 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
304 | return img1, img2, flow, valid
305 |
306 |
307 | def __call__(self, img1, img2, flow, valid):
308 | img1, img2 = self.color_transform(img1, img2)
309 | img1, img2 = self.eraser_transform(img1, img2)
310 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
311 |
312 | img1 = np.ascontiguousarray(img1)
313 | img2 = np.ascontiguousarray(img2)
314 | flow = np.ascontiguousarray(flow)
315 | valid = np.ascontiguousarray(valid)
316 |
317 | return img1, img2, flow, valid
318 |
--------------------------------------------------------------------------------
/core/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 | import json
6 | import imageio
7 | import cv2
8 | cv2.setNumThreads(0)
9 | cv2.ocl.setUseOpenCL(False)
10 |
11 | TAG_CHAR = np.array([202021.25], np.float32)
12 |
13 | def readFlow(fn):
14 | """ Read .flo file in Middlebury format"""
15 | # Code adapted from:
16 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
17 |
18 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
19 | # print 'fn = %s'%(fn)
20 | with open(fn, 'rb') as f:
21 | magic = np.fromfile(f, np.float32, count=1)
22 | if 202021.25 != magic:
23 | print('Magic number incorrect. Invalid .flo file')
24 | return None
25 | else:
26 | w = np.fromfile(f, np.int32, count=1)
27 | h = np.fromfile(f, np.int32, count=1)
28 | # print 'Reading %d x %d flo file\n' % (w, h)
29 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
30 | # Reshape data into 3D array (columns, rows, bands)
31 | # The reshape here is for visualization, the original code is (w,h,2)
32 | return np.resize(data, (int(h), int(w), 2))
33 |
34 | def readPFM(file):
35 | file = open(file, 'rb')
36 |
37 | color = None
38 | width = None
39 | height = None
40 | scale = None
41 | endian = None
42 |
43 | header = file.readline().rstrip()
44 | if header == b'PF':
45 | color = True
46 | elif header == b'Pf':
47 | color = False
48 | else:
49 | raise Exception('Not a PFM file.')
50 |
51 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
52 | if dim_match:
53 | width, height = map(int, dim_match.groups())
54 | else:
55 | raise Exception('Malformed PFM header.')
56 |
57 | scale = float(file.readline().rstrip())
58 | if scale < 0: # little-endian
59 | endian = '<'
60 | scale = -scale
61 | else:
62 | endian = '>' # big-endian
63 |
64 | data = np.fromfile(file, endian + 'f')
65 | shape = (height, width, 3) if color else (height, width)
66 |
67 | data = np.reshape(data, shape)
68 | data = np.flipud(data)
69 | return data
70 |
71 | def writePFM(file, array):
72 | import os
73 | assert type(file) is str and type(array) is np.ndarray and \
74 | os.path.splitext(file)[1] == ".pfm"
75 | with open(file, 'wb') as f:
76 | H, W = array.shape
77 | headers = ["Pf\n", f"{W} {H}\n", "-1\n"]
78 | for header in headers:
79 | f.write(str.encode(header))
80 | array = np.flip(array, axis=0).astype(np.float32)
81 | f.write(array.tobytes())
82 |
83 |
84 |
85 | def writeFlow(filename,uv,v=None):
86 | """ Write optical flow to file.
87 |
88 | If v is None, uv is assumed to contain both u and v channels,
89 | stacked in depth.
90 | Original code by Deqing Sun, adapted from Daniel Scharstein.
91 | """
92 | nBands = 2
93 |
94 | if v is None:
95 | assert(uv.ndim == 3)
96 | assert(uv.shape[2] == 2)
97 | u = uv[:,:,0]
98 | v = uv[:,:,1]
99 | else:
100 | u = uv
101 |
102 | assert(u.shape == v.shape)
103 | height,width = u.shape
104 | f = open(filename,'wb')
105 | # write the header
106 | f.write(TAG_CHAR)
107 | np.array(width).astype(np.int32).tofile(f)
108 | np.array(height).astype(np.int32).tofile(f)
109 | # arrange into matrix form
110 | tmp = np.zeros((height, width*nBands))
111 | tmp[:,np.arange(width)*2] = u
112 | tmp[:,np.arange(width)*2 + 1] = v
113 | tmp.astype(np.float32).tofile(f)
114 | f.close()
115 |
116 |
117 | def readFlowKITTI(filename):
118 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
119 | flow = flow[:,:,::-1].astype(np.float32)
120 | flow, valid = flow[:, :, :2], flow[:, :, 2]
121 | flow = (flow - 2**15) / 64.0
122 | return flow, valid
123 |
124 | def readDispKITTI(filename):
125 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
126 | valid = disp > 0.0
127 | return disp, valid
128 |
129 | # Method taken from /n/fs/raft-depth/RAFT-Stereo/datasets/SintelStereo/sdk/python/sintel_io.py
130 | def readDispSintelStereo(file_name):
131 | a = np.array(Image.open(file_name))
132 | d_r, d_g, d_b = np.split(a, axis=2, indices_or_sections=3)
133 | disp = (d_r * 4 + d_g / (2**6) + d_b / (2**14))[..., 0]
134 | mask = np.array(Image.open(file_name.replace('disparities', 'occlusions')))
135 | valid = ((mask == 0) & (disp > 0))
136 | return disp, valid
137 |
138 | # Method taken from https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
139 | def readDispFallingThings(file_name):
140 | a = np.array(Image.open(file_name))
141 | with open('/'.join(file_name.split('/')[:-1] + ['_camera_settings.json']), 'r') as f:
142 | intrinsics = json.load(f)
143 | fx = intrinsics['camera_settings'][0]['intrinsic_settings']['fx']
144 | disp = (fx * 6.0 * 100) / a.astype(np.float32)
145 | valid = disp > 0
146 | return disp, valid
147 |
148 | # Method taken from https://github.com/castacks/tartanair_tools/blob/master/data_type.md
149 | def readDispTartanAir(file_name):
150 | depth = np.load(file_name)
151 | disp = 80.0 / depth
152 | valid = disp > 0
153 | return disp, valid
154 |
155 |
156 | def readDispMiddlebury(file_name):
157 | if basename(file_name) == 'disp0GT.pfm':
158 | disp = readPFM(file_name).astype(np.float32)
159 | assert len(disp.shape) == 2
160 | nocc_pix = file_name.replace('disp0GT.pfm', 'mask0nocc.png')
161 | assert exists(nocc_pix)
162 | nocc_pix = imageio.imread(nocc_pix) == 255
163 | assert np.any(nocc_pix)
164 | return disp, nocc_pix
165 | elif basename(file_name) == 'disp0.pfm':
166 | disp = readPFM(file_name).astype(np.float32)
167 | valid = disp < 1e3
168 | return disp, valid
169 |
170 | def writeFlowKITTI(filename, uv):
171 | uv = 64.0 * uv + 2**15
172 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
173 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
174 | cv2.imwrite(filename, uv[..., ::-1])
175 |
176 |
177 | def read_gen(file_name, pil=False):
178 | ext = splitext(file_name)[-1]
179 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
180 | return Image.open(file_name)
181 | elif ext == '.bin' or ext == '.raw':
182 | return np.load(file_name)
183 | elif ext == '.flo':
184 | return readFlow(file_name).astype(np.float32)
185 | elif ext == '.pfm':
186 | flow = readPFM(file_name).astype(np.float32)
187 | if len(flow.shape) == 2:
188 | return flow
189 | else:
190 | return flow[:, :, :-1]
191 | return []
--------------------------------------------------------------------------------
/core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel', divis_by=8):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
12 | pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, *inputs):
19 | assert all((x.ndim == 4) for x in inputs)
20 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
21 |
22 | def unpad(self, x):
23 | assert x.ndim == 4
24 | ht, wd = x.shape[-2:]
25 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
26 | return x[..., c[0]:c[1], c[2]:c[3]]
27 |
28 | def forward_interpolate(flow):
29 | flow = flow.detach().cpu().numpy()
30 | dx, dy = flow[0], flow[1]
31 |
32 | ht, wd = dx.shape
33 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
34 |
35 | x1 = x0 + dx
36 | y1 = y0 + dy
37 |
38 | x1 = x1.reshape(-1)
39 | y1 = y1.reshape(-1)
40 | dx = dx.reshape(-1)
41 | dy = dy.reshape(-1)
42 |
43 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
44 | x1 = x1[valid]
45 | y1 = y1[valid]
46 | dx = dx[valid]
47 | dy = dy[valid]
48 |
49 | flow_x = interpolate.griddata(
50 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
51 |
52 | flow_y = interpolate.griddata(
53 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
54 |
55 | flow = np.stack([flow_x, flow_y], axis=0)
56 | return torch.from_numpy(flow).float()
57 |
58 |
59 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
60 | """ Wrapper for grid_sample, uses pixel coordinates """
61 | H, W = img.shape[-2:]
62 | xgrid, ygrid = coords.split([1,1], dim=-1)
63 | xgrid = 2*xgrid/(W-1) - 1
64 | if H > 1:
65 | ygrid = 2*ygrid/(H-1) - 1
66 |
67 | grid = torch.cat([xgrid, ygrid], dim=-1)
68 | img = F.grid_sample(img, grid, align_corners=True)
69 |
70 | if mask:
71 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
72 | return img, mask.float()
73 |
74 | return img
75 |
76 |
77 | def coords_grid(batch, ht, wd):
78 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
79 | coords = torch.stack(coords[::-1], dim=0).float()
80 | return coords[None].repeat(batch, 1, 1, 1)
81 |
82 |
83 | def upflow8(flow, mode='bilinear'):
84 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
85 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
86 |
87 | def downflow8(flow, mode='bilinear'):
88 | new_size = (int(flow.shape[2]/8), int(flow.shape[3]/8))
89 | return 0.125 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
90 |
91 | def gauss_blur(input, N=5, std=1):
92 | B, D, H, W = input.shape
93 | x, y = torch.meshgrid(torch.arange(N).float() - N//2, torch.arange(N).float() - N//2)
94 | unnormalized_gaussian = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * std ** 2))
95 | weights = unnormalized_gaussian / unnormalized_gaussian.sum().clamp(min=1e-4)
96 | weights = weights.view(1,1,N,N).to(input)
97 | output = F.conv2d(input.reshape(B*D,1,H,W), weights, padding=N//2)
98 | return output.view(B, D, H, W)
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: gps_gaussian
2 | channels:
3 | - pytorch
4 | - nvidia
5 | dependencies:
6 | - python=3.10.13
7 | - pip=23.3.1
8 | - pytorch=2.0.1
9 | - torchvision=0.15.2
10 | - torchaudio=2.0.2
11 | - pytorch-cuda=11.8
12 | - tqdm
13 | - tensorboard
14 | - scipy
15 | - pip:
16 | - opencv-python
17 | - taichi==1.5.0
18 | - yacs==0.1.8
19 |
--------------------------------------------------------------------------------
/gaussian_renderer/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
15 |
16 |
17 | def render(data, idx, pts_xyz, pts_rgb, rotations, scales, opacity, bg_color):
18 | """
19 | Render the scene.
20 |
21 | Background tensor (bg_color) must be on GPU!
22 | """
23 | bg_color = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
24 |
25 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
26 | screenspace_points = torch.zeros_like(pts_xyz, dtype=torch.float32, requires_grad=True, device="cuda") + 0
27 | try:
28 | screenspace_points.retain_grad()
29 | except:
30 | pass
31 |
32 | # Set up rasterization configuration
33 | tanfovx = math.tan(data['novel_view']['FovX'][idx] * 0.5)
34 | tanfovy = math.tan(data['novel_view']['FovY'][idx] * 0.5)
35 |
36 | raster_settings = GaussianRasterizationSettings(
37 | image_height=int(data['novel_view']['height'][idx]),
38 | image_width=int(data['novel_view']['width'][idx]),
39 | tanfovx=tanfovx,
40 | tanfovy=tanfovy,
41 | bg=bg_color,
42 | scale_modifier=1.0,
43 | viewmatrix=data['novel_view']['world_view_transform'][idx],
44 | projmatrix=data['novel_view']['full_proj_transform'][idx],
45 | sh_degree=3,
46 | campos=data['novel_view']['camera_center'][idx],
47 | prefiltered=False,
48 | debug=False
49 | )
50 |
51 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
52 |
53 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
54 | rendered_image, _ = rasterizer(
55 | means3D=pts_xyz,
56 | means2D=screenspace_points,
57 | shs=None,
58 | colors_precomp=pts_rgb,
59 | opacities=opacity,
60 | scales=scales,
61 | rotations=rotations,
62 | cov3D_precomp=None)
63 |
64 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
65 | # They will be excluded from value updates used in the splitting criteria.
66 |
67 | return rendered_image
68 |
--------------------------------------------------------------------------------
/lib/GaussianRender.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from gaussian_renderer import render
4 |
5 |
6 | def pts2render(data, bg_color):
7 | bs = data['lmain']['img'].shape[0]
8 | render_novel_list = []
9 | for i in range(bs):
10 | xyz_i_valid = []
11 | rgb_i_valid = []
12 | rot_i_valid = []
13 | scale_i_valid = []
14 | opacity_i_valid = []
15 | for view in ['lmain', 'rmain']:
16 | valid_i = data[view]['pts_valid'][i, :]
17 | xyz_i = data[view]['xyz'][i, :, :]
18 | rgb_i = data[view]['img'][i, :, :, :].permute(1, 2, 0).view(-1, 3)
19 | rot_i = data[view]['rot_maps'][i, :, :, :].permute(1, 2, 0).view(-1, 4)
20 | scale_i = data[view]['scale_maps'][i, :, :, :].permute(1, 2, 0).view(-1, 3)
21 | opacity_i = data[view]['opacity_maps'][i, :, :, :].permute(1, 2, 0).view(-1, 1)
22 |
23 | xyz_i_valid.append(xyz_i[valid_i].view(-1, 3))
24 | rgb_i_valid.append(rgb_i[valid_i].view(-1, 3))
25 | rot_i_valid.append(rot_i[valid_i].view(-1, 4))
26 | scale_i_valid.append(scale_i[valid_i].view(-1, 3))
27 | opacity_i_valid.append(opacity_i[valid_i].view(-1, 1))
28 |
29 | pts_xyz_i = torch.concat(xyz_i_valid, dim=0)
30 | pts_rgb_i = torch.concat(rgb_i_valid, dim=0)
31 | pts_rgb_i = pts_rgb_i * 0.5 + 0.5
32 | rot_i = torch.concat(rot_i_valid, dim=0)
33 | scale_i = torch.concat(scale_i_valid, dim=0)
34 | opacity_i = torch.concat(opacity_i_valid, dim=0)
35 |
36 | render_novel_i = render(data, i, pts_xyz_i, pts_rgb_i, rot_i, scale_i, opacity_i, bg_color=bg_color)
37 | render_novel_list.append(render_novel_i.unsqueeze(0))
38 |
39 | data['novel_view']['img_pred'] = torch.concat(render_novel_list, dim=0)
40 | return data
41 |
--------------------------------------------------------------------------------
/lib/TaichiRender.py:
--------------------------------------------------------------------------------
1 |
2 | import taichi as ti
3 | from lib.utils import *
4 | ti.init(ti.cuda)
5 |
6 |
7 | @ti.data_oriented
8 | class TaichiRenderBatch:
9 | def __init__(self, bs, res):
10 | self.res = res
11 | self.coord = ti.Vector.field(n=1, dtype=ti.i32, shape=(bs, res * res))
12 |
13 | @ti.kernel
14 | def render_respective_color(self, pts: ti.types.ndarray(), pts_mask: ti.types.ndarray(),
15 | render_depth: ti.types.ndarray(), render_color: ti.types.ndarray()):
16 | for B, i in self.coord:
17 | if pts_mask[B, i, 0] < 0.5:
18 | continue
19 | IX, IY = ti.cast(pts[B, i, 0], ti.i32), ti.cast(pts[B, i, 1], ti.i32)
20 | IX = ti.min(self.res - 1, ti.max(IX, 0))
21 | IY = ti.min(self.res - 1, ti.max(IY, 0))
22 | if pts[B, i, 2] >= ti.atomic_max(render_depth[B, 0, IY, IX], pts[B, i, 2]):
23 | for k in ti.static(range(3)):
24 | render_color[B, k, IY, IX] = pts[B, i, k + 3]
25 |
26 | def flow2render(self, data):
27 | novel_view_calib = torch.matmul(data['novel_view']['intr'], data['novel_view']['extr'])
28 | B = novel_view_calib.shape[0]
29 |
30 | taichi_pts_list = []
31 | taichi_mask_list = []
32 |
33 | for view in ['lmain', 'rmain']:
34 | data_select = data[view]
35 | depth_pred = flow2depth(data_select).clone()
36 | valid = depth_pred != 0
37 |
38 | pts = depth2pc(depth_pred, data_select['extr'], data_select['intr'])
39 | valid = valid.view(B, -1, 1).squeeze(2)
40 | pts_valid = torch.zeros_like(pts)
41 | pts_valid[valid] = pts[valid]
42 |
43 | pts_valid = perspective(pts_valid, novel_view_calib)
44 | pts_valid[:, :, 2:] = 1.0 / (pts_valid[:, :, 2:] + 1e-8)
45 |
46 | img_valid = torch.zeros_like(pts_valid)
47 | img_valid[valid] = data_select['img'].permute(0, 2, 3, 1).view(B, -1, 3)[valid]
48 | taichi_pts = torch.cat((pts_valid, img_valid), dim=2)
49 | taichi_mask = valid.view(B, -1, 1).float()
50 | taichi_pts_list.append(taichi_pts)
51 | taichi_mask_list.append(taichi_mask)
52 |
53 | render_depth = torch.zeros((B, 1, self.res, self.res), device=pts.device).float()
54 | min_value = -1
55 | render_color = min_value + torch.zeros((B, 3, self.res, self.res), device=pts.device).float()
56 | for i in range(2):
57 | self.render_respective_color(taichi_pts_list[i], taichi_mask_list[i], render_depth, render_color)
58 | data['novel_view']['img_pred'] = render_color
59 |
60 | return data
61 |
--------------------------------------------------------------------------------
/lib/graphics_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | import numpy as np
15 |
16 |
17 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
18 | Rt = np.zeros((4, 4))
19 | Rt[:3, :3] = R.transpose()
20 | Rt[:3, 3] = t
21 | Rt[3, 3] = 1.0
22 |
23 | C2W = np.linalg.inv(Rt)
24 | cam_center = C2W[:3, 3]
25 | cam_center = (cam_center + translate) * scale
26 | C2W[:3, 3] = cam_center
27 | Rt = np.linalg.inv(C2W)
28 | return np.float32(Rt)
29 |
30 |
31 | def getProjectionMatrix(znear, zfar, K, h, w):
32 | near_fx = znear / K[0, 0]
33 | near_fy = znear / K[1, 1]
34 | left = - (w - K[0, 2]) * near_fx
35 | right = K[0, 2] * near_fx
36 | bottom = (K[1, 2] - h) * near_fy
37 | top = K[1, 2] * near_fy
38 |
39 | P = torch.zeros(4, 4)
40 | z_sign = 1.0
41 | P[0, 0] = 2.0 * znear / (right - left)
42 | P[1, 1] = 2.0 * znear / (top - bottom)
43 | P[0, 2] = (right + left) / (right - left)
44 | P[1, 2] = (top + bottom) / (top - bottom)
45 | P[3, 2] = z_sign
46 | P[2, 2] = z_sign * zfar / (zfar - znear)
47 | P[2, 3] = -(zfar * znear) / (zfar - znear)
48 | return P
49 |
50 |
51 | def focal2fov(focal, pixels):
52 | return 2*math.atan(pixels/(2*focal))
53 |
--------------------------------------------------------------------------------
/lib/gs_parm_network.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch import nn
4 | from core.extractor import UnetExtractor, ResidualBlock
5 |
6 |
7 | class GSRegresser(nn.Module):
8 | def __init__(self, cfg, rgb_dim=3, depth_dim=1, norm_fn='group'):
9 | super().__init__()
10 | self.rgb_dims = cfg.raft.encoder_dims
11 | self.depth_dims = cfg.gsnet.encoder_dims
12 | self.decoder_dims = cfg.gsnet.decoder_dims
13 | self.head_dim = cfg.gsnet.parm_head_dim
14 | self.depth_encoder = UnetExtractor(in_channel=depth_dim, encoder_dim=self.depth_dims)
15 |
16 | self.decoder3 = nn.Sequential(
17 | ResidualBlock(self.rgb_dims[2]+self.depth_dims[2], self.decoder_dims[2], norm_fn=norm_fn),
18 | ResidualBlock(self.decoder_dims[2], self.decoder_dims[2], norm_fn=norm_fn)
19 | )
20 |
21 | self.decoder2 = nn.Sequential(
22 | ResidualBlock(self.rgb_dims[1]+self.depth_dims[1]+self.decoder_dims[2], self.decoder_dims[1], norm_fn=norm_fn),
23 | ResidualBlock(self.decoder_dims[1], self.decoder_dims[1], norm_fn=norm_fn)
24 | )
25 |
26 | self.decoder1 = nn.Sequential(
27 | ResidualBlock(self.rgb_dims[0]+self.depth_dims[0]+self.decoder_dims[1], self.decoder_dims[0], norm_fn=norm_fn),
28 | ResidualBlock(self.decoder_dims[0], self.decoder_dims[0], norm_fn=norm_fn)
29 | )
30 | self.up = nn.Upsample(scale_factor=2, mode="bilinear")
31 | self.out_conv = nn.Conv2d(self.decoder_dims[0]+rgb_dim+depth_dim, self.head_dim, kernel_size=3, padding=1)
32 | self.out_relu = nn.ReLU(inplace=True)
33 |
34 | self.rot_head = nn.Sequential(
35 | nn.Conv2d(self.head_dim, self.head_dim, kernel_size=3, padding=1),
36 | nn.ReLU(inplace=True),
37 | nn.Conv2d(self.head_dim, 4, kernel_size=1),
38 | )
39 | self.scale_head = nn.Sequential(
40 | nn.Conv2d(self.head_dim, self.head_dim, kernel_size=3, padding=1),
41 | nn.ReLU(inplace=True),
42 | nn.Conv2d(self.head_dim, 3, kernel_size=1),
43 | nn.Softplus(beta=100)
44 | )
45 | self.opacity_head = nn.Sequential(
46 | nn.Conv2d(self.head_dim, self.head_dim, kernel_size=3, padding=1),
47 | nn.ReLU(inplace=True),
48 | nn.Conv2d(self.head_dim, 1, kernel_size=1),
49 | nn.Sigmoid()
50 | )
51 |
52 | def forward(self, img, depth, img_feat):
53 | img_feat1, img_feat2, img_feat3 = img_feat
54 | depth_feat1, depth_feat2, depth_feat3 = self.depth_encoder(depth)
55 |
56 | feat3 = torch.concat([img_feat3, depth_feat3], dim=1)
57 | feat2 = torch.concat([img_feat2, depth_feat2], dim=1)
58 | feat1 = torch.concat([img_feat1, depth_feat1], dim=1)
59 |
60 | up3 = self.decoder3(feat3)
61 | up3 = self.up(up3)
62 | up2 = self.decoder2(torch.cat([up3, feat2], dim=1))
63 | up2 = self.up(up2)
64 | up1 = self.decoder1(torch.cat([up2, feat1], dim=1))
65 |
66 | up1 = self.up(up1)
67 | out = torch.cat([up1, img, depth], dim=1)
68 | out = self.out_conv(out)
69 | out = self.out_relu(out)
70 |
71 | # rot head
72 | rot_out = self.rot_head(out)
73 | rot_out = torch.nn.functional.normalize(rot_out, dim=1)
74 |
75 | # scale head
76 | scale_out = torch.clamp_max(self.scale_head(out), 0.01)
77 |
78 | # opacity head
79 | opacity_out = self.opacity_head(out)
80 |
81 | return rot_out, scale_out, opacity_out
82 |
--------------------------------------------------------------------------------
/lib/human_loader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 |
3 | import numpy as np
4 | import os
5 | from PIL import Image
6 | import cv2
7 | import torch
8 | from lib.graphics_utils import getWorld2View2, getProjectionMatrix, focal2fov
9 | from pathlib import Path
10 | import logging
11 | import json
12 | from tqdm import tqdm
13 |
14 |
15 | def save_np_to_json(parm, save_name):
16 | for key in parm.keys():
17 | parm[key] = parm[key].tolist()
18 | with open(save_name, 'w') as file:
19 | json.dump(parm, file, indent=1)
20 |
21 |
22 | def load_json_to_np(parm_name):
23 | with open(parm_name, 'r') as f:
24 | parm = json.load(f)
25 | for key in parm.keys():
26 | parm[key] = np.array(parm[key])
27 | return parm
28 |
29 |
30 | def depth2pts(depth, extrinsic, intrinsic):
31 | # depth H W extrinsic 3x4 intrinsic 3x3 pts map H W 3
32 | rot = extrinsic[:3, :3]
33 | trans = extrinsic[:3, 3:]
34 | S, S = depth.shape
35 |
36 | y, x = torch.meshgrid(torch.linspace(0.5, S-0.5, S, device=depth.device),
37 | torch.linspace(0.5, S-0.5, S, device=depth.device))
38 | pts_2d = torch.stack([x, y, torch.ones_like(x)], dim=-1) # H W 3
39 |
40 | pts_2d[..., 2] = 1.0 / (depth + 1e-8)
41 | pts_2d[..., 0] -= intrinsic[0, 2]
42 | pts_2d[..., 1] -= intrinsic[1, 2]
43 | pts_2d_xy = pts_2d[..., :2] * pts_2d[..., 2:]
44 | pts_2d = torch.cat([pts_2d_xy, pts_2d[..., 2:]], dim=-1)
45 |
46 | pts_2d[..., 0] /= intrinsic[0, 0]
47 | pts_2d[..., 1] /= intrinsic[1, 1]
48 | pts_2d = pts_2d.reshape(-1, 3).T
49 | pts = rot.T @ pts_2d - rot.T @ trans
50 | return pts.T.view(S, S, 3)
51 |
52 |
53 | def pts2depth(ptsmap, extrinsic, intrinsic):
54 | S, S, _ = ptsmap.shape
55 | pts = ptsmap.view(-1, 3).T
56 | calib = intrinsic @ extrinsic
57 | pts = calib[:3, :3] @ pts
58 | pts = pts + calib[:3, 3:4]
59 | pts[:2, :] /= (pts[2:, :] + 1e-8)
60 | depth = 1.0 / (pts[2, :].view(S, S) + 1e-8)
61 | return depth
62 |
63 |
64 | def stereo_pts2flow(pts0, pts1, rectify0, rectify1, Tf_x):
65 | new_extr0, new_intr0, rectify_mat0_x, rectify_mat0_y = rectify0
66 | new_extr1, new_intr1, rectify_mat1_x, rectify_mat1_y = rectify1
67 | new_depth0 = pts2depth(torch.FloatTensor(pts0), torch.FloatTensor(new_extr0), torch.FloatTensor(new_intr0))
68 | new_depth1 = pts2depth(torch.FloatTensor(pts1), torch.FloatTensor(new_extr1), torch.FloatTensor(new_intr1))
69 | new_depth0 = new_depth0.detach().numpy()
70 | new_depth1 = new_depth1.detach().numpy()
71 | new_depth0 = cv2.remap(new_depth0, rectify_mat0_x, rectify_mat0_y, cv2.INTER_LINEAR)
72 | new_depth1 = cv2.remap(new_depth1, rectify_mat1_x, rectify_mat1_y, cv2.INTER_LINEAR)
73 |
74 | offset0 = new_intr1[0, 2] - new_intr0[0, 2]
75 | disparity0 = -new_depth0 * Tf_x
76 | flow0 = offset0 - disparity0
77 |
78 | offset1 = new_intr0[0, 2] - new_intr1[0, 2]
79 | disparity1 = -new_depth1 * (-Tf_x)
80 | flow1 = offset1 - disparity1
81 |
82 | flow0[new_depth0 < 0.05] = 0
83 | flow1[new_depth1 < 0.05] = 0
84 |
85 | return flow0, flow1
86 |
87 |
88 | def read_img(name):
89 | img = np.array(Image.open(name))
90 | return img
91 |
92 |
93 | def read_depth(name):
94 | return cv2.imread(name, cv2.IMREAD_UNCHANGED).astype(np.float32) / 2.0 ** 15
95 |
96 |
97 | class StereoHumanDataset(Dataset):
98 | def __init__(self, opt, phase='train'):
99 | self.opt = opt
100 | self.use_processed_data = opt.use_processed_data
101 | self.phase = phase
102 | if self.phase == 'train':
103 | self.data_root = os.path.join(opt.data_root, 'train')
104 | elif self.phase == 'val':
105 | self.data_root = os.path.join(opt.data_root, 'val')
106 | elif self.phase == 'test':
107 | self.data_root = opt.test_data_root
108 |
109 | self.img_path = os.path.join(self.data_root, 'img/%s/%d.jpg')
110 | self.img_hr_path = os.path.join(self.data_root, 'img/%s/%d_hr.jpg')
111 | self.mask_path = os.path.join(self.data_root, 'mask/%s/%d.png')
112 | self.depth_path = os.path.join(self.data_root, 'depth/%s/%d.png')
113 | self.intr_path = os.path.join(self.data_root, 'parm/%s/%d_intrinsic.npy')
114 | self.extr_path = os.path.join(self.data_root, 'parm/%s/%d_extrinsic.npy')
115 | self.sample_list = sorted(list(os.listdir(os.path.join(self.data_root, 'img'))))
116 |
117 | if self.use_processed_data:
118 | self.local_data_root = os.path.join(opt.data_root, 'rectified_local', self.phase)
119 | self.local_img_path = os.path.join(self.local_data_root, 'img/%s/%d.jpg')
120 | self.local_mask_path = os.path.join(self.local_data_root, 'mask/%s/%d.png')
121 | self.local_flow_path = os.path.join(self.local_data_root, 'flow/%s/%d.npy')
122 | self.local_valid_path = os.path.join(self.local_data_root, 'valid/%s/%d.png')
123 | self.local_parm_path = os.path.join(self.local_data_root, 'parm/%s/%d_%d.json')
124 |
125 | if os.path.exists(self.local_data_root):
126 | assert len(os.listdir(os.path.join(self.local_data_root, 'img'))) == len(self.sample_list)
127 | logging.info(f"Using local data in {self.local_data_root} ...")
128 | else:
129 | self.save_local_stereo_data()
130 |
131 | def save_local_stereo_data(self):
132 | logging.info(f"Generating data to {self.local_data_root} ...")
133 | for sample_name in tqdm(self.sample_list):
134 | view0_data = self.load_single_view(sample_name, self.opt.source_id[0], hr_img=False,
135 | require_mask=True, require_pts=True)
136 | view1_data = self.load_single_view(sample_name, self.opt.source_id[1], hr_img=False,
137 | require_mask=True, require_pts=True)
138 | lmain_stereo_np = self.get_rectified_stereo_data(main_view_data=view0_data, ref_view_data=view1_data)
139 |
140 | for sub_dir in ['/img/', '/mask/', '/flow/', '/valid/', '/parm/']:
141 | Path(self.local_data_root + sub_dir + str(sample_name)).mkdir(exist_ok=True, parents=True)
142 |
143 | img0_save_name = self.local_img_path % (sample_name, self.opt.source_id[0])
144 | mask0_save_name = self.local_mask_path % (sample_name, self.opt.source_id[0])
145 | img1_save_name = self.local_img_path % (sample_name, self.opt.source_id[1])
146 | mask1_save_name = self.local_mask_path % (sample_name, self.opt.source_id[1])
147 | flow0_save_name = self.local_flow_path % (sample_name, self.opt.source_id[0])
148 | valid0_save_name = self.local_valid_path % (sample_name, self.opt.source_id[0])
149 | flow1_save_name = self.local_flow_path % (sample_name, self.opt.source_id[1])
150 | valid1_save_name = self.local_valid_path % (sample_name, self.opt.source_id[1])
151 | parm_save_name = self.local_parm_path % (sample_name, self.opt.source_id[0], self.opt.source_id[1])
152 |
153 | Image.fromarray(lmain_stereo_np['img0']).save(img0_save_name, quality=95)
154 | Image.fromarray(lmain_stereo_np['mask0']).save(mask0_save_name)
155 | Image.fromarray(lmain_stereo_np['img1']).save(img1_save_name, quality=95)
156 | Image.fromarray(lmain_stereo_np['mask1']).save(mask1_save_name)
157 | np.save(flow0_save_name, lmain_stereo_np['flow0'].astype(np.float16))
158 | Image.fromarray(lmain_stereo_np['valid0']).save(valid0_save_name)
159 | np.save(flow1_save_name, lmain_stereo_np['flow1'].astype(np.float16))
160 | Image.fromarray(lmain_stereo_np['valid1']).save(valid1_save_name)
161 | save_np_to_json(lmain_stereo_np['camera'], parm_save_name)
162 |
163 | logging.info("Generating data Done!")
164 |
165 | def load_local_stereo_data(self, sample_name):
166 | img0_name = self.local_img_path % (sample_name, self.opt.source_id[0])
167 | mask0_name = self.local_mask_path % (sample_name, self.opt.source_id[0])
168 | img1_name = self.local_img_path % (sample_name, self.opt.source_id[1])
169 | mask1_name = self.local_mask_path % (sample_name, self.opt.source_id[1])
170 | flow0_name = self.local_flow_path % (sample_name, self.opt.source_id[0])
171 | flow1_name = self.local_flow_path % (sample_name, self.opt.source_id[1])
172 | valid0_name = self.local_valid_path % (sample_name, self.opt.source_id[0])
173 | valid1_name = self.local_valid_path % (sample_name, self.opt.source_id[1])
174 | parm_name = self.local_parm_path % (sample_name, self.opt.source_id[0], self.opt.source_id[1])
175 |
176 | stereo_data = {
177 | 'img0': read_img(img0_name),
178 | 'mask0': read_img(mask0_name),
179 | 'img1': read_img(img1_name),
180 | 'mask1': read_img(mask1_name),
181 | 'camera': load_json_to_np(parm_name),
182 | 'flow0': np.load(flow0_name),
183 | 'valid0': read_img(valid0_name),
184 | 'flow1': np.load(flow1_name),
185 | 'valid1': read_img(valid1_name)
186 | }
187 |
188 | return stereo_data
189 |
190 | def load_single_view(self, sample_name, source_id, hr_img=False, require_mask=True, require_pts=True):
191 | img_name = self.img_path % (sample_name, source_id)
192 | image_hr_name = self.img_hr_path % (sample_name, source_id)
193 | mask_name = self.mask_path % (sample_name, source_id)
194 | depth_name = self.depth_path % (sample_name, source_id)
195 | intr_name = self.intr_path % (sample_name, source_id)
196 | extr_name = self.extr_path % (sample_name, source_id)
197 |
198 | intr, extr = np.load(intr_name), np.load(extr_name)
199 | mask, pts = None, None
200 | if hr_img:
201 | img = read_img(image_hr_name)
202 | intr[:2] *= 2
203 | else:
204 | img = read_img(img_name)
205 | if require_mask:
206 | mask = read_img(mask_name)
207 | if require_pts and os.path.exists(depth_name):
208 | depth = read_depth(depth_name)
209 | pts = depth2pts(torch.FloatTensor(depth), torch.FloatTensor(extr), torch.FloatTensor(intr))
210 |
211 | return img, mask, intr, extr, pts
212 |
213 | def get_novel_view_tensor(self, sample_name, view_id):
214 | img, _, intr, extr, _ = self.load_single_view(sample_name, view_id, hr_img=self.opt.use_hr_img,
215 | require_mask=False, require_pts=False)
216 | height, width = img.shape[:2]
217 | img = torch.from_numpy(img).permute(2, 0, 1)
218 | img = img / 255.0
219 |
220 | R = np.array(extr[:3, :3], np.float32).reshape(3, 3).transpose(1, 0)
221 | T = np.array(extr[:3, 3], np.float32)
222 |
223 | FovX = focal2fov(intr[0, 0], width)
224 | FovY = focal2fov(intr[1, 1], height)
225 | projection_matrix = getProjectionMatrix(znear=self.opt.znear, zfar=self.opt.zfar, K=intr, h=height, w=width).transpose(0, 1)
226 | world_view_transform = torch.tensor(getWorld2View2(R, T, np.array(self.opt.trans), self.opt.scale)).transpose(0, 1)
227 | full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
228 | camera_center = world_view_transform.inverse()[3, :3]
229 |
230 | novel_view_data = {
231 | 'view_id': torch.IntTensor([view_id]),
232 | 'img': img,
233 | 'extr': torch.FloatTensor(extr),
234 | 'FovX': FovX,
235 | 'FovY': FovY,
236 | 'width': width,
237 | 'height': height,
238 | 'world_view_transform': world_view_transform,
239 | 'full_proj_transform': full_proj_transform,
240 | 'camera_center': camera_center
241 | }
242 |
243 | return novel_view_data
244 |
245 | def get_rectified_stereo_data(self, main_view_data, ref_view_data):
246 | img0, mask0, intr0, extr0, pts0 = main_view_data
247 | img1, mask1, intr1, extr1, pts1 = ref_view_data
248 |
249 | H, W = self.opt.src_res, self.opt.src_res
250 | r0, t0 = extr0[:3, :3], extr0[:3, 3:]
251 | r1, t1 = extr1[:3, :3], extr1[:3, 3:]
252 | inv_r0 = r0.T
253 | inv_t0 = - r0.T @ t0
254 | E0 = np.eye(4)
255 | E0[:3, :3], E0[:3, 3:] = inv_r0, inv_t0
256 | E1 = np.eye(4)
257 | E1[:3, :3], E1[:3, 3:] = r1, t1
258 | E = E1 @ E0
259 | R, T = E[:3, :3], E[:3, 3]
260 | dist0, dist1 = np.zeros(4), np.zeros(4)
261 |
262 | R0, R1, P0, P1, _, _, _ = cv2.stereoRectify(intr0, dist0, intr1, dist1, (W, H), R, T, flags=0)
263 |
264 | new_extr0 = R0 @ extr0
265 | new_intr0 = P0[:3, :3]
266 | new_extr1 = R1 @ extr1
267 | new_intr1 = P1[:3, :3]
268 | Tf_x = np.array(P1[0, 3])
269 |
270 | camera = {
271 | 'intr0': new_intr0,
272 | 'intr1': new_intr1,
273 | 'extr0': new_extr0,
274 | 'extr1': new_extr1,
275 | 'Tf_x': Tf_x
276 | }
277 |
278 | rectify_mat0_x, rectify_mat0_y = cv2.initUndistortRectifyMap(intr0, dist0, R0, P0, (W, H), cv2.CV_32FC1)
279 | new_img0 = cv2.remap(img0, rectify_mat0_x, rectify_mat0_y, cv2.INTER_LINEAR)
280 | new_mask0 = cv2.remap(mask0, rectify_mat0_x, rectify_mat0_y, cv2.INTER_LINEAR)
281 | rectify_mat1_x, rectify_mat1_y = cv2.initUndistortRectifyMap(intr1, dist1, R1, P1, (W, H), cv2.CV_32FC1)
282 | new_img1 = cv2.remap(img1, rectify_mat1_x, rectify_mat1_y, cv2.INTER_LINEAR)
283 | new_mask1 = cv2.remap(mask1, rectify_mat1_x, rectify_mat1_y, cv2.INTER_LINEAR)
284 | rectify0 = new_extr0, new_intr0, rectify_mat0_x, rectify_mat0_y
285 | rectify1 = new_extr1, new_intr1, rectify_mat1_x, rectify_mat1_y
286 |
287 | stereo_data = {
288 | 'img0': new_img0,
289 | 'mask0': new_mask0,
290 | 'img1': new_img1,
291 | 'mask1': new_mask1,
292 | 'camera': camera
293 | }
294 |
295 | if pts0 is not None:
296 | flow0, flow1 = stereo_pts2flow(pts0, pts1, rectify0, rectify1, Tf_x)
297 |
298 | kernel = np.ones((3, 3), dtype=np.uint8)
299 | flow_eroded, valid_eroded = [], []
300 | for (flow, new_mask) in [(flow0, new_mask0), (flow1, new_mask1)]:
301 | valid = (new_mask.copy()[:, :, 0] / 255.0).astype(np.float32)
302 | valid = cv2.erode(valid, kernel, 1)
303 | valid[valid >= 0.66] = 1.0
304 | valid[valid < 0.66] = 0.0
305 | flow *= valid
306 | valid *= 255.0
307 | flow_eroded.append(flow)
308 | valid_eroded.append(valid)
309 |
310 | stereo_data.update({
311 | 'flow0': flow_eroded[0],
312 | 'valid0': valid_eroded[0].astype(np.uint8),
313 | 'flow1': flow_eroded[1],
314 | 'valid1': valid_eroded[1].astype(np.uint8)
315 | })
316 |
317 | return stereo_data
318 |
319 | def stereo_to_dict_tensor(self, stereo_data, subject_name):
320 | img_tensor, mask_tensor = [], []
321 | for (img_view, mask_view) in [('img0', 'mask0'), ('img1', 'mask1')]:
322 | img = torch.from_numpy(stereo_data[img_view]).permute(2, 0, 1)
323 | img = 2 * (img / 255.0) - 1.0
324 | mask = torch.from_numpy(stereo_data[mask_view]).permute(2, 0, 1).float()
325 | mask = mask / 255.0
326 |
327 | img = img * mask
328 | mask[mask < 0.5] = 0.0
329 | mask[mask >= 0.5] = 1.0
330 | img_tensor.append(img)
331 | mask_tensor.append(mask)
332 |
333 | lmain_data = {
334 | 'img': img_tensor[0],
335 | 'mask': mask_tensor[0],
336 | 'intr': torch.FloatTensor(stereo_data['camera']['intr0']),
337 | 'ref_intr': torch.FloatTensor(stereo_data['camera']['intr1']),
338 | 'extr': torch.FloatTensor(stereo_data['camera']['extr0']),
339 | 'Tf_x': torch.FloatTensor(stereo_data['camera']['Tf_x'])
340 | }
341 |
342 | rmain_data = {
343 | 'img': img_tensor[1],
344 | 'mask': mask_tensor[1],
345 | 'intr': torch.FloatTensor(stereo_data['camera']['intr1']),
346 | 'ref_intr': torch.FloatTensor(stereo_data['camera']['intr0']),
347 | 'extr': torch.FloatTensor(stereo_data['camera']['extr1']),
348 | 'Tf_x': -torch.FloatTensor(stereo_data['camera']['Tf_x'])
349 | }
350 |
351 | if 'flow0' in stereo_data:
352 | flow_tensor, valid_tensor = [], []
353 | for (flow_view, valid_view) in [('flow0', 'valid0'), ('flow1', 'valid1')]:
354 | flow = torch.from_numpy(stereo_data[flow_view])
355 | flow = torch.unsqueeze(flow, dim=0)
356 | flow_tensor.append(flow)
357 |
358 | valid = torch.from_numpy(stereo_data[valid_view])
359 | valid = torch.unsqueeze(valid, dim=0)
360 | valid = valid / 255.0
361 | valid_tensor.append(valid)
362 |
363 | lmain_data['flow'], lmain_data['valid'] = flow_tensor[0], valid_tensor[0]
364 | rmain_data['flow'], rmain_data['valid'] = flow_tensor[1], valid_tensor[1]
365 |
366 | return {'name': subject_name, 'lmain': lmain_data, 'rmain': rmain_data}
367 |
368 | def get_item(self, index, novel_id=None):
369 | sample_id = index % len(self.sample_list)
370 | sample_name = self.sample_list[sample_id]
371 |
372 | if self.use_processed_data:
373 | stereo_np = self.load_local_stereo_data(sample_name)
374 | else:
375 | view0_data = self.load_single_view(sample_name, self.opt.source_id[0], hr_img=False,
376 | require_mask=True, require_pts=True)
377 | view1_data = self.load_single_view(sample_name, self.opt.source_id[1], hr_img=False,
378 | require_mask=True, require_pts=True)
379 | stereo_np = self.get_rectified_stereo_data(main_view_data=view0_data, ref_view_data=view1_data)
380 | dict_tensor = self.stereo_to_dict_tensor(stereo_np, sample_name)
381 |
382 | if novel_id:
383 | novel_id = np.random.choice(novel_id)
384 | dict_tensor.update({
385 | 'novel_view': self.get_novel_view_tensor(sample_name, novel_id)
386 | })
387 |
388 | return dict_tensor
389 |
390 | def get_test_item(self, index, source_id):
391 | sample_id = index % len(self.sample_list)
392 | sample_name = self.sample_list[sample_id]
393 |
394 | if self.use_processed_data:
395 | logging.error('test data loader not support processed data')
396 |
397 | view0_data = self.load_single_view(sample_name, source_id[0], hr_img=False, require_mask=True, require_pts=False)
398 | view1_data = self.load_single_view(sample_name, source_id[1], hr_img=False, require_mask=True, require_pts=False)
399 | lmain_intr_ori, lmain_extr_ori = view0_data[2], view0_data[3]
400 | rmain_intr_ori, rmain_extr_ori = view1_data[2], view1_data[3]
401 | stereo_np = self.get_rectified_stereo_data(main_view_data=view0_data, ref_view_data=view1_data)
402 | dict_tensor = self.stereo_to_dict_tensor(stereo_np, sample_name)
403 |
404 | dict_tensor['lmain']['intr_ori'] = torch.FloatTensor(lmain_intr_ori)
405 | dict_tensor['rmain']['intr_ori'] = torch.FloatTensor(rmain_intr_ori)
406 | dict_tensor['lmain']['extr_ori'] = torch.FloatTensor(lmain_extr_ori)
407 | dict_tensor['rmain']['extr_ori'] = torch.FloatTensor(rmain_extr_ori)
408 |
409 | img_len = self.opt.src_res * 2 if self.opt.use_hr_img else self.opt.src_res
410 | novel_dict = {
411 | 'height': torch.IntTensor([img_len]),
412 | 'width': torch.IntTensor([img_len])
413 | }
414 |
415 | dict_tensor.update({
416 | 'novel_view': novel_dict
417 | })
418 |
419 | return dict_tensor
420 |
421 | def __getitem__(self, index):
422 | if self.phase == 'train':
423 | return self.get_item(index, novel_id=self.opt.train_novel_id)
424 | elif self.phase == 'val':
425 | return self.get_item(index, novel_id=self.opt.val_novel_id)
426 |
427 | def __len__(self):
428 | self.train_boost = 50
429 | self.val_boost = 200
430 | if self.phase == 'train':
431 | return len(self.sample_list) * self.train_boost
432 | elif self.phase == 'val':
433 | return len(self.sample_list) * self.val_boost
434 | else:
435 | return len(self.sample_list)
436 |
--------------------------------------------------------------------------------
/lib/loss.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from math import exp
6 |
7 |
8 | def sequence_loss(flow_preds, flow_gt, valid, loss_gamma=0.9):
9 | """ Loss function defined over sequence of flow predictions """
10 |
11 | n_predictions = len(flow_preds)
12 | flow_loss = 0.0
13 |
14 | valid = (valid >= 0.5)
15 | assert not torch.isinf(flow_gt[valid.bool()]).any()
16 |
17 | for i in range(n_predictions):
18 | # We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations
19 | adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1))
20 | i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
21 | i_loss = (flow_preds[i] - flow_gt).abs()
22 | flow_loss += i_weight * i_loss[valid.bool()].mean()
23 |
24 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
25 | epe = epe.view(-1)[valid.view(-1)]
26 |
27 | metrics = {
28 | 'train_epe': epe.mean().item(),
29 | 'train_1px': (epe < 1).float().mean().item(),
30 | 'train_3px': (epe < 3).float().mean().item()
31 | }
32 |
33 | return flow_loss, metrics
34 |
35 |
36 | def l1_loss(network_output, gt):
37 | return torch.abs((network_output - gt)).mean()
38 |
39 |
40 | def gaussian(window_size, sigma):
41 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
42 | return gauss / gauss.sum()
43 |
44 |
45 | def create_window(window_size, channel):
46 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
47 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
48 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
49 | return window
50 |
51 |
52 | def ssim(img1, img2, window_size=11, size_average=True):
53 | channel = img1.size(-3)
54 | window = create_window(window_size, channel)
55 |
56 | if img1.is_cuda:
57 | window = window.cuda(img1.get_device())
58 | window = window.type_as(img1)
59 |
60 | return _ssim(img1, img2, window, window_size, channel, size_average)
61 |
62 |
63 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
64 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
65 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
66 |
67 | mu1_sq = mu1.pow(2)
68 | mu2_sq = mu2.pow(2)
69 | mu1_mu2 = mu1 * mu2
70 |
71 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
72 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
73 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
74 |
75 | C1 = 0.01 ** 2
76 | C2 = 0.03 ** 2
77 |
78 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
79 |
80 | if size_average:
81 | return ssim_map.mean()
82 | else:
83 | return ssim_map.mean(1).mean(1).mean(1)
84 |
85 |
86 | def psnr(img1, img2):
87 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
88 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
89 |
--------------------------------------------------------------------------------
/lib/network.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch import nn
4 | from core.raft_stereo_human import RAFTStereoHuman
5 | from core.extractor import UnetExtractor
6 | from lib.gs_parm_network import GSRegresser
7 | from lib.loss import sequence_loss
8 | from lib.utils import flow2depth, depth2pc
9 | from torch.cuda.amp import autocast as autocast
10 |
11 |
12 | class RtStereoHumanModel(nn.Module):
13 | def __init__(self, cfg, with_gs_render=False):
14 | super().__init__()
15 | self.cfg = cfg
16 | self.with_gs_render = with_gs_render
17 | self.train_iters = self.cfg.raft.train_iters
18 | self.val_iters = self.cfg.raft.val_iters
19 |
20 | self.img_encoder = UnetExtractor(in_channel=3, encoder_dim=self.cfg.raft.encoder_dims)
21 | self.raft_stereo = RAFTStereoHuman(self.cfg.raft)
22 | if self.with_gs_render:
23 | self.gs_parm_regresser = GSRegresser(self.cfg, rgb_dim=3, depth_dim=1)
24 |
25 | def forward(self, data, is_train=True):
26 | bs = data['lmain']['img'].shape[0]
27 |
28 | image = torch.cat([data['lmain']['img'], data['rmain']['img']], dim=0)
29 | flow = torch.cat([data['lmain']['flow'], data['rmain']['flow']], dim=0) if is_train else None
30 | valid = torch.cat([data['lmain']['valid'], data['rmain']['valid']], dim=0) if is_train else None
31 |
32 | with autocast(enabled=self.cfg.raft.mixed_precision):
33 | img_feat = self.img_encoder(image)
34 |
35 | if is_train:
36 | flow_predictions = self.raft_stereo(img_feat[2], iters=self.train_iters)
37 | flow_loss, metrics = sequence_loss(flow_predictions, flow, valid)
38 | flow_pred_lmain, flow_pred_rmain = torch.split(flow_predictions[-1], [bs, bs])
39 |
40 | if not self.with_gs_render:
41 | data['lmain']['flow_pred'] = flow_pred_lmain.detach()
42 | data['rmain']['flow_pred'] = flow_pred_rmain.detach()
43 | return data, flow_loss, metrics
44 |
45 | data['lmain']['flow_pred'] = flow_pred_lmain
46 | data['rmain']['flow_pred'] = flow_pred_rmain
47 | data = self.flow2gsparms(image, img_feat, data, bs)
48 |
49 | return data, flow_loss, metrics
50 |
51 | else:
52 | flow_up = self.raft_stereo(img_feat[2], iters=self.val_iters, test_mode=True)
53 | flow_loss, metrics = None, None
54 |
55 | data['lmain']['flow_pred'] = flow_up[0]
56 | data['rmain']['flow_pred'] = flow_up[1]
57 |
58 | if not self.with_gs_render:
59 | return data, flow_loss, metrics
60 | data = self.flow2gsparms(image, img_feat, data, bs)
61 |
62 | return data, flow_loss, metrics
63 |
64 | def flow2gsparms(self, lr_img, lr_img_feat, data, bs):
65 | for view in ['lmain', 'rmain']:
66 | data[view]['depth'] = flow2depth(data[view])
67 | data[view]['xyz'] = depth2pc(data[view]['depth'], data[view]['extr'], data[view]['intr']).view(bs, -1, 3)
68 | valid = data[view]['depth'] != 0.0
69 | data[view]['pts_valid'] = valid.view(bs, -1)
70 |
71 | # regress gaussian parms
72 | lr_depth = torch.concat([data['lmain']['depth'], data['rmain']['depth']], dim=0)
73 | rot_maps, scale_maps, opacity_maps = self.gs_parm_regresser(lr_img, lr_depth, lr_img_feat)
74 |
75 | data['lmain']['rot_maps'], data['rmain']['rot_maps'] = torch.split(rot_maps, [bs, bs])
76 | data['lmain']['scale_maps'], data['rmain']['scale_maps'] = torch.split(scale_maps, [bs, bs])
77 | data['lmain']['opacity_maps'], data['rmain']['opacity_maps'] = torch.split(opacity_maps, [bs, bs])
78 |
79 | return data
80 |
81 |
--------------------------------------------------------------------------------
/lib/train_recoder.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import json
4 | import shutil
5 | import logging
6 | from pathlib import Path
7 | from torch.utils.tensorboard import SummaryWriter
8 |
9 |
10 | def file_backup(exp_path, cfg, train_script):
11 | shutil.copy(train_script, exp_path)
12 | shutil.copytree('core', os.path.join(exp_path, 'core'), dirs_exist_ok=True)
13 | shutil.copytree('config', os.path.join(exp_path, 'config'), dirs_exist_ok=True)
14 | shutil.copytree('gaussian_renderer', os.path.join(exp_path, 'gaussian_renderer'), dirs_exist_ok=True)
15 | for sub_dir in ['lib']:
16 | files = os.listdir(sub_dir)
17 | for file in files:
18 | Path(os.path.join(exp_path, sub_dir)).mkdir(exist_ok=True, parents=True)
19 | if file[-3:] == '.py':
20 | shutil.copy(os.path.join(sub_dir, file), os.path.join(exp_path, sub_dir))
21 |
22 | json_file_name = exp_path + '/cfg.json'
23 | with open(json_file_name, 'w') as json_file:
24 | json.dump(cfg, json_file, indent=2)
25 |
26 |
27 | class Logger:
28 | def __init__(self, scheduler, cfg):
29 | self.scheduler = scheduler
30 | self.sum_freq = cfg.loss_freq
31 | self.log_dir = cfg.logs_path
32 | self.total_steps = 0
33 | self.running_loss = {}
34 | self.writer = SummaryWriter(log_dir=self.log_dir)
35 |
36 | def _print_training_status(self):
37 | metrics_data = [self.running_loss[k] / self.sum_freq for k in sorted(self.running_loss.keys())]
38 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps, self.scheduler.get_last_lr()[0])
39 | metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
40 |
41 | # print the training status
42 | logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}")
43 |
44 | if self.writer is None:
45 | self.writer = SummaryWriter(log_dir=self.log_dir)
46 |
47 | for k in self.running_loss:
48 | self.writer.add_scalar(k, self.running_loss[k] / self.sum_freq, self.total_steps)
49 | self.running_loss[k] = 0.0
50 |
51 | def push(self, metrics):
52 | for key in metrics:
53 | if key not in self.running_loss:
54 | self.running_loss[key] = 0.0
55 |
56 | self.running_loss[key] += metrics[key]
57 |
58 | if self.total_steps and self.total_steps % self.sum_freq == 0:
59 | self._print_training_status()
60 | self.running_loss = {}
61 |
62 | self.total_steps += 1
63 |
64 | def write_dict(self, results, write_step):
65 | if self.writer is None:
66 | self.writer = SummaryWriter(log_dir=self.log_dir)
67 |
68 | for key in results:
69 | self.writer.add_scalar(key, results[key], write_step)
70 |
71 | def close(self):
72 | self.writer.close()
73 |
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numpy as np
4 | from scipy.spatial.transform import Rotation as Rot
5 | from scipy.spatial.transform import Slerp
6 | from lib.graphics_utils import getWorld2View2, getProjectionMatrix, focal2fov
7 |
8 |
9 | def get_novel_calib(data, opt, ratio=0.5, intr_key='intr', extr_key='extr'):
10 | bs = data['lmain'][intr_key].shape[0]
11 | fovx_list, fovy_list, world_view_transform_list, full_proj_transform_list, camera_center_list = [], [], [], [], []
12 | for i in range(bs):
13 | intr0 = data['lmain'][intr_key][i, ...].cpu().numpy()
14 | intr1 = data['rmain'][intr_key][i, ...].cpu().numpy()
15 | extr0 = data['lmain'][extr_key][i, ...].cpu().numpy()
16 | extr1 = data['rmain'][extr_key][i, ...].cpu().numpy()
17 |
18 | rot0 = extr0[:3, :3]
19 | rot1 = extr1[:3, :3]
20 | rots = Rot.from_matrix(np.stack([rot0, rot1]))
21 | key_times = [0, 1]
22 | slerp = Slerp(key_times, rots)
23 | rot = slerp(ratio)
24 | npose = np.diag([1.0, 1.0, 1.0, 1.0])
25 | npose = npose.astype(np.float32)
26 | npose[:3, :3] = rot.as_matrix()
27 | npose[:3, 3] = ((1.0 - ratio) * extr0 + ratio * extr1)[:3, 3]
28 | extr_new = npose[:3, :]
29 | intr_new = ((1.0 - ratio) * intr0 + ratio * intr1)
30 |
31 | if opt.use_hr_img:
32 | intr_new[:2] *= 2
33 | width, height = data['novel_view']['width'][i], data['novel_view']['height'][i]
34 | R = np.array(extr_new[:3, :3], np.float32).reshape(3, 3).transpose(1, 0)
35 | T = np.array(extr_new[:3, 3], np.float32)
36 |
37 | FovX = focal2fov(intr_new[0, 0], width)
38 | FovY = focal2fov(intr_new[1, 1], height)
39 | projection_matrix = getProjectionMatrix(znear=opt.znear, zfar=opt.zfar, K=intr_new, h=height, w=width).transpose(0, 1)
40 | world_view_transform = torch.tensor(getWorld2View2(R, T, np.array(opt.trans), opt.scale)).transpose(0, 1)
41 | full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
42 | camera_center = world_view_transform.inverse()[3, :3]
43 |
44 | fovx_list.append(FovX)
45 | fovy_list.append(FovY)
46 | world_view_transform_list.append(world_view_transform.unsqueeze(0))
47 | full_proj_transform_list.append(full_proj_transform.unsqueeze(0))
48 | camera_center_list.append(camera_center.unsqueeze(0))
49 |
50 | data['novel_view']['FovX'] = torch.FloatTensor(np.array(fovx_list)).cuda()
51 | data['novel_view']['FovY'] = torch.FloatTensor(np.array(fovy_list)).cuda()
52 | data['novel_view']['world_view_transform'] = torch.concat(world_view_transform_list).cuda()
53 | data['novel_view']['full_proj_transform'] = torch.concat(full_proj_transform_list).cuda()
54 | data['novel_view']['camera_center'] = torch.concat(camera_center_list).cuda()
55 | return data
56 |
57 |
58 | def get_novel_calib_for_show(data, ratio=0.5, intr_key='intr', extr_key='extr'):
59 | bs = data['lmain'][intr_key].shape[0]
60 | intr_list, extr_list = [], []
61 | data['novel_view'] = {}
62 | for i in range(bs):
63 | intr0 = data['lmain'][intr_key][i, ...].cpu().numpy()
64 | intr1 = data['rmain'][intr_key][i, ...].cpu().numpy()
65 | extr0 = data['lmain'][extr_key][i, ...].cpu().numpy()
66 | extr1 = data['rmain'][extr_key][i, ...].cpu().numpy()
67 |
68 | rot0 = extr0[:3, :3]
69 | rot1 = extr1[:3, :3]
70 | rots = Rot.from_matrix(np.stack([rot0, rot1]))
71 | key_times = [0, 1]
72 | slerp = Slerp(key_times, rots)
73 | rot = slerp(ratio)
74 | npose = np.diag([1.0, 1.0, 1.0, 1.0])
75 | npose = npose.astype(np.float32)
76 | npose[:3, :3] = rot.as_matrix()
77 | npose[:3, 3] = ((1.0 - ratio) * extr0 + ratio * extr1)[:3, 3]
78 | extr_new = npose[:3, :]
79 |
80 | intr_new = ((1.0 - ratio) * intr0 + ratio * intr1)
81 | intr_list.append(intr_new)
82 | extr_list.append(extr_new)
83 | data['novel_view']['intr'] = torch.FloatTensor(np.array(intr_list)).cuda()
84 | data['novel_view']['extr'] = torch.FloatTensor(np.array(extr_list)).cuda()
85 | return data
86 |
87 |
88 | def depth2pc(depth, extrinsic, intrinsic):
89 | B, C, S, S = depth.shape
90 | depth = depth[:, 0, :, :]
91 | rot = extrinsic[:, :3, :3]
92 | trans = extrinsic[:, :3, 3:]
93 |
94 | y, x = torch.meshgrid(torch.linspace(0.5, S-0.5, S, device=depth.device), torch.linspace(0.5, S-0.5, S, device=depth.device))
95 | pts_2d = torch.stack([x, y, torch.ones_like(x)], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1) # B S S 3
96 |
97 | pts_2d[..., 2] = 1.0 / (depth + 1e-8)
98 | pts_2d[:, :, :, 0] -= intrinsic[:, None, None, 0, 2]
99 | pts_2d[:, :, :, 1] -= intrinsic[:, None, None, 1, 2]
100 | pts_2d_xy = pts_2d[:, :, :, :2] * pts_2d[:, :, :, 2:]
101 | pts_2d = torch.cat([pts_2d_xy, pts_2d[..., 2:]], dim=-1)
102 |
103 | pts_2d[..., 0] /= intrinsic[:, 0, 0][:, None, None]
104 | pts_2d[..., 1] /= intrinsic[:, 1, 1][:, None, None]
105 |
106 | pts_2d = pts_2d.view(B, -1, 3).permute(0, 2, 1)
107 | rot_t = rot.permute(0, 2, 1)
108 | pts = torch.bmm(rot_t, pts_2d) - torch.bmm(rot_t, trans)
109 |
110 | return pts.permute(0, 2, 1)
111 |
112 |
113 | def flow2depth(data):
114 | offset = data['ref_intr'][:, 0, 2] - data['intr'][:, 0, 2]
115 | offset = torch.broadcast_to(offset[:, None, None, None], data['flow_pred'].shape)
116 | disparity = offset - data['flow_pred']
117 | depth = -disparity / data['Tf_x'][:, None, None, None]
118 | depth *= data['mask'][:, :1, :, :]
119 |
120 | return depth
121 |
122 | def perspective(pts, calibs):
123 | pts = pts.permute(0, 2, 1)
124 | pts = torch.bmm(calibs[:, :3, :3], pts)
125 | pts = pts + calibs[:, :3, 3:4]
126 | pts[:, :2, :] /= pts[:, 2:, :]
127 | pts = pts.permute(0, 2, 1)
128 | return pts
129 |
--------------------------------------------------------------------------------
/prepare_data/MAKE_DATA.md:
--------------------------------------------------------------------------------
1 | # Data Documentation
2 |
3 | We provide a scripts for rendering training data from human scans, many thanks [Ruizhi Shao](https://dsaurus.github.io/saurus/) for sharing this code. Take [THuman2.0](https://github.com/ytrock/THuman2.0-Dataset) as an example.
4 |
5 | - Download THuman2.0 scan data from [This Link](https://github.com/ytrock/THuman2.0-Dataset) and the SMPL-X fitting parameters from [This Link](https://drive.google.com/file/d/1rnkGomScq3yxyM9auA-oHW6m_OJ5mlGL/view?usp=sharing). Then spilt the THuman2.0 scans into train set and validation set. We use theSMPL-X parameters to normalization the orientation of human and it is not essential. Comment L133-L140 in [render_data.py](render_data.py#L133-L140) if you do not need to normalization the orientation of human scans in THuman2.0.
6 | ```
7 | ./Thuman2.0
8 | ├── THuman2.0_Smpl_X_Paras/
9 | ├── train/
10 | │ ├── 0004/
11 | │ │ ├── 0004.obj
12 | │ │ ├── material0.jpeg
13 | │ │ └── material0.mtl
14 | │ ├── 0005
15 | │ ├── 0007
16 | │ └── ...
17 | └──val
18 | ├── 0000
19 | ├── 0001
20 | ├── 0002
21 | └── ...
22 | ```
23 |
24 | - Set the correct ```thuman_root``` and ```save_root``` in [render_data.py](render_data.py#L214-L215). Reset the ```cam_nums``` , ```scene_radius``` and camera parameters to make the training data similar to your targeting real-world scenario.
25 |
--------------------------------------------------------------------------------
/prepare_data/render_data.py:
--------------------------------------------------------------------------------
1 | import taichi_three as t3
2 | import numpy as np
3 | from taichi_three.transform import *
4 | from pathlib import Path
5 | from tqdm import tqdm
6 | import os
7 | import cv2
8 | import pickle
9 | os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
10 |
11 |
12 | def save(pid, data_id, vid, save_path, extr, intr, depth, img, mask, img_hr=None):
13 | img_save_path = os.path.join(save_path, 'img', data_id + '_' + '%03d' % pid)
14 | depth_save_path = os.path.join(save_path, 'depth', data_id + '_' + '%03d' % pid)
15 | mask_save_path = os.path.join(save_path, 'mask', data_id + '_' + '%03d' % pid)
16 | parm_save_path = os.path.join(save_path, 'parm', data_id + '_' + '%03d' % pid)
17 | Path(img_save_path).mkdir(exist_ok=True, parents=True)
18 | Path(parm_save_path).mkdir(exist_ok=True, parents=True)
19 | Path(mask_save_path).mkdir(exist_ok=True, parents=True)
20 | Path(depth_save_path).mkdir(exist_ok=True, parents=True)
21 |
22 | depth = depth * 2.0 ** 15
23 | cv2.imwrite(os.path.join(depth_save_path, '{}.png'.format(vid)), depth.astype(np.uint16))
24 | img = (np.clip(img, 0, 1) * 255.0 + 0.5).astype(np.uint8)[:, :, ::-1]
25 | mask = (np.clip(mask, 0, 1) * 255.0 + 0.5).astype(np.uint8)
26 | cv2.imwrite(os.path.join(img_save_path, '{}.jpg'.format(vid)), img)
27 | if img_hr is not None:
28 | img_hr = (np.clip(img_hr, 0, 1) * 255.0 + 0.5).astype(np.uint8)[:, :, ::-1]
29 | cv2.imwrite(os.path.join(img_save_path, '{}_hr.jpg'.format(vid)), img_hr)
30 | cv2.imwrite(os.path.join(mask_save_path, '{}.png'.format(vid)), mask)
31 | np.save(os.path.join(parm_save_path, '{}_intrinsic.npy'.format(vid)), intr)
32 | np.save(os.path.join(parm_save_path, '{}_extrinsic.npy'.format(vid)), extr)
33 |
34 |
35 | class StaticRenderer:
36 | def __init__(self, src_res):
37 | ti.init(arch=ti.cuda, device_memory_fraction=0.8)
38 | self.scene = t3.Scene()
39 | self.N = 10
40 | self.src_res = src_res
41 | self.hr_res = (src_res[0] * 2, src_res[1] * 2)
42 |
43 | def change_all(self):
44 | save_obj = []
45 | save_tex = []
46 | for model in self.scene.models:
47 | save_obj.append(model.init_obj)
48 | save_tex.append(model.init_tex)
49 | ti.init(arch=ti.cuda, device_memory_fraction=0.8)
50 | print('init')
51 | self.scene = t3.Scene()
52 | for i in range(len(save_obj)):
53 | model = t3.StaticModel(self.N, obj=save_obj[i], tex=save_tex[i])
54 | self.scene.add_model(model)
55 |
56 | def check_update(self, obj):
57 | temp_n = self.N
58 | self.N = max(obj['vi'].shape[0], self.N)
59 | self.N = max(obj['f'].shape[0], self.N)
60 | if not (obj['vt'] is None):
61 | self.N = max(obj['vt'].shape[0], self.N)
62 |
63 | if self.N > temp_n:
64 | self.N *= 2
65 | self.change_all()
66 | self.camera_light()
67 |
68 | def add_model(self, obj, tex=None):
69 | self.check_update(obj)
70 | model = t3.StaticModel(self.N, obj=obj, tex=tex)
71 | self.scene.add_model(model)
72 |
73 | def modify_model(self, index, obj, tex=None):
74 | self.check_update(obj)
75 | self.scene.models[index].init_obj = obj
76 | self.scene.models[index].init_tex = tex
77 | self.scene.models[index]._init()
78 |
79 | def camera_light(self):
80 | camera = t3.Camera(res=self.src_res)
81 | self.scene.add_camera(camera)
82 |
83 | camera_hr = t3.Camera(res=self.hr_res)
84 | self.scene.add_camera(camera_hr)
85 |
86 | light_dir = np.array([0, 0, 1])
87 | light_list = []
88 | for l in range(6):
89 | rotate = np.matmul(rotationX(math.radians(np.random.uniform(-30, 30))),
90 | rotationY(math.radians(360 // 6 * l)))
91 | dir = [*np.matmul(rotate, light_dir)]
92 | light = t3.Light(dir, color=[1.0, 1.0, 1.0])
93 | light_list.append(light)
94 | lights = t3.Lights(light_list)
95 | self.scene.add_lights(lights)
96 |
97 |
98 | def render_data(renderer, data_path, phase, data_id, save_path, cam_nums, res, dis=1.0, is_thuman=False):
99 | obj_path = os.path.join(data_path, phase, data_id, '%s.obj' % data_id)
100 | texture_path = data_path
101 | img_path = os.path.join(texture_path, phase, data_id, 'material0.jpeg')
102 | texture = cv2.imread(img_path)[:, :, ::-1]
103 | texture = np.ascontiguousarray(texture)
104 | texture = texture.swapaxes(0, 1)[:, ::-1, :]
105 | obj = t3.readobj(obj_path, scale=1)
106 |
107 | # height normalization
108 | vy_max = np.max(obj['vi'][:, 1])
109 | vy_min = np.min(obj['vi'][:, 1])
110 | human_height = 1.80 + np.random.uniform(-0.05, 0.05, 1)
111 | obj['vi'][:, :3] = obj['vi'][:, :3] / (vy_max - vy_min) * human_height
112 | obj['vi'][:, 1] -= np.min(obj['vi'][:, 1])
113 | look_at_center = np.array([0, 0.85, 0])
114 | base_cam_pitch = -8
115 |
116 | # randomly move the scan
117 | move_range = 0.1 if human_height < 1.80 else 0.05
118 | delta_x = np.max(obj['vi'][:, 0]) - np.min(obj['vi'][:, 0])
119 | delta_z = np.max(obj['vi'][:, 2]) - np.min(obj['vi'][:, 2])
120 | if delta_x > 1.0 or delta_z > 1.0:
121 | move_range = 0.01
122 | obj['vi'][:, 0] += np.random.uniform(-move_range, move_range, 1)
123 | obj['vi'][:, 2] += np.random.uniform(-move_range, move_range, 1)
124 |
125 | if len(renderer.scene.models) >= 1:
126 | renderer.modify_model(0, obj, texture)
127 | else:
128 | renderer.add_model(obj, texture)
129 |
130 | degree_interval = 360 / cam_nums
131 | angle_list1 = list(range(360-int(degree_interval//2), 360))
132 | angle_list2 = list(range(0, 0+int(degree_interval//2)))
133 | angle_list = angle_list1 + angle_list2
134 | angle_base = np.random.choice(angle_list, 1)[0]
135 | if is_thuman:
136 | # thuman needs a normalization of orientation
137 | smpl_path = os.path.join(data_path, 'THuman2.0_Smpl_X_Paras', data_id, 'smplx_param.pkl')
138 | with open(smpl_path, 'rb') as f:
139 | smpl_para = pickle.load(f)
140 |
141 | y_orient = smpl_para['global_orient'][0][1]
142 | angle_base += (y_orient*180.0/np.pi)
143 |
144 | for pid in range(cam_nums):
145 | angle = angle_base + pid * degree_interval
146 |
147 | def render(dis, angle, look_at_center, p, renderer, render_hr=False):
148 | ori_vec = np.array([0, 0, dis])
149 | rotate = np.matmul(rotationY(math.radians(angle)), rotationX(math.radians(p)))
150 | fwd = np.matmul(rotate, ori_vec)
151 | cam_pos = look_at_center + fwd
152 |
153 | x_min = 0
154 | y_min = -25
155 | cx = res[0] * 0.5
156 | cy = res[1] * 0.5
157 | fx = res[0] * 0.8
158 | fy = res[1] * 0.8
159 | _cx = cx - x_min
160 | _cy = cy - y_min
161 | renderer.scene.cameras[0].set_intrinsic(fx, fy, _cx, _cy)
162 | renderer.scene.cameras[0].set(pos=cam_pos, target=look_at_center)
163 | renderer.scene.cameras[0]._init()
164 |
165 | if render_hr:
166 | fx = res[0] * 0.8 * 2
167 | fy = res[1] * 0.8 * 2
168 | _cx = (res[0] * 0.5 - x_min) * 2
169 | _cy = (res[1] * 0.5 - y_min) * 2
170 | renderer.scene.cameras[1].set_intrinsic(fx, fy, _cx, _cy)
171 | renderer.scene.cameras[1].set(pos=cam_pos, target=look_at_center)
172 | renderer.scene.cameras[1]._init()
173 |
174 | renderer.scene.render()
175 | camera = renderer.scene.cameras[0]
176 | camera_hr = renderer.scene.cameras[1]
177 | extrinsic = camera.export_extrinsic()
178 | intrinsic = camera.export_intrinsic()
179 | depth_map = camera.zbuf.to_numpy().swapaxes(0, 1)
180 | img = camera.img.to_numpy().swapaxes(0, 1)
181 | img_hr = camera_hr.img.to_numpy().swapaxes(0, 1)
182 | mask = camera.mask.to_numpy().swapaxes(0, 1)
183 | return extrinsic, intrinsic, depth_map, img, mask, img_hr
184 |
185 | renderer.scene.render()
186 | camera = renderer.scene.cameras[0]
187 | extrinsic = camera.export_extrinsic()
188 | intrinsic = camera.export_intrinsic()
189 | depth_map = camera.zbuf.to_numpy().swapaxes(0, 1)
190 | img = camera.img.to_numpy().swapaxes(0, 1)
191 | mask = camera.mask.to_numpy().swapaxes(0, 1)
192 | return extrinsic, intrinsic, depth_map, img, mask
193 |
194 |
195 | extr, intr, depth, img, mask = render(dis, angle, look_at_center, base_cam_pitch, renderer)
196 | save(pid, data_id, 0, save_path, extr, intr, depth, img, mask)
197 | extr, intr, depth, img, mask = render(dis, (angle+degree_interval) % 360, look_at_center, base_cam_pitch, renderer)
198 | save(pid, data_id, 1, save_path, extr, intr, depth, img, mask)
199 |
200 | # three novel viewpoints between source views
201 | angle1 = (angle + (np.random.uniform() * degree_interval / 2)) % 360
202 | angle2 = (angle + degree_interval / 2) % 360
203 | angle3 = (angle + degree_interval - (np.random.uniform() * degree_interval / 2)) % 360
204 |
205 | extr, intr, depth, img, mask, img_hr = render(dis, angle1, look_at_center, base_cam_pitch, renderer, render_hr=True)
206 | save(pid, data_id, 2, save_path, extr, intr, depth, img, mask, img_hr)
207 | extr, intr, depth, img, mask, img_hr = render(dis, angle2, look_at_center, base_cam_pitch, renderer, render_hr=True)
208 | save(pid, data_id, 3, save_path, extr, intr, depth, img, mask, img_hr)
209 | extr, intr, depth, img, mask, img_hr = render(dis, angle3, look_at_center, base_cam_pitch, renderer, render_hr=True)
210 | save(pid, data_id, 4, save_path, extr, intr, depth, img, mask, img_hr)
211 |
212 |
213 | if __name__ == '__main__':
214 | cam_nums = 16
215 | scene_radius = 2.0
216 | res = (1024, 1024)
217 | thuman_root = 'PATH/TO/THuman2.0'
218 | save_root = 'PATH/TO/SAVE/RENDERED/DATA'
219 |
220 | np.random.seed(1314)
221 | renderer = StaticRenderer(src_res=res)
222 |
223 | for phase in ['train', 'val']:
224 | thuman_list = sorted(os.listdir(os.path.join(thuman_root, phase)))
225 | save_path = os.path.join(save_root, phase)
226 |
227 | for data_id in tqdm(thuman_list):
228 | render_data(renderer, thuman_root, phase, data_id, save_path, cam_nums, res, dis=scene_radius, is_thuman=True)
229 |
--------------------------------------------------------------------------------
/prepare_data/taichi_three/__init__.py:
--------------------------------------------------------------------------------
1 | from .version import version as __version__
2 | from .scene import *
3 | from .model import *
4 | from .scatter import *
5 | from .geometry import *
6 | from .loader import *
7 | from .light import *
8 |
--------------------------------------------------------------------------------
/prepare_data/taichi_three/common.py:
--------------------------------------------------------------------------------
1 | class AutoInit:
2 | def init(self):
3 | if not hasattr(self, '_AutoInit_had_init'):
4 | self._init()
5 | self._AutoInit_had_init = True
6 |
7 | def _init(self):
8 | raise NotImplementedError
9 |
--------------------------------------------------------------------------------
/prepare_data/taichi_three/geometry.py:
--------------------------------------------------------------------------------
1 | import math
2 | import taichi as ti
3 | import taichi.math as ts
4 |
5 |
6 | @ti.func
7 | def render_triangle(model, camera, face, lights):
8 | scene = model.scene
9 | _1 = ti.static(min(1, model.faces.m - 1))
10 | _2 = ti.static(min(2, model.faces.m - 1))
11 | ia, ib, ic = model.vi[face[0, 0]], model.vi[face[1, 0]], model.vi[face[2, 0]]
12 | ca, cb, cc = ia, ib, ic
13 | if model.type[None] >= 1:
14 | ca, cb, cc = model.vc[face[0, 0]], model.vc[face[1, 0]], model.vc[face[2, 0]]
15 | ta, tb, tc = model.vt[face[0, _1]], model.vt[face[1, _1]], model.vt[face[2, _1]]
16 | na, nb, nc = model.vn[face[0, _2]], model.vn[face[1, _2]], model.vn[face[2, _2]]
17 | a = camera.untrans_pos(ia)
18 | b = camera.untrans_pos(ib)
19 | c = camera.untrans_pos(ic)
20 |
21 | # NOTE: the normal computation indicates that # a front-facing face should
22 | # be COUNTER-CLOCKWISE, i.e., glFrontFace(GL_CCW);
23 | # this is to be compatible with obj model loading.
24 | normal = (a - b).cross(a - c).normalized()
25 | pos = (a + b + c) / 3
26 | view_pos = (a + b + c) / 3
27 | shading = (view_pos.normalized().dot(normal) - 0.5)*2.0
28 | if ti.static(camera.type == camera.ORTHO):
29 | view_pos = ti.Vector([0.0, 0.0, 1.0], ti.f32)
30 | reverse_v = 1
31 | if ti.static(model.reverse == True):
32 | reverse_v = -1
33 | if ts.dot(view_pos, normal)*reverse_v <= 0:
34 | # shading
35 | color = ti.Vector([0.0, 0.0, 0.0], ti.f32)
36 | for i in range(lights.n):
37 | light_dir = lights.light_dirs[i]
38 | # print(light_dir)
39 | light_dir = camera.untrans_dir(light_dir)
40 | light_color = lights.light_colors[i]
41 | color += scene.opt.render_func(pos, normal, ti.Vector([0.0, 0.0, 0.0], ti.f32), light_dir, light_color)
42 | color = scene.opt.pre_process(color)
43 | A = camera.uncook(a)
44 | B = camera.uncook(b)
45 | C = camera.uncook(c)
46 | scr_norm = 1 / (A - C).cross(B - A)
47 | B_A = (B - A) * scr_norm
48 | C_B = (C - B) * scr_norm
49 | A_C = (A - C) * scr_norm
50 |
51 | W = 1
52 | # screen space bounding box
53 | M, N = int(ti.floor(min(A, B, C) - W)), int(ti.ceil(max(A, B, C) + W))
54 | M.x, N.x = min(max(M.x, 0), camera.img.shape[0]), min(max(N.x, 0), camera.img.shape[1])
55 | M.y, N.y = min(max(M.y, 0), camera.img.shape[0]), min(max(N.y, 0), camera.img.shape[1])
56 | for X in ti.grouped(ti.ndrange((M.x, N.x), (M.y, N.y))):
57 | # barycentric coordinates using the area method
58 | X_A = X - A
59 | w_C = B_A.cross(X_A)
60 | w_B = A_C.cross(X_A)
61 | w_A = 1 - w_C - w_B
62 | # draw
63 | in_screen = w_A >= 0 and w_B >= 0 and w_C >= 0 and 0 < X[0] < camera.img.shape[0] and 0 < X[1] < camera.img.shape[1]
64 | if not in_screen:
65 | continue
66 | zindex = 0.0
67 | if ti.static(model.reverse == True):
68 | zindex = (a.z * w_A + b.z * w_B + c.z * w_C)
69 | else:
70 | zindex = 1 / (a.z * w_A + b.z * w_B + c.z * w_C)
71 | if zindex < ti.atomic_max(camera.zbuf[X], zindex):
72 | continue
73 |
74 | coor = (ta * w_A + tb * w_B + tc * w_C)
75 | if model.type[None] == 1:
76 | camera.img[X] = (ca * w_A + cb * w_B + cc * w_C)
77 | elif model.type[None] == 2:
78 | camera.img[X] = color * (ca * w_A + cb * w_B + cc * w_C)
79 | else:
80 | camera.img[X] = color * model.texSample(coor)
81 | camera.mask[X] = ti.Vector([1.0, 1.0, 1.0], ti.f32)
82 | camera.normal_map[X] = -normal / 2.0 + 0.5
83 | # camera.normal_map[X] = ts.vec3(shading / 2.0 + 0.5)
84 |
85 | @ti.func
86 | def render_particle(model, camera, vertex, radius):
87 | scene = model.scene
88 | L2W = model.L2W
89 | a = camera.untrans_pos(L2W @ vertex)
90 | A = camera.uncook(a)
91 |
92 | M = int(ti.floor(A - radius))
93 | N = int(ti.ceil(A + radius))
94 |
95 | for X in ti.grouped(ti.ndrange((M.x, N.x), (M.y, N.y))):
96 | if X.x < 0 or X.x >= camera.res[0] or X.y < 0 or X.y >= camera.res[1]:
97 | continue
98 | if (X - A).norm_sqr() > radius**2:
99 | continue
100 |
101 | camera.img[X] = ti.Vector([1.0, 1.0, 1.0], ti.f32)
102 |
--------------------------------------------------------------------------------
/prepare_data/taichi_three/light.py:
--------------------------------------------------------------------------------
1 | import taichi as ti
2 | from .common import *
3 | import math
4 |
5 | '''
6 | The base light class represents a directional light.
7 | '''
8 | @ti.data_oriented
9 | class Light(AutoInit):
10 |
11 | def __init__(self, dir=None, color=None):
12 | dir = dir or [0, 0, 1]
13 | norm = math.sqrt(sum(x ** 2 for x in dir))
14 | dir = [x / norm for x in dir]
15 |
16 | self.dir_py = [-x for x in dir]
17 | self.color_py = color or [1, 1, 1]
18 |
19 | self.dir = ti.Vector.field(3, ti.float32, ())
20 | self.color = ti.Vector.field(3, ti.float32, ())
21 | # store the current light direction in the view space
22 | # so that we don't have to compute it for each vertex
23 | self.viewdir = ti.Vector.field(3, ti.float32, ())
24 |
25 | def set(self, dir=[0, 0, 1], color=[1, 1, 1]):
26 | norm = math.sqrt(sum(x**2 for x in dir))
27 | dir = [x / norm for x in dir]
28 | self.dir_py = dir
29 | self.color = color
30 |
31 | def _init(self):
32 | self.dir[None] = self.dir_py
33 | self.color[None] = self.color_py
34 |
35 | @ti.func
36 | def intensity(self, pos):
37 | return 1
38 |
39 | @ti.func
40 | def get_color(self):
41 | return self.color[None]
42 |
43 | @ti.func
44 | def get_dir(self):
45 | return self.viewdir
46 |
47 | @ti.func
48 | def set_view(self, camera):
49 | self.viewdir[None] = camera.untrans_dir(self.dir[None])
50 |
51 |
52 | @ti.data_oriented
53 | class Lights(AutoInit):
54 |
55 | def __init__(self, light_list):
56 | n = len(light_list)
57 | self.n = n
58 | self.light_list = light_list
59 | self.light_dirs = ti.Vector.field(3, ti.float32, n)
60 | self.light_colors = ti.Vector.field(3, ti.float32, n)
61 |
62 | def init_data(self):
63 | index = 0
64 | for light in self.light_list:
65 | self.light_dirs[index] = self.light_list[index].dir_py
66 | self.light_colors[index] = self.light_list[index].color_py
67 | index += 1
68 |
69 | class PointLight(Light):
70 | pass
71 |
72 |
--------------------------------------------------------------------------------
/prepare_data/taichi_three/loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 |
5 | def _append(faces, indices):
6 | if len(indices) == 4:
7 | faces.append([indices[0], indices[1], indices[2]])
8 | faces.append([indices[2], indices[3], indices[0]])
9 | elif len(indices) == 3:
10 | faces.append(indices)
11 | else:
12 | assert False, len(indices)
13 |
14 |
15 | def readobj(path, scale=1):
16 | vi = []
17 | vt = []
18 | vn = []
19 | faces = []
20 |
21 | with open(path, 'r') as myfile:
22 | lines = myfile.readlines()
23 |
24 | # cache vertices
25 | for line in lines:
26 | try:
27 | type, fields = line.split(maxsplit=1)
28 | fields = [float(_) for _ in fields.split()]
29 | except ValueError:
30 | continue
31 |
32 | if type == 'v':
33 | vi.append(fields)
34 | elif type == 'vt':
35 | vt.append(fields)
36 | elif type == 'vn':
37 | vn.append(fields)
38 |
39 | # cache faces
40 | for line in lines:
41 | try:
42 | type, fields = line.split(maxsplit=1)
43 | fields = fields.split()
44 | except ValueError:
45 | continue
46 |
47 | # line looks like 'f 5/1/1 1/2/1 4/3/1'
48 | # or 'f 314/380/494 382/400/494 388/550/494 506/551/494' for quads
49 | if type != 'f':
50 | continue
51 |
52 | # a field should look like '5/1/1'
53 | # for vertex/vertex UV coords/vertex Normal (indexes number in the list)
54 | # the index in 'f 5/1/1 1/2/1 4/3/1' STARTS AT 1 !!!
55 |
56 | indices = [[int(_) - 1 if _ != '' else 0 for _ in field.split('/')] for field in fields]
57 |
58 | if len(indices) == 4:
59 | faces.append([indices[0], indices[1], indices[2]])
60 | faces.append([indices[2], indices[3], indices[0]])
61 | elif len(indices) == 3:
62 | faces.append(indices)
63 | else:
64 | assert False, len(indices)
65 |
66 | ret = {}
67 | ret['vi'] = None if len(vi) == 0 else np.array(vi).astype(np.float32) * scale
68 | ret['vt'] = None if len(vt) == 0 else np.array(vt).astype(np.float32)
69 | ret['vn'] = None if len(vn) == 0 else np.array(vn).astype(np.float32)
70 | ret['f'] = None if len(faces) == 0 else np.array(faces).astype(np.int32)
71 | return ret
72 |
--------------------------------------------------------------------------------
/prepare_data/taichi_three/meshgen.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import taichi as ti
3 | import taichi_glsl as tl
4 | import math
5 |
6 |
7 | def _pre(x):
8 | if not isinstance(x, ti.Matrix):
9 | x = ti.Vector(x)
10 | return x
11 |
12 |
13 | def _ser(foo):
14 | def wrapped(self, *args, **kwargs):
15 | foo(self, *args, **kwargs)
16 | return self
17 |
18 | return wrapped
19 |
20 |
21 | def _mparg(foo):
22 | def wrapped(self, *args):
23 | if len(args) > 1:
24 | return [foo(self, x) for x in args]
25 | else:
26 | return foo(self, args[0])
27 |
28 | return wrapped
29 |
30 |
31 | class MeshGen:
32 | def __init__(self):
33 | self.v = []
34 | self.f = []
35 |
36 | @_ser
37 | def quad(self, a, b, c, d):
38 | a, b, c, d = self.add_v(a, b, c, d)
39 | self.add_f([a, b, c], [c, d, a])
40 |
41 | @_ser
42 | def cube(self, a, b):
43 | aaa = self.add_v(tl.mix(a, b, tl.D.yyy))
44 | baa = self.add_v(tl.mix(a, b, tl.D.xyy))
45 | aba = self.add_v(tl.mix(a, b, tl.D.yxy))
46 | aab = self.add_v(tl.mix(a, b, tl.D.yyx))
47 | bba = self.add_v(tl.mix(a, b, tl.D.xxy))
48 | abb = self.add_v(tl.mix(a, b, tl.D.yxx))
49 | bab = self.add_v(tl.mix(a, b, tl.D.xyx))
50 | bbb = self.add_v(tl.mix(a, b, tl.D.xxx))
51 |
52 | self.add_f4([aaa, aba, bba, baa]) # back
53 | self.add_f4([aab, bab, bbb, abb]) # front
54 | self.add_f4([aaa, aab, abb, aba]) # left
55 | self.add_f4([baa, bba, bbb, bab]) # right
56 | self.add_f4([aaa, baa, bab, aab]) # bottom
57 | self.add_f4([aba, abb, bbb, bba]) # top
58 |
59 | @_ser
60 | def cylinder(self, bottom, top, dir1, dir2, N):
61 | bottom = _pre(bottom)
62 | top = _pre(top)
63 | dir1 = _pre(dir1)
64 | dir2 = _pre(dir2)
65 |
66 | B, T = [], []
67 | for i in range(N):
68 | disp = tl.mat(dir1.entries, dir2.entries).T() @ tl.vecAngle(tl.math.tau * i / N)
69 | B.append(self.add_v(bottom + disp))
70 | T.append(self.add_v(top + disp))
71 |
72 | BC = self.add_v(bottom)
73 | TC = self.add_v(top)
74 |
75 | for i in range(N):
76 | j = (i + 1) % N
77 | self.add_f4([B[i], B[j], T[j], T[i]])
78 |
79 | for i in range(N):
80 | j = (i + 1) % N
81 | self.add_f([B[j], B[i], BC])
82 | self.add_f([T[i], T[j], TC])
83 |
84 | @_ser
85 | def tri(self, a, b, c):
86 | a, b, c = self.add_v(a, b, c)
87 | self.add_f([a, b, c])
88 |
89 | @_mparg
90 | def add_v(self, v):
91 | if isinstance(v, ti.Matrix):
92 | v = v.entries
93 | ret = len(self.v)
94 | self.v.append(v)
95 | return ret
96 |
97 | @_mparg
98 | def add_f(self, f):
99 | ret = len(self.f)
100 | self.f.append(f)
101 | return ret
102 |
103 | @_mparg
104 | def add_f4(self, f):
105 | a, b, c, d = f
106 | return self.add_f([a, b, c], [c, d, a])
107 |
108 |
109 | def __getitem__(self, key):
110 | if key == 'v':
111 | return np.array(self.v)
112 | if key == 'f':
113 | return np.array(self.f)
114 |
--------------------------------------------------------------------------------
/prepare_data/taichi_three/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import taichi as ti
3 | import taichi.math as ts
4 | from .geometry import *
5 | from .transform import *
6 | from .common import *
7 | import math
8 |
9 |
10 | @ti.func
11 | def sample(field: ti.template(), P):
12 | '''
13 | Sampling a field with indices clampped into the field shape.
14 | :parameter field: (Tensor)
15 | Specify the field to sample.
16 | :parameter P: (Vector)
17 | Specify the index in field.
18 | :return:
19 | The return value is calcuated as::
20 | P = clamp(P, 0, vec(*field.shape) - 1)
21 | return field[int(P)]
22 | '''
23 | shape = ti.Vector(field.shape)
24 | P = ts.clamp(P, 0, shape - 1)
25 | return field[int(P)]
26 |
27 | @ti.func
28 | def bilerp(field: ti.template(), P):
29 | '''
30 | Bilinear sampling an 2D field with a real index.
31 | :parameter field: (2D Tensor)
32 | Specify the field to sample.
33 | :parameter P: (2D Vector of float)
34 | Specify the index in field.
35 | :note:
36 | If one of the element to be accessed is out of `field.shape`, then
37 | `bilerp` will automatically do a clamp for you, see :func:`sample`.
38 | :return:
39 | The return value is calcuated as::
40 | I = int(P)
41 | x = fract(P)
42 | y = 1 - x
43 | return (sample(field, I + D.xx) * x.x * x.y +
44 | sample(field, I + D.xy) * x.x * y.y +
45 | sample(field, I + D.yy) * y.x * y.y +
46 | sample(field, I + D.yx) * y.x * x.y)
47 | .. where D = vec(1, 0, -1)
48 | '''
49 | I = int(P)
50 | x = ts.fract(P)
51 | y = 1 - x
52 | D = ts.vec3(1, 0, -1)
53 | return (sample(field, I + D.xx) * x.x * x.y +
54 | sample(field, I + D.xy) * x.x * y.y +
55 | sample(field, I + D.yy) * y.x * y.y +
56 | sample(field, I + D.yx) * y.x * x.y)
57 |
58 | @ti.data_oriented
59 | class Model(AutoInit):
60 | TEX = 0
61 | COLOR = 1
62 |
63 | def __init__(self, f_n=None, f_m=None,
64 | vi_n=None, vt_n=None, vn_n=None, tex_n=None, col_n=None,
65 | obj=None, tex=None):
66 |
67 | self.faces = None
68 | self.vi = None
69 | self.vt = None
70 | self.vn = None
71 | self.tex = None
72 | self.type = ti.field(dtype=ti.int32, shape=())
73 | self.reverse = False
74 |
75 | if obj is not None:
76 | f_n = None if obj['f'] is None else obj['f'].shape[0]
77 | vi_n = None if obj['vi'] is None else obj['vi'].shape[0]
78 | vt_n = None if obj['vt'] is None else obj['vt'].shape[0]
79 | vn_n = None if obj['vn'] is None else obj['vn'].shape[0]
80 |
81 | if tex is not None:
82 | tex_n = tex.shape[:2]
83 |
84 | if f_m is None:
85 | f_m = 1
86 | if vt_n is not None:
87 | f_m = 2
88 | if vn_n is not None:
89 | f_m = 3
90 |
91 | if vi_n is None:
92 | vi_n = 1
93 | if vt_n is None:
94 | vt_n = 1
95 | if vn_n is None:
96 | vn_n = 1
97 | if col_n is None:
98 | col_n = 1
99 |
100 | if f_n is not None:
101 | self.faces = ti.Matrix.field(3, f_m, ti.i32, f_n)
102 | if vi_n is not None:
103 | self.vi = ti.Vector.field(3, ti.f32, vi_n)
104 | if vt_n is not None:
105 | self.vt = ti.Vector.field(2, ti.f32, vt_n)
106 | if vn_n is not None:
107 | self.vn = ti.Vector.field(3, ti.f32, vn_n)
108 | if tex_n is not None:
109 | self.tex = ti.Vector.field(3, ti.f32, tex_n)
110 | if col_n is not None:
111 | self.vc = ti.Vector.field(3, ti.f32, col_n)
112 |
113 | if obj is not None:
114 | self.init_obj = obj
115 | if tex is not None:
116 | self.init_tex = tex
117 |
118 | def from_obj(self, obj):
119 | if obj['f'] is not None:
120 | self.faces.from_numpy(obj['f'])
121 | if obj['vi'] is not None:
122 | self.vi.from_numpy(obj['vi'])
123 | if obj['vt'] is not None:
124 | self.vt.from_numpy(obj['vt'])
125 | if obj['vn'] is not None:
126 | self.vn.from_numpy(obj['vn'])
127 |
128 | def _init(self):
129 | self.type[None] = 0
130 | if hasattr(self, 'init_obj'):
131 | self.from_obj(self.init_obj)
132 | if hasattr(self, 'init_tex'):
133 | self.tex.from_numpy(self.init_tex.astype(np.float32) / 255)
134 |
135 | @ti.func
136 | def render(self, camera):
137 | for i in ti.grouped(self.faces):
138 | render_triangle(self, camera, self.faces[i])
139 |
140 | @ti.func
141 | def texSample(self, coor):
142 | if ti.static(self.tex is not None):
143 | return ts.bilerp(self.tex, coor * ts.vec(*self.tex.shape))
144 | else:
145 | return 1
146 |
147 |
148 | @ti.data_oriented
149 | class StaticModel(AutoInit):
150 | TEX = 0
151 | COLOR = 1
152 |
153 | def __init__(self, N, f_m=None, col_n=None,
154 | obj=None, tex=None):
155 | self.faces = None
156 | self.vi = None
157 | self.vt = None
158 | self.vn = None
159 | self.tex = None
160 | # 0 origin 1 pure color 2 shader color
161 | self.type = ti.field(dtype=ti.int32, shape=())
162 | self.f_n = ti.field(dtype=ti.int32, shape=())
163 | self.reverse = False
164 | self.N = N
165 |
166 | if obj is not None:
167 | f_n = None if obj['f'] is None else obj['f'].shape[0]
168 | vi_n = None if obj['vi'] is None else obj['vi'].shape[0]
169 | vt_n = None if obj['vt'] is None else obj['vt'].shape[0]
170 | vn_n = None if obj['vn'] is None else obj['vn'].shape[0]
171 |
172 | if not (tex is None):
173 | tex_n = tex.shape[:2]
174 | else:
175 | tex_n = None
176 |
177 | if f_m is None:
178 | f_m = 1
179 | if vt_n is not None:
180 | f_m = 2
181 | if vn_n is not None:
182 | f_m = 3
183 |
184 | if vi_n is None:
185 | vi_n = 1
186 | if vt_n is None:
187 | vt_n = 1
188 | if vn_n is None:
189 | vn_n = 1
190 | if col_n is None:
191 | col_n = 1
192 |
193 | if f_n is not None:
194 | self.faces = ti.Matrix.field(3, f_m, ti.i32, N)
195 | if vi_n is not None:
196 | self.vi = ti.Vector.field(3, ti.f32, N)
197 | if vt_n is not None:
198 | self.vt = ti.Vector.field(2, ti.f32, N)
199 | if vn_n is not None:
200 | self.vn = ti.Vector.field(3, ti.f32, N)
201 | if not (tex_n is None):
202 | self.tex = ti.Vector.field(3, ti.f32, tex_n)
203 | if col_n is not None:
204 | self.vc = ti.Vector.field(3, ti.f32, N)
205 |
206 | if obj is not None:
207 | self.init_obj = obj
208 | if tex is not None:
209 | self.init_tex = tex
210 |
211 | def modify_color(self, color):
212 | s_color = np.zeros((self.N, 3)).astype(np.float32)
213 | s_color[:color.shape[0]] = color
214 | self.vc.from_numpy(s_color)
215 |
216 | def from_obj(self, obj):
217 | N = self.N
218 | if obj['f'] is not None:
219 | s_faces = np.zeros((N, obj['f'].shape[1], obj['f'].shape[2])).astype(int)
220 | s_faces[:obj['f'].shape[0]] = obj['f']
221 | self.f_n[None] = obj['f'].shape[0]
222 | self.faces.from_numpy(s_faces)
223 | if obj['vi'] is not None:
224 | s_vi = np.zeros((N, 3)).astype(np.float32)
225 | s_vi[:obj['vi'].shape[0]] = obj['vi'][:, :3]
226 | self.vi.from_numpy(s_vi)
227 | if obj['vt'] is not None:
228 | s_vt = np.zeros((N, 2)).astype(np.float32)
229 | s_vt[:obj['vt'].shape[0]] = obj['vt']
230 | self.vt.from_numpy(s_vt)
231 | if obj['vn'] is not None:
232 | s_vn = np.zeros((N, 3)).astype(np.float32)
233 | s_vn[:obj['vn'].shape[0]] = obj['vn']
234 | self.vn.from_numpy(s_vn)
235 |
236 | def _init(self):
237 | self.type[None] = 0
238 | if hasattr(self, 'init_obj'):
239 | self.from_obj(self.init_obj)
240 | if hasattr(self, 'init_tex') and (self.init_tex is not None):
241 | self.tex.from_numpy(self.init_tex.astype(np.float32) / 255)
242 |
243 | @ti.func
244 | def render(self, camera, lights):
245 | for i in ti.grouped(self.faces):
246 | render_triangle(self, camera, self.faces[i], lights)
247 |
248 | @ti.func
249 | def texSample(self, coor):
250 | if ti.static(self.tex is not None):
251 | return bilerp(self.tex, coor * ts.vec2(*self.tex.shape))
252 | else:
253 | return 1
254 |
--------------------------------------------------------------------------------
/prepare_data/taichi_three/raycast.py:
--------------------------------------------------------------------------------
1 | import taichi as ti
2 | import taichi.math as ts
3 | from .scene import *
4 | import math
5 |
6 |
7 | EPS = 1e-3
8 | INF = 1e3
9 |
10 |
11 | @ti.data_oriented
12 | class ObjectRT(ts.TaichiClass):
13 | @ti.func
14 | def calc_sdf(self, p):
15 | ret = INF
16 | for I in ti.grouped(ti.ndrange(*self.pos.shape())):
17 | ret = min(ret, self.make_one(I).do_calc_sdf(p))
18 | return ret
19 |
20 | @ti.func
21 | def intersect(self, orig, dir):
22 | ret, normal = INF, ts.vec3(0.0)
23 | for I in ti.grouped(ti.ndrange(*self.pos.shape())):
24 | t, n = self.make_one(I).do_intersect(orig, dir)
25 | if t < ret:
26 | ret, normal = t, n
27 | return ret, normal
28 |
29 | def do_calc_sdf(self, p):
30 | raise NotImplementedError
31 |
32 | def do_intersect(self, orig, dir):
33 | raise NotImplementedError
34 |
35 |
36 | @ti.data_oriented
37 | class Ball(ObjectRT):
38 | def __init__(self, pos, radius):
39 | self.pos = pos
40 | self.radius = radius
41 |
42 | @ti.func
43 | def make_one(self, I):
44 | return Ball(self.pos[I], self.radius[I])
45 |
46 | @ti.func
47 | def do_calc_sdf(self, p):
48 | return ts.distance(self.pos, p) - self.radius
49 |
50 | @ti.func
51 | def do_intersect(self, orig, dir):
52 | op = self.pos - orig
53 | b = op.dot(dir)
54 | det = b ** 2 - op.norm_sqr() + self.radius ** 2
55 | ret = INF
56 | if det > 0.0:
57 | det = ti.sqrt(det)
58 | t = b - det
59 | if t > EPS:
60 | ret = t
61 | else:
62 | t = b + det
63 | if t > EPS:
64 | ret = t
65 | return ret, ts.normalize(dir * ret - op)
66 |
67 |
68 | @ti.data_oriented
69 | class SceneRTBase(Scene):
70 | def __init__(self):
71 | super(SceneRTBase, self).__init__()
72 | self.balls = []
73 |
74 | def trace(self, pos, dir):
75 | raise NotImplementedError
76 |
77 | @ti.func
78 | def color_at(self, coor, camera):
79 | orig, dir = camera.generate(coor)
80 |
81 | pos, normal = self.trace(orig, dir)
82 | light_dir = self.light_dir[None]
83 |
84 | color = self.opt.render_func(pos, normal, dir, light_dir)
85 | color = self.opt.pre_process(color)
86 | return color
87 |
88 | @ti.kernel
89 | def _render(self):
90 | if ti.static(len(self.cameras)):
91 | for camera in ti.static(self.cameras):
92 | for I in ti.grouped(camera.img):
93 | coor = self.cook_coor(I, camera)
94 | color = self.color_at(coor, camera)
95 | camera.img[I] = color
96 |
97 | def add_ball(self, pos, radius):
98 | b = Ball(pos, radius)
99 | self.balls.append(b)
100 |
101 |
102 | @ti.data_oriented
103 | class SceneRT(SceneRTBase):
104 | @ti.func
105 | def intersect(self, orig, dir):
106 | ret, normal = INF, ts.vec3(0.0)
107 | for b in ti.static(self.balls):
108 | t, n = b.intersect(orig, dir)
109 | if t < ret:
110 | ret, normal = t, n
111 | return ret, normal
112 |
113 | @ti.func
114 | def trace(self, orig, dir):
115 | depth, normal = self.intersect(orig, dir)
116 | pos = orig + dir * depth
117 | return pos, normal
118 |
119 |
120 | @ti.data_oriented
121 | class SceneSDF(SceneRTBase):
122 | @ti.func
123 | def calc_sdf(self, p):
124 | ret = INF
125 | for b in ti.static(self.balls):
126 | ret = min(ret, b.calc_sdf(p))
127 | return ret
128 |
129 | @ti.func
130 | def calc_grad(self, p):
131 | return ts.vec(
132 | self.calc_sdf(p + ts.vec(EPS, 0, 0)),
133 | self.calc_sdf(p + ts.vec(0, EPS, 0)),
134 | self.calc_sdf(p + ts.vec(0, 0, EPS)))
135 |
136 | @ti.func
137 | def trace(self, orig, dir):
138 | pos = orig
139 | color = ts.vec3(0.0)
140 | normal = ts.vec3(0.0)
141 | for s in range(100):
142 | t = self.calc_sdf(pos)
143 | if t <= 0:
144 | normal = ts.normalize(self.calc_grad(pos) - t)
145 | break
146 | pos += dir * t
147 | return pos, normal
148 |
--------------------------------------------------------------------------------
/prepare_data/taichi_three/scatter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import taichi as ti
3 | import taichi.math as ts
4 | from .geometry import *
5 | from .transform import *
6 | from .common import *
7 | import math
8 |
9 |
10 | @ti.data_oriented
11 | class ScatterModel(AutoInit):
12 | def __init__(self, num=None, radius=2):
13 | self.L2W = Affine.field(())
14 |
15 | self.num = num
16 | self.radius = radius
17 |
18 | if num is not None:
19 | self.particles = ti.Vector.field(3, ti.i32, num)
20 |
21 | def _init(self):
22 | self.L2W.init()
23 |
24 | @ti.func
25 | def render(self, camera):
26 | for i in ti.grouped(self.particles):
27 | render_particle(self, camera, self.particles[i], self.radius)
28 |
--------------------------------------------------------------------------------
/prepare_data/taichi_three/scene.py:
--------------------------------------------------------------------------------
1 | import taichi as ti
2 | from .transform import *
3 | from .shading import *
4 | from .light import *
5 |
6 |
7 | @ti.data_oriented
8 | class Scene(AutoInit):
9 | def __init__(self):
10 | self.lights = []
11 | self.cameras = []
12 | self.opt = Shading()
13 | self.models = []
14 |
15 | @ti.func
16 | def cook_coor(self, I, camera):
17 | scale = ti.static(2 / min(*camera.img.shape()))
18 | coor = (I - ts.vec2(*camera.img.shape()) / 2) * scale
19 | return coor
20 |
21 | @ti.func
22 | def uncook_coor(self, coor, camera):
23 | scale = ti.static(min(*camera.img.shape()) / 2)
24 | I = coor.xy * scale + ts.vec2(*camera.img.shape()) / 2
25 | return I
26 |
27 | def add_model(self, model):
28 | model.scene = self
29 | self.models.append(model)
30 |
31 | def add_camera(self, camera):
32 | camera.scene = self
33 | self.cameras.append(camera)
34 |
35 | def add_lights(self, lights):
36 | lights.scene = self
37 | self.lights = lights
38 |
39 | def _init(self):
40 | for camera in self.cameras:
41 | camera.init()
42 | for model in self.models:
43 | model.init()
44 | self.lights.init_data()
45 |
46 |
47 | @ti.kernel
48 | def _single_render(self, num : ti.template()):
49 | if ti.static(len(self.cameras)):
50 | for camera in ti.static(self.cameras):
51 | camera.clear_buffer()
52 | self.models[num].render(camera, self.lights)
53 |
54 | def single_render(self, num):
55 | self.lights.init_data()
56 | for camera in self.cameras:
57 | camera.init()
58 | self.models[num].init()
59 | self._single_render(num)
60 |
61 | def render(self):
62 | self.init()
63 | self._render()
64 |
65 | @ti.kernel
66 | def _render(self):
67 | if ti.static(len(self.cameras)):
68 | for camera in ti.static(self.cameras):
69 | camera.clear_buffer()
70 | # sets up light directions
71 | if ti.static(len(self.models)):
72 | for model in ti.static(self.models):
73 | model.render(camera, self.lights)
74 |
--------------------------------------------------------------------------------
/prepare_data/taichi_three/shading.py:
--------------------------------------------------------------------------------
1 | import taichi as ti
2 | import taichi.math as tm
3 | from .transform import *
4 | import math
5 |
6 | class Shading:
7 | def __init__(self, **kwargs):
8 | self.is_normal_map = False
9 | self.lambert = 0.58
10 | self.half_lambert = 0.04
11 | self.blinn_phong = 0.3
12 | self.phong = 0.0
13 | self.shineness = 10
14 | self.__dict__.update(kwargs)
15 |
16 | @ti.func
17 | def render_func(self, pos, normal, dir, light_dir, light_color):
18 | color = ti.Vector([0.0, 0.0, 0.0], ti.f32)
19 | shineness = self.shineness
20 | half_lambert = normal.dot(light_dir) * 0.5 + 0.5
21 | lambert = max(0, normal.dot(light_dir))
22 | blinn_phong = normal.dot(tm.mix(light_dir, -dir, 0.5))
23 | blinn_phong = pow(max(blinn_phong, 0), shineness)
24 | refl_dir = tm.reflect(light_dir, normal)
25 | phong = -tm.dot(normal, refl_dir)
26 | phong = pow(max(phong, 0), shineness)
27 |
28 | strength = 0.0
29 | if ti.static(self.lambert != 0.0):
30 | strength += lambert * self.lambert
31 | if ti.static(self.half_lambert != 0.0):
32 | strength += half_lambert * self.half_lambert
33 | if ti.static(self.blinn_phong != 0.0):
34 | strength += blinn_phong * self.blinn_phong
35 | if ti.static(self.phong != 0.0):
36 | strength += phong * self.phong
37 | color = tm.vec3(strength)
38 |
39 | if ti.static(self.is_normal_map):
40 | color = normal * 0.5 + 0.5
41 | return color * light_color
42 |
43 | @ti.func
44 | def pre_process(self, color):
45 | blue = tm.vec3(0.00, 0.01, 0.05)
46 | orange = tm.vec3(1.19, 1.04, 0.98)
47 | return ti.sqrt(ts.mix(blue, orange, color))
48 |
--------------------------------------------------------------------------------
/prepare_data/taichi_three/transform.py:
--------------------------------------------------------------------------------
1 | import taichi as ti
2 | import taichi.math as ts
3 | from .common import *
4 | import math
5 |
6 |
7 | def rotationX(angle):
8 | return [
9 | [1, 0, 0],
10 | [0, math.cos(angle), -math.sin(angle)],
11 | [0, math.sin(angle), math.cos(angle)],
12 | ]
13 |
14 | def rotationY(angle):
15 | return [
16 | [ math.cos(angle), 0, math.sin(angle)],
17 | [ 0, 1, 0],
18 | [-math.sin(angle), 0, math.cos(angle)],
19 | ]
20 |
21 | def rotationZ(angle):
22 | return [
23 | [math.cos(angle), -math.sin(angle), 0],
24 | [math.sin(angle), math.cos(angle), 0],
25 | [ 0, 0, 1],
26 | ]
27 |
28 |
29 | @ti.data_oriented
30 | class Affine(AutoInit):
31 | @property
32 | def matrix(self):
33 | return self.entries[0]
34 |
35 | @property
36 | def offset(self):
37 | return self.entries[1]
38 |
39 | @classmethod
40 | def _field(cls, shape=None):
41 | return ti.Matrix.field(3, 3, ti.f32, shape), ti.Vector.field(3, ti.f32, shape)
42 |
43 | @ti.func
44 | def loadIdentity(self):
45 | self.matrix = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
46 | self.offset = [0, 0, 0]
47 |
48 | @ti.kernel
49 | def _init(self):
50 | self.loadIdentity()
51 |
52 | @ti.func
53 | def __matmul__(self, other):
54 | return self.matrix @ other + self.offset
55 |
56 | @ti.func
57 | def inverse(self):
58 | # TODO: incorrect:
59 | return Affine(self.matrix.inverse(), -self.offset)
60 |
61 | def loadOrtho(self, fwd=[0, 0, 1], up=[0, 1, 0]):
62 | # fwd = target - pos
63 | # fwd = fwd.normalized()
64 | fwd_len = math.sqrt(sum(x**2 for x in fwd))
65 | fwd = [x / fwd_len for x in fwd]
66 | # right = fwd.cross(up)
67 | right = [
68 | fwd[2] * up[1] - fwd[1] * up[2],
69 | fwd[0] * up[2] - fwd[2] * up[0],
70 | fwd[1] * up[0] - fwd[0] * up[1],
71 | ]
72 | # right = right.normalized()
73 | right_len = math.sqrt(sum(x**2 for x in right))
74 | right = [x / right_len for x in right]
75 | # up = right.cross(fwd)
76 | up = [
77 | right[2] * fwd[1] - right[1] * fwd[2],
78 | right[0] * fwd[2] - right[2] * fwd[0],
79 | right[1] * fwd[0] - right[0] * fwd[1],
80 | ]
81 |
82 | # trans = ti.Matrix.cols([right, up, fwd])
83 | trans = [right, up, fwd]
84 | trans = [[trans[i][j] for i in range(3)] for j in range(3)]
85 | self.matrix[None] = trans
86 |
87 | def from_mouse(self, mpos):
88 | if isinstance(mpos, ti.GUI):
89 | if mpos.is_pressed(ti.GUI.LMB):
90 | mpos = mpos.get_cursor_pos()
91 | else:
92 | mpos = (0, 0)
93 | a, t = mpos
94 | if a != 0 or t != 0:
95 | a, t = a * math.tau - math.pi, t * math.pi - math.pi / 2
96 | c = math.cos(t)
97 | self.loadOrtho(fwd=[c * math.sin(a), math.sin(t), c * math.cos(a)])
98 |
99 |
100 | @ti.data_oriented
101 | class Camera(AutoInit):
102 | ORTHO = 'Orthogonal'
103 | TAN_FOV = 'Tangent Perspective' # rectilinear perspective
104 | COS_FOV = 'Cosine Perspective' # curvilinear perspective, see en.wikipedia.org/wiki/Curvilinear_perspective
105 |
106 | def __init__(self, res=None, fx=None, fy=None, cx=None, cy=None,
107 | pos=[0, 0, -2], target=[0, 0, 0], up=[0, -1, 0], fov=30):
108 | self.res = res or (512, 512)
109 | self.img = ti.Vector.field(3, ti.f32, self.res)
110 | self.zbuf = ti.field(ti.f32, self.res)
111 | self.mask = ti.Vector.field(3, ti.f32, self.res)
112 | self.normal_map = ti.Vector.field(3, ti.f32, self.res)
113 | self.trans = ti.Matrix.field(3, 3, ti.f32, ())
114 | self.pos = ti.Vector.field(3, ti.f32, ())
115 | self.target = ti.Vector.field(3, ti.f32, ())
116 | self.intrinsic = ti.Matrix.field(3, 3, ti.f32, ())
117 | self.type = self.TAN_FOV
118 | self.fov = math.radians(fov)
119 |
120 | self.cx = cx or self.res[0] // 2
121 | self.cy = cy or self.res[1] // 2
122 | self.fx = fx or self.cx / math.tan(self.fov)
123 | self.fy = fy or self.cy / math.tan(self.fov)
124 | # python scope camera transformations
125 | self.pos_py = pos
126 | self.target_py = target
127 | self.trans_py = None
128 | self.up_py = up
129 | self.set(init=True)
130 | # mouse position for camera control
131 | self.mpos = (0, 0)
132 |
133 | def set_intrinsic(self, fx=None, fy=None, cx=None, cy=None):
134 | # see http://ais.informatik.uni-freiburg.de/teaching/ws09/robotics2/pdfs/rob2-08-camera-calibration.pdf
135 | self.fx = fx or self.fx
136 | self.fy = fy or self.fy
137 | self.cx = self.cx if cx is None else cx
138 | self.cy = self.cy if cy is None else cy
139 |
140 | '''
141 | NOTE: taichi_three uses a LEFT HANDED coordinate system.
142 | that is, the +Z axis points FROM the camera TOWARDS the scene,
143 | with X, Y being device coordinates
144 | '''
145 | def set(self, pos=None, target=None, up=None, init=False):
146 | pos = self.pos_py if pos is None else pos
147 | target = self.target_py if target is None else target
148 | up = self.up_py if up is None else up
149 | # fwd = target - pos
150 | fwd = [target[i] - pos[i] for i in range(3)]
151 | # fwd = fwd.normalized()
152 | fwd_len = math.sqrt(sum(x**2 for x in fwd))
153 | fwd = [x / fwd_len for x in fwd]
154 | # right = fwd.cross(up)
155 | # print(fwd, up)
156 | # left = [
157 | # fwd[2] * up[1] - fwd[1] * up[2],
158 | # fwd[0] * up[2] - fwd[2] * up[0],
159 | # fwd[1] * up[0] - fwd[0] * up[1],
160 | # ]
161 | right = [
162 | fwd[2] * up[1] - fwd[1] * up[2],
163 | fwd[0] * up[2] - fwd[2] * up[0],
164 | fwd[1] * up[0] - fwd[0] * up[1],
165 | ]
166 | right_len = math.sqrt(sum(x**2 for x in right))
167 | right = [x / right_len for x in right]
168 | # up = right.cross(fwd)
169 | up = [
170 | right[2] * fwd[1] - right[1] * fwd[2],
171 | right[0] * fwd[2] - right[2] * fwd[0],
172 | right[1] * fwd[0] - right[0] * fwd[1],
173 | ]
174 | # up = [
175 | # -left[2] * fwd[1] + left[1] * fwd[2],
176 | # -left[0] * fwd[2] + left[2] * fwd[0],
177 | # -left[1] * fwd[0] + left[0] * fwd[1],
178 | # ]
179 | # trans = ti.Matrix.cols([right, up, fwd])
180 | trans = [right, up, fwd]
181 | # print(right, up, fwd)
182 | self.trans_py = [[trans[i][j] for i in range(3)] for j in range(3)]
183 | # self.trans_py = [[trans[i][j] for j in range(3)] for i in range(3)]
184 | self.pos_py = pos
185 | self.target_py = target
186 | if not init:
187 | self.pos[None] = self.pos_py
188 | self.trans[None] = self.trans_py
189 | self.target[None] = self.target_py
190 |
191 | def set_extrinsic(self, trans, pose):
192 | trans = [[trans[i][j] for j in range(3)] for i in range(3)]
193 | self.trans_py = trans
194 | self.pos_py = pose
195 |
196 | def _init(self):
197 | self.pos[None] = self.pos_py
198 | self.trans[None] = self.trans_py
199 | self.target[None] = self.target_py
200 | self.intrinsic[None][0, 0] = self.fx
201 | self.intrinsic[None][0, 2] = self.cx
202 | self.intrinsic[None][1, 1] = self.fy
203 | self.intrinsic[None][1, 2] = self.cy
204 | self.intrinsic[None][2, 2] = 1.0
205 |
206 | @ti.func
207 | def clear_buffer(self):
208 | for I in ti.grouped(self.img):
209 | self.img[I] = ti.Vector([0.0, 0.0, 0.0], ti.f32)
210 | self.zbuf[I] = 0.0
211 | self.mask[I] = ti.Vector([0.0, 0.0, 0.0], ti.f32)
212 | self.normal_map[I] = ti.Vector([0.5, 0.5, 0.5], ti.f32)
213 |
214 | def from_mouse(self, gui):
215 | is_alter_move = gui.is_pressed(ti.GUI.CTRL)
216 | if gui.is_pressed(ti.GUI.LMB):
217 | mpos = gui.get_cursor_pos()
218 | if self.mpos != (0, 0):
219 | self.orbit((mpos[0] - self.mpos[0], mpos[1] - self.mpos[1]),
220 | pov=is_alter_move)
221 | self.mpos = mpos
222 | elif gui.is_pressed(ti.GUI.RMB):
223 | mpos = gui.get_cursor_pos()
224 | if self.mpos != (0, 0):
225 | self.zoom_by_mouse(mpos, (mpos[0] - self.mpos[0], mpos[1] - self.mpos[1]),
226 | dolly=is_alter_move)
227 | self.mpos = mpos
228 | elif gui.is_pressed(ti.GUI.MMB):
229 | mpos = gui.get_cursor_pos()
230 | if self.mpos != (0, 0):
231 | self.pan((mpos[0] - self.mpos[0], mpos[1] - self.mpos[1]))
232 | self.mpos = mpos
233 | else:
234 | if gui.event and gui.event.key == ti.GUI.WHEEL:
235 | # one mouse wheel unit is (0, 120)
236 | self.zoom(-gui.event.delta[1] / 1200,
237 | dolly=is_alter_move)
238 | gui.event = None
239 | mpos = (0, 0)
240 | self.mpos = mpos
241 |
242 |
243 | def orbit(self, delta, sensitivity=5, pov=False):
244 | ds, dt = delta
245 | if ds != 0 or dt != 0:
246 | dis = math.sqrt(sum((self.target_py[i] - self.pos_py[i]) ** 2 for i in range(3)))
247 | fov = self.fov
248 | ds, dt = ds * fov * sensitivity, dt * fov * sensitivity
249 | newdir = ti.Vector([ds, dt, 1], ti.f32).normalized()
250 | newdir = [sum(self.trans[None][i, j] * newdir[j] for j in range(3))\
251 | for i in range(3)]
252 | if pov:
253 | newtarget = [self.pos_py[i] + dis * newdir[i] for i in range(3)]
254 | self.set(target=newtarget)
255 | else:
256 | newpos = [self.target_py[i] - dis * newdir[i] for i in range(3)]
257 | self.set(pos=newpos)
258 |
259 | def zoom_by_mouse(self, pos, delta, sensitivity=3, dolly=False):
260 | ds, dt = delta
261 | if ds != 0 or dt != 0:
262 | z = math.sqrt(ds ** 2 + dt ** 2) * sensitivity
263 | if (pos[0] - 0.5) * ds + (pos[1] - 0.5) * dt > 0:
264 | z *= -1
265 | self.zoom(z, dolly)
266 |
267 | def zoom(self, z, dolly=False):
268 | newpos = [(1 + z) * self.pos_py[i] - z * self.target_py[i] for i in range(3)]
269 | if dolly:
270 | newtarget = [z * self.pos_py[i] + (1 - z) * self.target_py[i] for i in range(3)]
271 | self.set(pos=newpos, target=newtarget)
272 | else:
273 | self.set(pos=newpos)
274 |
275 | def pan(self, delta, sensitivity=3):
276 | ds, dt = delta
277 | if ds != 0 or dt != 0:
278 | dis = math.sqrt(sum((self.target_py[i] - self.pos_py[i]) ** 2 for i in range(3)))
279 | fov = self.fov
280 | ds, dt = ds * fov * sensitivity, dt * fov * sensitivity
281 | newdir = ti.Vector([-ds, -dt, 1], ti.f32).normalized()
282 | newdir = [sum(self.trans[None][i, j] * newdir[j] for j in range(3))\
283 | for i in range(3)]
284 | newtarget = [self.pos_py[i] + dis * newdir[i] for i in range(3)]
285 | newpos = [self.pos_py[i] + newtarget[i] - self.target_py[i] for i in range(3)]
286 | self.set(pos=newpos, target=newtarget)
287 |
288 | @ti.func
289 | def trans_pos(self, pos):
290 | return self.trans[None] @ pos + self.pos[None]
291 |
292 | @ti.func
293 | def trans_dir(self, pos):
294 | return self.trans[None] @ pos
295 |
296 | @ti.func
297 | def untrans_pos(self, pos):
298 | return self.trans[None].inverse() @ (pos - self.pos[None])
299 |
300 | @ti.func
301 | def untrans_dir(self, pos):
302 | return self.trans[None].inverse() @ pos
303 |
304 | @ti.func
305 | def uncook(self, pos):
306 | if ti.static(self.type == self.ORTHO):
307 | pos[0] *= self.intrinsic[None][0, 0]
308 | pos[1] *= self.intrinsic[None][1, 1]
309 | pos[0] += self.intrinsic[None][0, 2]
310 | pos[1] += self.intrinsic[None][1, 2]
311 | elif ti.static(self.type == self.TAN_FOV):
312 | pos = self.intrinsic[None] @ pos
313 | pos[0] /= abs(pos[2])
314 | pos[1] /= abs(pos[2])
315 | else:
316 | raise NotImplementedError("Curvilinear projection matrix not implemented!")
317 | return ti.Vector([pos[0], pos[1]], ti.f32)
318 |
319 | def export_intrinsic(self):
320 | import numpy as np
321 | intrinsic = np.zeros((3, 3))
322 | intrinsic[0, 0] = self.fx
323 | intrinsic[1, 1] = self.fy
324 | intrinsic[0, 2] = self.cx
325 | intrinsic[1, 2] = self.cy
326 | intrinsic[2, 2] = 1
327 | return intrinsic
328 |
329 | def export_extrinsic(self):
330 | import numpy as np
331 | trans = np.array(self.trans_py)
332 | pos = np.array(self.pos_py)
333 | extrinsic = np.zeros((3, 4))
334 |
335 | trans = np.transpose(trans)
336 | for i in range(3):
337 | for j in range(3):
338 | extrinsic[i][j] = trans[i, j]
339 | pos = -trans @ pos
340 | for i in range(3):
341 | extrinsic[i][3] = pos[i]
342 | return extrinsic
343 |
344 | @ti.func
345 | def generate(self, coor):
346 | fov = ti.static(self.fov)
347 | tan_fov = ti.static(math.tan(fov))
348 |
349 | orig = ti.Vector([0.0, 0.0, 0.0], ti.f32)
350 | dir = ti.Vector([0.0, 0.0, 1.0], ti.f32)
351 |
352 | if ti.static(self.type == self.ORTHO):
353 | orig = ti.Vector([coor[0], coor[1], 0.0], ti.f32)
354 | elif ti.static(self.type == self.TAN_FOV):
355 | uv = coor * fov
356 | dir = ti.Vector([uv[0], uv[1], 1], ti.f32).normalized()
357 | elif ti.static(self.type == self.COS_FOV):
358 | uv = coor * fov
359 | dir = ti.Vector([ti.sin(uv), ti.cos(uv.norm())], ti.f32).normalized()
360 |
361 | orig = self.trans_pos(orig)
362 | dir = self.trans_dir(dir)
363 |
364 | return orig, dir
365 |
--------------------------------------------------------------------------------
/prepare_data/taichi_three/version.py:
--------------------------------------------------------------------------------
1 | version = (0, 0, 4)
2 | taiglsl_version = (0, 0, 9)
3 | taichi_version = (0, 6, 28)
4 |
5 | print(f'[Tai3D] version {".".join(map(str, version))}')
6 | print(f'[Tai3D] Inputs are welcomed at https://github.com/taichi-dev/taichi_three')
7 | print(f'[Tai3D] Camera control hints: LMB to orbit, MMB to move, RMB to scale')
8 |
--------------------------------------------------------------------------------
/test_real_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import argparse
4 | import logging
5 | import numpy as np
6 | import cv2
7 | import os
8 | from pathlib import Path
9 | from tqdm import tqdm
10 |
11 | from lib.human_loader import StereoHumanDataset
12 | from lib.network import RtStereoHumanModel
13 | from config.stereo_human_config import ConfigStereoHuman as config
14 | from lib.utils import get_novel_calib
15 | from lib.GaussianRender import pts2render
16 |
17 | import torch
18 | import warnings
19 | warnings.filterwarnings("ignore", category=UserWarning)
20 |
21 |
22 | class StereoHumanRender:
23 | def __init__(self, cfg_file, phase):
24 | self.cfg = cfg_file
25 | self.bs = self.cfg.batch_size
26 |
27 | self.model = RtStereoHumanModel(self.cfg, with_gs_render=True)
28 | self.dataset = StereoHumanDataset(self.cfg.dataset, phase=phase)
29 | self.model.cuda()
30 | if self.cfg.restore_ckpt:
31 | self.load_ckpt(self.cfg.restore_ckpt)
32 | self.model.eval()
33 |
34 | def infer_seqence(self, view_select, ratio=0.5):
35 | total_frames = len(os.listdir(os.path.join(self.cfg.dataset.test_data_root, 'img')))
36 | for idx in tqdm(range(total_frames)):
37 | item = self.dataset.get_test_item(idx, source_id=view_select)
38 | data = self.fetch_data(item)
39 | data = get_novel_calib(data, self.cfg.dataset, ratio=ratio, intr_key='intr_ori', extr_key='extr_ori')
40 | with torch.no_grad():
41 | data, _, _ = self.model(data, is_train=False)
42 | data = pts2render(data, bg_color=self.cfg.dataset.bg_color)
43 |
44 | render_novel = self.tensor2np(data['novel_view']['img_pred'])
45 | cv2.imwrite(self.cfg.test_out_path + '/%s_novel.jpg' % (data['name']), render_novel)
46 |
47 | def tensor2np(self, img_tensor):
48 | img_np = img_tensor.permute(0, 2, 3, 1)[0].detach().cpu().numpy()
49 | img_np = img_np * 255
50 | img_np = img_np[:, :, ::-1].astype(np.uint8)
51 | return img_np
52 |
53 | def fetch_data(self, data):
54 | for view in ['lmain', 'rmain']:
55 | for item in data[view].keys():
56 | data[view][item] = data[view][item].cuda().unsqueeze(0)
57 | return data
58 |
59 | def load_ckpt(self, load_path):
60 | assert os.path.exists(load_path)
61 | logging.info(f"Loading checkpoint from {load_path} ...")
62 | ckpt = torch.load(load_path, map_location='cuda')
63 | self.model.load_state_dict(ckpt['network'], strict=True)
64 | logging.info(f"Parameter loading done")
65 |
66 |
67 | if __name__ == '__main__':
68 | logging.basicConfig(level=logging.INFO,
69 | format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')
70 | parser = argparse.ArgumentParser()
71 | parser.add_argument('--test_data_root', type=str, required=True)
72 | parser.add_argument('--ckpt_path', type=str, required=True)
73 | parser.add_argument('--src_view', type=int, nargs='+', required=True)
74 | parser.add_argument('--ratio', type=float, default=0.5)
75 | arg = parser.parse_args()
76 |
77 | cfg = config()
78 | cfg_for_train = os.path.join('./config', 'stage2.yaml')
79 | cfg.load(cfg_for_train)
80 | cfg = cfg.get_cfg()
81 |
82 | cfg.defrost()
83 | cfg.batch_size = 1
84 | cfg.dataset.test_data_root = arg.test_data_root
85 | cfg.dataset.use_processed_data = False
86 | cfg.restore_ckpt = arg.ckpt_path
87 | cfg.test_out_path = './test_out'
88 | Path(cfg.test_out_path).mkdir(exist_ok=True, parents=True)
89 | cfg.freeze()
90 |
91 | render = StereoHumanRender(cfg, phase='test')
92 | render.infer_seqence(view_select=arg.src_view, ratio=arg.ratio)
93 |
--------------------------------------------------------------------------------
/test_view_interp.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import argparse
4 | import logging
5 | import numpy as np
6 | import cv2
7 | import os
8 | from pathlib import Path
9 | from tqdm import tqdm
10 |
11 | from lib.human_loader import StereoHumanDataset
12 | from lib.network import RtStereoHumanModel
13 | from config.stereo_human_config import ConfigStereoHuman as config
14 | from lib.utils import get_novel_calib
15 | from lib.GaussianRender import pts2render
16 |
17 | import torch
18 | import warnings
19 | warnings.filterwarnings("ignore", category=UserWarning)
20 |
21 |
22 | class StereoHumanRender:
23 | def __init__(self, cfg_file, phase):
24 | self.cfg = cfg_file
25 | self.bs = self.cfg.batch_size
26 |
27 | self.model = RtStereoHumanModel(self.cfg, with_gs_render=True)
28 | self.dataset = StereoHumanDataset(self.cfg.dataset, phase=phase)
29 | self.model.cuda()
30 | if self.cfg.restore_ckpt:
31 | self.load_ckpt(self.cfg.restore_ckpt)
32 | self.model.eval()
33 |
34 | def infer_static(self, view_select, novel_view_nums):
35 | total_samples = len(os.listdir(os.path.join(self.cfg.dataset.test_data_root, 'img')))
36 | for idx in tqdm(range(total_samples)):
37 | item = self.dataset.get_test_item(idx, source_id=view_select)
38 | data = self.fetch_data(item)
39 | for i in range(novel_view_nums):
40 | ratio_tmp = (i+0.5)*(1/novel_view_nums)
41 | data_i = get_novel_calib(data, self.cfg.dataset, ratio=ratio_tmp, intr_key='intr_ori', extr_key='extr_ori')
42 | with torch.no_grad():
43 | data_i, _, _ = self.model(data_i, is_train=False)
44 | data_i = pts2render(data_i, bg_color=self.cfg.dataset.bg_color)
45 |
46 | render_novel = self.tensor2np(data['novel_view']['img_pred'])
47 | cv2.imwrite(self.cfg.test_out_path + '/%s_novel%s.jpg' % (data_i['name'], str(i).zfill(2)), render_novel)
48 |
49 | def tensor2np(self, img_tensor):
50 | img_np = img_tensor.permute(0, 2, 3, 1)[0].detach().cpu().numpy()
51 | img_np = img_np * 255
52 | img_np = img_np[:, :, ::-1].astype(np.uint8)
53 | return img_np
54 |
55 | def fetch_data(self, data):
56 | for view in ['lmain', 'rmain']:
57 | for item in data[view].keys():
58 | data[view][item] = data[view][item].cuda().unsqueeze(0)
59 | return data
60 |
61 | def load_ckpt(self, load_path):
62 | assert os.path.exists(load_path)
63 | logging.info(f"Loading checkpoint from {load_path} ...")
64 | ckpt = torch.load(load_path, map_location='cuda')
65 | self.model.load_state_dict(ckpt['network'], strict=True)
66 | logging.info(f"Parameter loading done")
67 |
68 |
69 | if __name__ == '__main__':
70 | logging.basicConfig(level=logging.INFO,
71 | format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument('--test_data_root', type=str, required=True)
74 | parser.add_argument('--ckpt_path', type=str, required=True)
75 | parser.add_argument('--novel_view_nums', type=int, default=5)
76 | arg = parser.parse_args()
77 |
78 | cfg = config()
79 | cfg_for_train = os.path.join('./config', 'stage2.yaml')
80 | cfg.load(cfg_for_train)
81 | cfg = cfg.get_cfg()
82 |
83 | cfg.defrost()
84 | cfg.batch_size = 1
85 | cfg.dataset.test_data_root = arg.test_data_root
86 | cfg.dataset.use_processed_data = False
87 | cfg.restore_ckpt = arg.ckpt_path
88 | cfg.test_out_path = './interp_out'
89 | Path(cfg.test_out_path).mkdir(exist_ok=True, parents=True)
90 | cfg.freeze()
91 |
92 | render = StereoHumanRender(cfg, phase='test')
93 | render.infer_static(view_select=[0, 1], novel_view_nums=arg.novel_view_nums)
94 |
--------------------------------------------------------------------------------
/train_stage1.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import logging
3 |
4 | import numpy as np
5 | import cv2
6 | import os
7 | from pathlib import Path
8 | from tqdm import tqdm
9 | from datetime import datetime
10 |
11 | from lib.human_loader import StereoHumanDataset
12 | from lib.network import RtStereoHumanModel
13 | from config.stereo_human_config import ConfigStereoHuman as config
14 | from lib.train_recoder import Logger, file_backup
15 | from lib.utils import get_novel_calib_for_show as get_novel_calib
16 | from lib.TaichiRender import TaichiRenderBatch
17 |
18 | import torch
19 | import torch.optim as optim
20 | from torch.cuda.amp import GradScaler
21 | from torch.utils.data import DataLoader
22 | import warnings
23 | warnings.filterwarnings("ignore", category=UserWarning)
24 |
25 |
26 | class Trainer:
27 | def __init__(self, cfg_file):
28 | self.cfg = cfg_file
29 |
30 | self.model = RtStereoHumanModel(self.cfg, with_gs_render=False)
31 | self.train_set = StereoHumanDataset(self.cfg.dataset, phase='train')
32 | self.train_loader = DataLoader(self.train_set, batch_size=self.cfg.batch_size, shuffle=True,
33 | num_workers=self.cfg.batch_size*2, pin_memory=True)
34 | self.train_iterator = iter(self.train_loader)
35 | self.val_set = StereoHumanDataset(self.cfg.dataset, phase='val')
36 | self.val_loader = DataLoader(self.val_set, batch_size=2, shuffle=False, num_workers=4, pin_memory=True)
37 | self.len_val = int(len(self.val_loader) / self.val_set.val_boost) # real length of val set
38 | self.val_iterator = iter(self.val_loader)
39 | self.optimizer = optim.AdamW(self.model.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.wdecay, eps=1e-8)
40 | self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer, self.cfg.lr, 100100, pct_start=0.01,
41 | cycle_momentum=False, anneal_strategy='linear')
42 |
43 | self.logger = Logger(self.scheduler, cfg.record)
44 | self.total_steps = 0
45 |
46 | self.model.cuda()
47 | if self.cfg.restore_ckpt:
48 | self.load_ckpt(self.cfg.restore_ckpt)
49 | self.model.train()
50 | self.model.raft_stereo.freeze_bn()
51 | self.scaler = GradScaler(enabled=self.cfg.raft.mixed_precision)
52 | self.render = TaichiRenderBatch(bs=1, res=self.cfg.dataset.src_res)
53 |
54 | def train(self):
55 | for _ in tqdm(range(self.total_steps, self.cfg.num_steps)):
56 | self.optimizer.zero_grad()
57 | data = self.fetch_data(phase='train')
58 |
59 | # Raft Stereo
60 | _, flow_loss, metrics = self.model(data)
61 | loss = flow_loss
62 |
63 | if self.total_steps and self.total_steps % self.cfg.record.loss_freq == 0:
64 | self.logger.writer.add_scalar(f'lr', self.optimizer.param_groups[0]['lr'], self.total_steps)
65 | self.save_ckpt(save_path=Path('%s/%s_latest.pth' % (cfg.record.ckpt_path, cfg.name)), show_log=False)
66 | self.logger.push(metrics)
67 |
68 | self.scaler.scale(loss).backward()
69 | self.scaler.unscale_(self.optimizer)
70 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
71 |
72 | self.scaler.step(self.optimizer)
73 | self.scheduler.step()
74 | self.scaler.update()
75 |
76 | if self.total_steps and self.total_steps % self.cfg.record.eval_freq == 0:
77 | self.model.eval()
78 | self.run_eval()
79 | self.model.train()
80 | self.model.raft_stereo.freeze_bn()
81 |
82 | self.total_steps += 1
83 |
84 | print("FINISHED TRAINING")
85 | self.logger.close()
86 | self.save_ckpt(save_path=Path('%s/%s_final.pth' % (cfg.record.ckpt_path, cfg.name)))
87 |
88 | def run_eval(self):
89 | logging.info(f"Doing validation ...")
90 | torch.cuda.empty_cache()
91 | epe_list, one_pix_list = [], []
92 | show_idx = np.random.choice(list(range(self.len_val)), 1)
93 | for idx in range(self.len_val):
94 | data = self.fetch_data(phase='val')
95 | with torch.no_grad():
96 | data, _, _ = self.model(data, is_train=False)
97 |
98 | if idx == show_idx:
99 | data = get_novel_calib(data, ratio=0.5)
100 | data = self.render.flow2render(data)
101 | tmp_novel = data['novel_view']['img_pred'][0].detach()
102 | tmp_novel = (tmp_novel / 2.0 + 0.5) * 255
103 | tmp_novel = tmp_novel.permute(1, 2, 0).cpu().numpy()
104 | tmp_img_name = '%s/%s.jpg' % (cfg.record.show_path, self.total_steps)
105 | cv2.imwrite(tmp_img_name, tmp_novel[:, :, ::-1].astype(np.uint8))
106 |
107 | for view in ['lmain', 'rmain']:
108 | valid = (data[view]['valid'] >= 0.5)
109 | epe = torch.sum((data[view]['flow'] - data[view]['flow_pred']) ** 2, dim=1).sqrt()
110 | epe = epe.view(-1)[valid.view(-1)]
111 | one_pix = (epe < 1)
112 | epe_list.append(epe.mean().item())
113 | one_pix_list.append(one_pix.float().mean().item())
114 |
115 | val_epe = np.round(np.mean(np.array(epe_list)), 4)
116 | val_one_pix = np.round(np.mean(np.array(one_pix_list)), 4)
117 | logging.info(f"Validation Metrics ({self.total_steps}): epe {val_epe}, 1pix {val_one_pix}")
118 | self.logger.write_dict({'val_epe': val_epe, 'val_1pix': val_one_pix}, write_step=self.total_steps)
119 | torch.cuda.empty_cache()
120 |
121 | def fetch_data(self, phase):
122 | if phase == 'train':
123 | try:
124 | data = next(self.train_iterator)
125 | except:
126 | self.train_iterator = iter(self.train_loader)
127 | data = next(self.train_iterator)
128 | elif phase == 'val':
129 | try:
130 | data = next(self.val_iterator)
131 | except:
132 | self.val_iterator = iter(self.val_loader)
133 | data = next(self.val_iterator)
134 |
135 | for view in ['lmain', 'rmain']:
136 | for item in data[view].keys():
137 | data[view][item] = data[view][item].cuda()
138 | return data
139 |
140 | def load_ckpt(self, load_path, load_optimizer=True, strict=True):
141 | assert os.path.exists(load_path)
142 | logging.info(f"Loading checkpoint from {load_path} ...")
143 | ckpt = torch.load(load_path, map_location='cuda')
144 | self.model.load_state_dict(ckpt['network'], strict=strict)
145 | logging.info(f"Parameter loading done")
146 | if load_optimizer:
147 | self.total_steps = ckpt['total_steps'] + 1
148 | self.logger.total_steps = self.total_steps
149 | self.optimizer.load_state_dict(ckpt['optimizer'])
150 | self.scheduler.load_state_dict(ckpt['scheduler'])
151 | logging.info(f"Optimizer loading done")
152 |
153 | def save_ckpt(self, save_path, show_log=True):
154 | if show_log:
155 | logging.info(f"Save checkpoint to {save_path} ...")
156 | torch.save({
157 | 'total_steps': self.total_steps,
158 | 'network': self.model.state_dict(),
159 | 'optimizer': self.optimizer.state_dict(),
160 | 'scheduler': self.scheduler.state_dict()
161 | }, save_path)
162 |
163 |
164 | if __name__ == '__main__':
165 | logging.basicConfig(level=logging.INFO,
166 | format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')
167 |
168 | cfg = config()
169 | cfg.load("config/stage1.yaml")
170 | cfg = cfg.get_cfg()
171 |
172 | cfg.defrost()
173 | dt = datetime.today()
174 | cfg.exp_name = '%s_%s%s' % (cfg.name, str(dt.month).zfill(2), str(dt.day).zfill(2))
175 | cfg.record.ckpt_path = "experiments/%s/ckpt" % cfg.exp_name
176 | cfg.record.show_path = "experiments/%s/show" % cfg.exp_name
177 | cfg.record.logs_path = "experiments/%s/logs" % cfg.exp_name
178 | cfg.record.file_path = "experiments/%s/file" % cfg.exp_name
179 | cfg.freeze()
180 |
181 | for path in [cfg.record.ckpt_path, cfg.record.show_path, cfg.record.logs_path, cfg.record.file_path]:
182 | Path(path).mkdir(exist_ok=True, parents=True)
183 |
184 | file_backup(cfg.record.file_path, cfg, train_script=os.path.basename(__file__))
185 |
186 | torch.manual_seed(1314)
187 | np.random.seed(1314)
188 |
189 | trainer = Trainer(cfg)
190 | trainer.train()
191 |
--------------------------------------------------------------------------------
/train_stage2.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import logging
4 |
5 | import numpy as np
6 | import cv2
7 | import os
8 | from pathlib import Path
9 | from tqdm import tqdm
10 | from datetime import datetime
11 |
12 | from lib.human_loader import StereoHumanDataset
13 | from lib.network import RtStereoHumanModel
14 | from config.stereo_human_config import ConfigStereoHuman as config
15 | from lib.train_recoder import Logger, file_backup
16 | from lib.GaussianRender import pts2render
17 | from lib.loss import l1_loss, ssim, psnr
18 |
19 | import torch
20 | import torch.optim as optim
21 | from torch.cuda.amp import GradScaler
22 | from torch.utils.data import DataLoader
23 | import warnings
24 | warnings.filterwarnings("ignore", category=UserWarning)
25 |
26 |
27 | class Trainer:
28 | def __init__(self, cfg_file):
29 | self.cfg = cfg_file
30 |
31 | self.model = RtStereoHumanModel(self.cfg, with_gs_render=True)
32 | self.train_set = StereoHumanDataset(self.cfg.dataset, phase='train')
33 | self.train_loader = DataLoader(self.train_set, batch_size=self.cfg.batch_size, shuffle=True,
34 | num_workers=self.cfg.batch_size*2, pin_memory=True)
35 | self.train_iterator = iter(self.train_loader)
36 | self.val_set = StereoHumanDataset(self.cfg.dataset, phase='val')
37 | self.val_loader = DataLoader(self.val_set, batch_size=2, shuffle=False, num_workers=4, pin_memory=True)
38 | self.len_val = int(len(self.val_loader) / self.val_set.val_boost) # real length of val set
39 | self.val_iterator = iter(self.val_loader)
40 | self.optimizer = optim.AdamW(self.model.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.wdecay, eps=1e-8)
41 | self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer, self.cfg.lr, self.cfg.num_steps + 100,
42 | pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
43 |
44 | self.logger = Logger(self.scheduler, cfg.record)
45 | self.total_steps = 0
46 |
47 | self.model.cuda()
48 | if self.cfg.restore_ckpt:
49 | self.load_ckpt(self.cfg.restore_ckpt)
50 | elif self.cfg.stage1_ckpt:
51 | logging.info(f"Using checkpoint from stage1")
52 | self.load_ckpt(self.cfg.stage1_ckpt, load_optimizer=False, strict=False)
53 | self.model.train()
54 | self.model.raft_stereo.freeze_bn()
55 | self.scaler = GradScaler(enabled=self.cfg.raft.mixed_precision)
56 |
57 | def train(self):
58 | for _ in tqdm(range(self.total_steps, self.cfg.num_steps)):
59 | self.optimizer.zero_grad()
60 | data = self.fetch_data(phase='train')
61 |
62 | # Raft Stereo + GS Regresser
63 | data, flow_loss, metrics = self.model(data, is_train=True)
64 | # Gaussian Render
65 | data = pts2render(data, bg_color=self.cfg.dataset.bg_color)
66 |
67 | render_novel = data['novel_view']['img_pred']
68 | gt_novel = data['novel_view']['img'].cuda()
69 |
70 | Ll1 = l1_loss(render_novel, gt_novel)
71 | Lssim = 1.0 - ssim(render_novel, gt_novel)
72 | loss = 1.0 * flow_loss + 0.8 * Ll1 + 0.2 * Lssim
73 |
74 | if self.total_steps and self.total_steps % self.cfg.record.loss_freq == 0:
75 | self.logger.writer.add_scalar(f'lr', self.optimizer.param_groups[0]['lr'], self.total_steps)
76 | self.save_ckpt(save_path=Path('%s/%s_latest.pth' % (cfg.record.ckpt_path, cfg.name)), show_log=False)
77 | metrics.update({
78 | 'l1': Ll1.item(),
79 | 'ssim': Lssim.item(),
80 | })
81 | self.logger.push(metrics)
82 |
83 | self.scaler.scale(loss).backward()
84 | self.scaler.unscale_(self.optimizer)
85 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
86 |
87 | self.scaler.step(self.optimizer)
88 | self.scheduler.step()
89 | self.scaler.update()
90 |
91 | if self.total_steps and self.total_steps % self.cfg.record.eval_freq == 0:
92 | self.model.eval()
93 | self.run_eval()
94 | self.model.train()
95 | self.model.raft_stereo.freeze_bn()
96 |
97 | self.total_steps += 1
98 |
99 | print("FINISHED TRAINING")
100 | self.logger.close()
101 | self.save_ckpt(save_path=Path('%s/%s_final.pth' % (cfg.record.ckpt_path, cfg.name)))
102 |
103 | def run_eval(self):
104 | logging.info(f"Doing validation ...")
105 | torch.cuda.empty_cache()
106 | epe_list, one_pix_list, psnr_list = [], [], []
107 | show_idx = np.random.choice(list(range(self.len_val)), 1)
108 | for idx in range(self.len_val):
109 | data = self.fetch_data(phase='val')
110 | with torch.no_grad():
111 | data, _, _ = self.model(data, is_train=False)
112 | data = pts2render(data, bg_color=self.cfg.dataset.bg_color)
113 |
114 | render_novel = data['novel_view']['img_pred']
115 | gt_novel = data['novel_view']['img'].cuda()
116 | psnr_value = psnr(render_novel, gt_novel).mean().double()
117 | psnr_list.append(psnr_value.item())
118 |
119 | if idx == show_idx:
120 | tmp_novel = data['novel_view']['img_pred'][0].detach()
121 | tmp_novel *= 255
122 | tmp_novel = tmp_novel.permute(1, 2, 0).cpu().numpy()
123 | tmp_img_name = '%s/%s.jpg' % (cfg.record.show_path, self.total_steps)
124 | cv2.imwrite(tmp_img_name, tmp_novel[:, :, ::-1].astype(np.uint8))
125 |
126 | for view in ['lmain', 'rmain']:
127 | valid = (data[view]['valid'] >= 0.5)
128 | epe = torch.sum((data[view]['flow'] - data[view]['flow_pred']) ** 2, dim=1).sqrt()
129 | epe = epe.view(-1)[valid.view(-1)]
130 | one_pix = (epe < 1)
131 | epe_list.append(epe.mean().item())
132 | one_pix_list.append(one_pix.float().mean().item())
133 |
134 | val_epe = np.round(np.mean(np.array(epe_list)), 4)
135 | val_one_pix = np.round(np.mean(np.array(one_pix_list)), 4)
136 | val_psnr = np.round(np.mean(np.array(psnr_list)), 4)
137 | logging.info(f"Validation Metrics ({self.total_steps}): epe {val_epe}, 1pix {val_one_pix}, psnr {val_psnr}")
138 | self.logger.write_dict({'val_epe': val_epe, 'val_1pix': val_one_pix, 'val_psnr': val_psnr}, write_step=self.total_steps)
139 | torch.cuda.empty_cache()
140 |
141 | def fetch_data(self, phase):
142 | if phase == 'train':
143 | try:
144 | data = next(self.train_iterator)
145 | except:
146 | self.train_iterator = iter(self.train_loader)
147 | data = next(self.train_iterator)
148 | elif phase == 'val':
149 | try:
150 | data = next(self.val_iterator)
151 | except:
152 | self.val_iterator = iter(self.val_loader)
153 | data = next(self.val_iterator)
154 |
155 | for view in ['lmain', 'rmain']:
156 | for item in data[view].keys():
157 | data[view][item] = data[view][item].cuda()
158 | return data
159 |
160 | def load_ckpt(self, load_path, load_optimizer=True, strict=True):
161 | assert os.path.exists(load_path)
162 | logging.info(f"Loading checkpoint from {load_path} ...")
163 | ckpt = torch.load(load_path, map_location='cuda')
164 | self.model.load_state_dict(ckpt['network'], strict=strict)
165 | logging.info(f"Parameter loading done")
166 | if load_optimizer:
167 | self.total_steps = ckpt['total_steps'] + 1
168 | self.logger.total_steps = self.total_steps
169 | self.optimizer.load_state_dict(ckpt['optimizer'])
170 | self.scheduler.load_state_dict(ckpt['scheduler'])
171 | logging.info(f"Optimizer loading done")
172 |
173 | def save_ckpt(self, save_path, show_log=True):
174 | if show_log:
175 | logging.info(f"Save checkpoint to {save_path} ...")
176 | torch.save({
177 | 'total_steps': self.total_steps,
178 | 'network': self.model.state_dict(),
179 | 'optimizer': self.optimizer.state_dict(),
180 | 'scheduler': self.scheduler.state_dict()
181 | }, save_path)
182 |
183 |
184 | if __name__ == '__main__':
185 | logging.basicConfig(level=logging.INFO,
186 | format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')
187 |
188 | cfg = config()
189 | cfg.load("config/stage2.yaml")
190 | cfg = cfg.get_cfg()
191 |
192 | cfg.defrost()
193 | dt = datetime.today()
194 | cfg.exp_name = '%s_%s%s' % (cfg.name, str(dt.month).zfill(2), str(dt.day).zfill(2))
195 | cfg.record.ckpt_path = "experiments/%s/ckpt" % cfg.exp_name
196 | cfg.record.show_path = "experiments/%s/show" % cfg.exp_name
197 | cfg.record.logs_path = "experiments/%s/logs" % cfg.exp_name
198 | cfg.record.file_path = "experiments/%s/file" % cfg.exp_name
199 | cfg.freeze()
200 |
201 | for path in [cfg.record.ckpt_path, cfg.record.show_path, cfg.record.logs_path, cfg.record.file_path]:
202 | Path(path).mkdir(exist_ok=True, parents=True)
203 |
204 | file_backup(cfg.record.file_path, cfg, train_script=os.path.basename(__file__))
205 |
206 | torch.manual_seed(1314)
207 | np.random.seed(1314)
208 |
209 | trainer = Trainer(cfg)
210 | trainer.train()
211 |
--------------------------------------------------------------------------------