├── .gitignore ├── LICENSE ├── README.md ├── config ├── stage1.yaml ├── stage2.yaml └── stereo_human_config.py ├── core ├── __init__.py ├── corr.py ├── extractor.py ├── raft_stereo_human.py ├── update.py └── utils │ ├── __init__.py │ ├── augmentor.py │ ├── frame_utils.py │ └── utils.py ├── environment.yml ├── gaussian_renderer └── __init__.py ├── lib ├── GaussianRender.py ├── TaichiRender.py ├── graphics_utils.py ├── gs_parm_network.py ├── human_loader.py ├── loss.py ├── network.py ├── train_recoder.py └── utils.py ├── prepare_data ├── MAKE_DATA.md ├── render_data.py └── taichi_three │ ├── __init__.py │ ├── common.py │ ├── geometry.py │ ├── light.py │ ├── loader.py │ ├── meshgen.py │ ├── model.py │ ├── raycast.py │ ├── scatter.py │ ├── scene.py │ ├── shading.py │ ├── transform.py │ └── version.py ├── test_real_data.py ├── test_view_interp.py ├── train_stage1.py └── train_stage2.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | /experiments 3 | /test_out 4 | /interp_out -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Shunyuan Zheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # GPS-Gaussian: Generalizable Pixel-wise 3D Gaussian Splatting for Real-time Human Novel View Synthesis 4 | 5 | [Shunyuan Zheng](https://shunyuanzheng.github.io)†,1, [Boyao Zhou](https://yaourtb.github.io)2, [Ruizhi Shao](https://dsaurus.github.io/saurus)2, [Boning Liu](https://liuboning2.github.io)2, [Shengping Zhang](http://homepage.hit.edu.cn/zhangshengping)*,1,3, [Liqiang Nie](https://liqiangnie.github.io)1, [Yebin Liu](https://www.liuyebin.com)2 6 | 7 |

1Harbin Institute of Technology   2Tsinghua Univserity   3Peng Cheng Laboratory 8 |
*Corresponding author   Work done during an internship at Tsinghua Univserity

9 | 10 | ### [Projectpage](https://shunyuanzheng.github.io/GPS-Gaussian) · [Video](https://youtu.be/HjnBAqjGIAo) · [Paper](https://openaccess.thecvf.com/content/CVPR2024/papers/Zheng_GPS-Gaussian_Generalizable_Pixel-wise_3D_Gaussian_Splatting_for_Real-time_Human_Novel_CVPR_2024_paper.pdf) · [Supp.](https://openaccess.thecvf.com/content/CVPR2024/supplemental/Zheng_GPS-Gaussian_Generalizable_Pixel-wise_CVPR_2024_supplemental.pdf) 11 | 12 |

13 | 14 | ## Introduction 15 | 16 | We propose GPS-Gaussian, a generalizable pixel-wise 3D Gaussian representation for synthesizing novel views of any unseen characters instantly without any fine-tuning or optimization. 17 | 18 | https://github.com/ShunyuanZheng/GPS-Gaussian/assets/33752042/54a253ad-012a-448f-8303-168d80d3f594 19 | 20 | ## Installation 21 | 22 | To deploy and run GPS-Gaussian, run the following scripts: 23 | ``` 24 | conda env create --file environment.yml 25 | conda activate gps_gaussian 26 | ``` 27 | Then, compile the ```diff-gaussian-rasterization``` in [3DGS](https://github.com/graphdeco-inria/gaussian-splatting) repository: 28 | ``` 29 | git clone https://github.com/graphdeco-inria/gaussian-splatting --recursive 30 | cd gaussian-splatting/ 31 | pip install -e submodules/diff-gaussian-rasterization 32 | cd .. 33 | ``` 34 | (optinal) [RAFT-Stereo](https://github.com/princeton-vl/RAFT-Stereo) provides a faster CUDA implementation of the correlation sampler to speed up the model without impacting performance: 35 | ``` 36 | git clone https://github.com/princeton-vl/RAFT-Stereo.git 37 | cd RAFT-Stereo/sampler && python setup.py install && cd ../.. 38 | ``` 39 | If compiled this CUDA implementation, set ```corr_implementation='reg_cuda'``` in [config/stereo_human_config.py](config/stereo_human_config.py#L33) else ```corr_implementation='reg'```. 40 | 41 | ## Run on synthetic human dataset 42 | 43 | ### Dataset Preparation 44 | - We provide rendered THuman2.0 dataset for GPS-Gaussian training in 16-camera setting, download ```render_data``` from [Baidu Netdisk](https://pan.baidu.com/s/1sX9m8wRDSQAI9d78wST7mw?pwd=rax4) or [OneDrive](https://hiteducn0-my.sharepoint.com/:f:/g/personal/sawyer0503_hit_edu_cn/EkE2GFd2saBCh_XkY3TsoV0BVTmK1UiTTKJDYje3U3vdkw?e=YazWdd) and unzip it. Since we recommend rectifying the source images and determining the disparity in an offline manner, the saved files and the downloaded data necessity around 50GB of free storage space. 45 | - To train a more robust model, we recommend collecting more human scans for training (e.g. [Twindom](https://web.twindom.com), [Render People](https://renderpeople.com/), [2K2K](https://sanghunhan92.github.io/conference/2K2K/)). Then, render the training data as the target scenario, including the number of cameras and the radius of the scene. We provide the rendering code to generate training data from human scans, see [data documentation](prepare_data/MAKE_DATA.md) for more details. 46 | 47 | ### Training 48 | Note: At the first training time, we do stereo rectify and determine the disparity offline, the processed data will be saved at ```render_data/rectified_local```. This process takes several hours and can extremely speed up the following training scheme. If you want to skip this pre-processing, set ```use_processed_data=False``` in [stage1.yaml](config/stage1.yaml#L11) and [stage2.yaml](config/stage2.yaml#L15). 49 | 50 | - Stage1: pretrain the depth prediction model. Set ```data_root``` in [stage1.yaml](config/stage1.yaml#L12) to the path of unzipped folder ```render_data```. 51 | ``` 52 | python train_stage1.py 53 | ``` 54 | 55 | - Stage2: train the full model. Set ```data_root``` in [stage2.yaml](config/stage2.yaml#L16) to the path of unzipped folder ```render_data```, and set the correct pretrained stage1 model path ```stage1_ckpt``` in [stage2.yaml](config/stage2.yaml#L3) 56 | ``` 57 | python train_stage2.py 58 | ``` 59 | - We provide the pretrained model ```GPS-GS_stage2_final.pth``` in [Baidu Netdisk](https://pan.baidu.com/s/1sX9m8wRDSQAI9d78wST7mw?pwd=rax4) and [OneDrive](https://hiteducn0-my.sharepoint.com/:f:/g/personal/sawyer0503_hit_edu_cn/EkE2GFd2saBCh_XkY3TsoV0BVTmK1UiTTKJDYje3U3vdkw?e=YazWdd) for fast evaluation and testing. 60 | 61 | ### Testing 62 | 63 | - Real-world data: download the test data ```real_data``` from [Baidu Netdisk](https://pan.baidu.com/s/1sX9m8wRDSQAI9d78wST7mw?pwd=rax4) or [OneDrive](https://hiteducn0-my.sharepoint.com/:f:/g/personal/sawyer0503_hit_edu_cn/EkE2GFd2saBCh_XkY3TsoV0BVTmK1UiTTKJDYje3U3vdkw?e=YazWdd). Then, run the following code for synthesizing a fixed novel view between ```src_view``` 0 and 1, the position of novel viewpoint between source views is adjusted with a ```ratio``` ranging from 0 to 1. 64 | ``` 65 | python test_real_data.py \ 66 | --test_data_root 'PATH/TO/REAL_DATA' \ 67 | --ckpt_path 'PATH/TO/GPS-GS_stage2_final.pth' \ 68 | --src_view 0 1 \ 69 | --ratio=0.5 70 | ``` 71 | 72 | - Freeview rendering: run the following code to interpolate freeview between source views, and modify the ```novel_view_nums``` to set a specific number of novel viewpoints. 73 | ``` 74 | python test_view_interp.py \ 75 | --test_data_root 'PATH/TO/RENDER_DATA/val' \ 76 | --ckpt_path 'PATH/TO/GPS-GS_stage2_final.pth' \ 77 | --novel_view_nums 5 78 | ``` 79 | 80 | # Citation 81 | 82 | If you find this code useful for your research, please consider citing: 83 | ```bibtex 84 | @inproceedings{zheng2024gpsgaussian, 85 | title={GPS-Gaussian: Generalizable Pixel-wise 3D Gaussian Splatting for Real-time Human Novel View Synthesis}, 86 | author={Zheng, Shunyuan and Zhou, Boyao and Shao, Ruizhi and Liu, Boning and Zhang, Shengping and Nie, Liqiang and Liu, Yebin}, 87 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 88 | year={2024} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /config/stage1.yaml: -------------------------------------------------------------------------------- 1 | name: 'GPS-GS_stage1' 2 | 3 | restore_ckpt: None 4 | lr: 0.0002 5 | wdecay: 1e-5 6 | batch_size: 6 7 | num_steps: 40000 8 | 9 | dataset: 10 | source_id: [0, 1] 11 | src_res: 1024 12 | use_processed_data: True 13 | data_root: 'PATH/TO/RENDER_DATA' 14 | 15 | raft: 16 | mixed_precision: False 17 | train_iters: 3 18 | val_iters: 3 19 | encoder_dims: [32, 48, 96] 20 | hidden_dims: [96, 96, 96] 21 | 22 | record: 23 | loss_freq: 2000 24 | eval_freq: 2000 25 | -------------------------------------------------------------------------------- /config/stage2.yaml: -------------------------------------------------------------------------------- 1 | name: 'GPS-GS_stage2' 2 | 3 | stage1_ckpt: 'PATH/TO/GPS-GS_stage1_final.pth' 4 | restore_ckpt: None 5 | lr: 0.0002 6 | wdecay: 1e-5 7 | batch_size: 2 8 | num_steps: 100000 9 | 10 | dataset: 11 | source_id: [0, 1] 12 | train_novel_id: [2, 3, 4] 13 | val_novel_id: [3] 14 | src_res: 1024 15 | use_hr_img: True 16 | use_processed_data: True 17 | data_root: 'PATH/TO/RENDER_DATA' 18 | 19 | raft: 20 | mixed_precision: True 21 | train_iters: 3 22 | val_iters: 3 23 | encoder_dims: [32, 48, 96] 24 | hidden_dims: [96, 96, 96] 25 | 26 | gsnet: 27 | encoder_dims: [32, 48, 96] 28 | decoder_dims: [48, 64, 96] 29 | parm_head_dim: 32 30 | 31 | record: 32 | loss_freq: 5000 33 | eval_freq: 5000 34 | -------------------------------------------------------------------------------- /config/stereo_human_config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | 4 | class ConfigStereoHuman: 5 | def __init__(self): 6 | self.cfg = CN() 7 | self.cfg.name = '' 8 | self.cfg.stage1_ckpt = None 9 | self.cfg.restore_ckpt = None 10 | self.cfg.lr = 0.0 11 | self.cfg.wdecay = 0.0 12 | self.cfg.batch_size = 0 13 | self.cfg.num_steps = 0 14 | 15 | self.cfg.dataset = CN() 16 | self.cfg.dataset.source_id = None 17 | self.cfg.dataset.train_novel_id = None 18 | self.cfg.dataset.val_novel_id = None 19 | self.cfg.dataset.src_res = None 20 | self.cfg.dataset.use_hr_img = None 21 | self.cfg.dataset.use_processed_data = None 22 | self.cfg.dataset.data_root = '' 23 | # gsussian render settings 24 | self.cfg.dataset.bg_color = [0, 0, 0] 25 | self.cfg.dataset.zfar = 100.0 26 | self.cfg.dataset.znear = 0.01 27 | self.cfg.dataset.trans = [0.0, 0.0, 0.0] 28 | self.cfg.dataset.scale = 1.0 29 | 30 | self.cfg.raft = CN() 31 | self.cfg.raft.mixed_precision = None 32 | self.cfg.raft.train_iters = 0 33 | self.cfg.raft.val_iters = 0 34 | self.cfg.raft.corr_implementation = 'reg_cuda' # or 'reg' 35 | self.cfg.raft.corr_levels = 4 36 | self.cfg.raft.corr_radius = 4 37 | self.cfg.raft.n_downsample = 3 38 | self.cfg.raft.n_gru_layers = 1 39 | self.cfg.raft.slow_fast_gru = None 40 | self.cfg.raft.encoder_dims = [64, 96, 128] 41 | self.cfg.raft.hidden_dims = [128]*3 42 | 43 | self.cfg.gsnet = CN() 44 | self.cfg.gsnet.encoder_dims = None 45 | self.cfg.gsnet.decoder_dims = None 46 | self.cfg.gsnet.parm_head_dim = None 47 | 48 | self.cfg.record = CN() 49 | self.cfg.record.ckpt_path = None 50 | self.cfg.record.show_path = None 51 | self.cfg.record.logs_path = None 52 | self.cfg.record.file_path = None 53 | self.cfg.record.loss_freq = 0 54 | self.cfg.record.eval_freq = 0 55 | 56 | def get_cfg(self): 57 | return self.cfg.clone() 58 | 59 | def load(self, config_file): 60 | self.cfg.defrost() 61 | self.cfg.merge_from_file(config_file) 62 | self.cfg.freeze() 63 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aipixel/GPS-Gaussian/0024776deee4824f270d4bb534a17ffd85f63cf2/core/__init__.py -------------------------------------------------------------------------------- /core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from core.utils.utils import bilinear_sampler 4 | 5 | try: 6 | import corr_sampler 7 | except: 8 | pass 9 | 10 | try: 11 | import alt_cuda_corr 12 | except: 13 | # alt_cuda_corr is not compiled 14 | pass 15 | 16 | 17 | class CorrSampler(torch.autograd.Function): 18 | @staticmethod 19 | def forward(ctx, volume, coords, radius): 20 | ctx.save_for_backward(volume,coords) 21 | ctx.radius = radius 22 | corr, = corr_sampler.forward(volume, coords, radius) 23 | return corr 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | volume, coords = ctx.saved_tensors 27 | grad_output = grad_output.contiguous() 28 | grad_volume, = corr_sampler.backward(volume, coords, grad_output, ctx.radius) 29 | return grad_volume, None, None 30 | 31 | class CorrBlockFast1D: 32 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 33 | self.num_levels = num_levels 34 | self.radius = radius 35 | self.corr_pyramid = [] 36 | # all pairs correlation 37 | corr = CorrBlockFast1D.corr(fmap1, fmap2) 38 | batch, h1, w1, dim, w2 = corr.shape 39 | corr = corr.reshape(batch*h1*w1, dim, 1, w2) 40 | for i in range(self.num_levels): 41 | self.corr_pyramid.append(corr.view(batch, h1, w1, -1, w2//2**i)) 42 | corr = F.avg_pool2d(corr, [1,2], stride=[1,2]) 43 | 44 | def __call__(self, coords): 45 | out_pyramid = [] 46 | bz, _, ht, wd = coords.shape 47 | coords = coords[:, [0]] 48 | for i in range(self.num_levels): 49 | corr = CorrSampler.apply(self.corr_pyramid[i].squeeze(3), coords/2**i, self.radius) 50 | out_pyramid.append(corr.view(bz, -1, ht, wd)) 51 | return torch.cat(out_pyramid, dim=1) 52 | 53 | @staticmethod 54 | def corr(fmap1, fmap2): 55 | B, D, H, W1 = fmap1.shape 56 | _, _, _, W2 = fmap2.shape 57 | fmap1 = fmap1.view(B, D, H, W1) 58 | fmap2 = fmap2.view(B, D, H, W2) 59 | corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2) 60 | corr = corr.reshape(B, H, W1, 1, W2).contiguous() 61 | return corr / torch.sqrt(torch.tensor(D).float()) 62 | 63 | 64 | class PytorchAlternateCorrBlock1D: 65 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 66 | self.num_levels = num_levels 67 | self.radius = radius 68 | self.corr_pyramid = [] 69 | self.fmap1 = fmap1 70 | self.fmap2 = fmap2 71 | 72 | def corr(self, fmap1, fmap2, coords): 73 | B, D, H, W = fmap2.shape 74 | # map grid coordinates to [-1,1] 75 | xgrid, ygrid = coords.split([1,1], dim=-1) 76 | xgrid = 2*xgrid/(W-1) - 1 77 | ygrid = 2*ygrid/(H-1) - 1 78 | 79 | grid = torch.cat([xgrid, ygrid], dim=-1) 80 | output_corr = [] 81 | for grid_slice in grid.unbind(3): 82 | fmapw_mini = F.grid_sample(fmap2, grid_slice, align_corners=True) 83 | corr = torch.sum(fmapw_mini * fmap1, dim=1) 84 | output_corr.append(corr) 85 | corr = torch.stack(output_corr, dim=1).permute(0,2,3,1) 86 | 87 | return corr / torch.sqrt(torch.tensor(D).float()) 88 | 89 | def __call__(self, coords): 90 | r = self.radius 91 | coords = coords.permute(0, 2, 3, 1) 92 | batch, h1, w1, _ = coords.shape 93 | fmap1 = self.fmap1 94 | fmap2 = self.fmap2 95 | out_pyramid = [] 96 | for i in range(self.num_levels): 97 | dx = torch.zeros(1) 98 | dy = torch.linspace(-r, r, 2*r+1) 99 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 100 | centroid_lvl = coords.reshape(batch, h1, w1, 1, 2).clone() 101 | centroid_lvl[...,0] = centroid_lvl[...,0] / 2**i 102 | coords_lvl = centroid_lvl + delta.view(-1, 2) 103 | corr = self.corr(fmap1, fmap2, coords_lvl) 104 | fmap2 = F.avg_pool2d(fmap2, [1, 2], stride=[1, 2]) 105 | out_pyramid.append(corr) 106 | out = torch.cat(out_pyramid, dim=-1) 107 | return out.permute(0, 3, 1, 2).contiguous().float() 108 | 109 | 110 | class CorrBlock1D: 111 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 112 | self.num_levels = num_levels 113 | self.radius = radius 114 | self.corr_pyramid = [] 115 | 116 | # all pairs correlation 117 | corr = CorrBlock1D.corr(fmap1, fmap2) 118 | 119 | batch, h1, w1, _, w2 = corr.shape 120 | corr = corr.reshape(batch*h1*w1, 1, 1, w2) 121 | 122 | self.corr_pyramid.append(corr) 123 | for i in range(self.num_levels): 124 | corr = F.avg_pool2d(corr, [1,2], stride=[1,2]) 125 | self.corr_pyramid.append(corr) 126 | 127 | def __call__(self, coords): 128 | r = self.radius 129 | coords = coords[:, :1].permute(0, 2, 3, 1) 130 | batch, h1, w1, _ = coords.shape 131 | 132 | out_pyramid = [] 133 | for i in range(self.num_levels): 134 | corr = self.corr_pyramid[i] 135 | dx = torch.linspace(-r, r, 2*r+1) 136 | dx = dx.view(2*r+1, 1).to(coords.device) 137 | x0 = dx + coords.reshape(batch*h1*w1, 1, 1, 1) / 2**i 138 | y0 = torch.zeros_like(x0) 139 | 140 | coords_lvl = torch.cat([x0,y0], dim=-1) 141 | corr = bilinear_sampler(corr, coords_lvl) 142 | corr = corr.view(batch, h1, w1, -1) 143 | out_pyramid.append(corr) 144 | 145 | out = torch.cat(out_pyramid, dim=-1) 146 | return out.permute(0, 3, 1, 2).contiguous().float() 147 | 148 | @staticmethod 149 | def corr(fmap1, fmap2): 150 | B, D, H, W1 = fmap1.shape 151 | _, _, _, W2 = fmap2.shape 152 | fmap1 = fmap1.view(B, D, H, W1) 153 | fmap2 = fmap2.view(B, D, H, W2) 154 | corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2) 155 | corr = corr.reshape(B, H, W1, 1, W2).contiguous() 156 | return corr / torch.sqrt(torch.tensor(D).float()) 157 | 158 | 159 | class AlternateCorrBlock: 160 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 161 | raise NotImplementedError 162 | self.num_levels = num_levels 163 | self.radius = radius 164 | 165 | self.pyramid = [(fmap1, fmap2)] 166 | for i in range(self.num_levels): 167 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 168 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 169 | self.pyramid.append((fmap1, fmap2)) 170 | 171 | def __call__(self, coords): 172 | coords = coords.permute(0, 2, 3, 1) 173 | B, H, W, _ = coords.shape 174 | dim = self.pyramid[0][0].shape[1] 175 | 176 | corr_list = [] 177 | for i in range(self.num_levels): 178 | r = self.radius 179 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 180 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 181 | 182 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 183 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 184 | corr_list.append(corr.squeeze(1)) 185 | 186 | corr = torch.stack(corr_list, dim=1) 187 | corr = corr.reshape(B, -1, H, W) 188 | return corr / torch.sqrt(torch.tensor(dim).float()) 189 | -------------------------------------------------------------------------------- /core/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not (stride == 1 and in_planes == planes): 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not (stride == 1 and in_planes == planes): 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not (stride == 1 and in_planes == planes): 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not (stride == 1 and in_planes == planes): 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1 and in_planes == planes: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.conv1(y) 51 | y = self.norm1(y) 52 | y = self.relu(y) 53 | y = self.conv2(y) 54 | y = self.norm2(y) 55 | y = self.relu(y) 56 | 57 | if self.downsample is not None: 58 | x = self.downsample(x) 59 | 60 | return self.relu(x+y) 61 | 62 | 63 | class UnetExtractor(nn.Module): 64 | def __init__(self, in_channel=3, encoder_dim=[64, 96, 128], norm_fn='group'): 65 | super().__init__() 66 | self.in_ds = nn.Sequential( 67 | nn.Conv2d(in_channel, 32, kernel_size=5, stride=2, padding=2), 68 | nn.GroupNorm(num_groups=8, num_channels=32), 69 | nn.ReLU(inplace=True) 70 | ) 71 | 72 | self.res1 = nn.Sequential( 73 | ResidualBlock(32, encoder_dim[0], norm_fn=norm_fn), 74 | ResidualBlock(encoder_dim[0], encoder_dim[0], norm_fn=norm_fn) 75 | ) 76 | self.res2 = nn.Sequential( 77 | ResidualBlock(encoder_dim[0], encoder_dim[1], stride=2, norm_fn=norm_fn), 78 | ResidualBlock(encoder_dim[1], encoder_dim[1], norm_fn=norm_fn) 79 | ) 80 | self.res3 = nn.Sequential( 81 | ResidualBlock(encoder_dim[1], encoder_dim[2], stride=2, norm_fn=norm_fn), 82 | ResidualBlock(encoder_dim[2], encoder_dim[2], norm_fn=norm_fn), 83 | ) 84 | 85 | def forward(self, x): 86 | x = self.in_ds(x) 87 | x1 = self.res1(x) 88 | x2 = self.res2(x1) 89 | x3 = self.res3(x2) 90 | 91 | return x1, x2, x3 92 | 93 | 94 | class MultiBasicEncoder(nn.Module): 95 | def __init__(self, output_dim=[128], encoder_dim=[64, 96, 128]): 96 | super(MultiBasicEncoder, self).__init__() 97 | 98 | # output convolution for feature 99 | self.conv2 = nn.Sequential( 100 | ResidualBlock(encoder_dim[2], encoder_dim[2], stride=1), 101 | nn.Conv2d(encoder_dim[2], encoder_dim[2]*2, 3, padding=1)) 102 | 103 | # output convolution for context 104 | output_list = [] 105 | for dim in output_dim: 106 | conv_out = nn.Sequential( 107 | ResidualBlock(encoder_dim[2], encoder_dim[2], stride=1), 108 | nn.Conv2d(encoder_dim[2], dim[2], 3, padding=1)) 109 | output_list.append(conv_out) 110 | 111 | self.outputs08 = nn.ModuleList(output_list) 112 | 113 | def forward(self, x): 114 | feat1, feat2 = self.conv2(x).split(dim=0, split_size=x.shape[0]//2) 115 | 116 | outputs08 = [f(x) for f in self.outputs08] 117 | return outputs08, feat1, feat2 118 | 119 | 120 | if __name__ == '__main__': 121 | 122 | data = torch.ones((1, 3, 1024, 1024)) 123 | 124 | model = UnetExtractor(in_channel=3, encoder_dim=[64, 96, 128]) 125 | 126 | x1, x2, x3 = model(data) 127 | print(x1.shape, x2.shape, x3.shape) 128 | -------------------------------------------------------------------------------- /core/raft_stereo_human.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from core.update import BasicMultiUpdateBlock 5 | from core.extractor import MultiBasicEncoder 6 | from core.corr import CorrBlock1D, CorrBlockFast1D 7 | from core.utils.utils import coords_grid, downflow8 8 | from torch.cuda.amp import autocast as autocast 9 | 10 | 11 | class RAFTStereoHuman(nn.Module): 12 | def __init__(self, args): 13 | super().__init__() 14 | self.args = args 15 | 16 | context_dims = args.hidden_dims 17 | self.cnet = MultiBasicEncoder(output_dim=[args.hidden_dims, context_dims], encoder_dim=args.encoder_dims) 18 | self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)]) 19 | self.update_module = FlowUpdateModule(self.args) 20 | 21 | def freeze_bn(self): 22 | for m in self.modules(): 23 | if isinstance(m, nn.BatchNorm2d): 24 | m.eval() 25 | 26 | def forward(self, image_pair, iters=12, flow_init=None, test_mode=False): 27 | """ Estimate optical flow between pair of frames """ 28 | 29 | if flow_init is not None: 30 | flow_init = downflow8(flow_init) 31 | flow_init = torch.cat([flow_init, torch.zeros_like(flow_init)], dim=1) 32 | 33 | # run the context network 34 | with autocast(enabled=self.args.mixed_precision): 35 | *cnet_list, fmap1, fmap2 = self.cnet(image_pair) 36 | fmap12 = torch.cat((fmap1, fmap2), dim=0) 37 | fmap21 = torch.cat((fmap2, fmap1), dim=0) 38 | 39 | net_list = [torch.tanh(x[0]) for x in cnet_list] 40 | inp_list = [torch.relu(x[1]) for x in cnet_list] 41 | 42 | # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning 43 | inp_list = [list(conv(i).split(split_size=conv.out_channels // 3, dim=1)) for i, conv in zip(inp_list, self.context_zqr_convs)] 44 | 45 | # run update module 46 | flow_pred = self.update_module(fmap12, fmap21, net_list, inp_list, iters, flow_init, test_mode) 47 | 48 | if not test_mode: 49 | return flow_pred 50 | else: 51 | return flow_pred.split(dim=0, split_size=flow_pred.shape[0]//2) 52 | 53 | 54 | class FlowUpdateModule(nn.Module): 55 | def __init__(self, args): 56 | super().__init__() 57 | self.args = args 58 | self.update_block = BasicMultiUpdateBlock(self.args, hidden_dims=args.hidden_dims) 59 | 60 | def initialize_flow(self, img): 61 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 62 | N, _, H, W = img.shape 63 | 64 | coords0 = coords_grid(N, H, W).to(img.device) 65 | coords1 = coords_grid(N, H, W).to(img.device) 66 | 67 | return coords0, coords1 68 | 69 | def upsample_flow(self, flow, mask): 70 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 71 | N, D, H, W = flow.shape 72 | factor = 2 ** self.args.n_downsample 73 | mask = mask.view(N, 1, 9, factor, factor, H, W) 74 | mask = torch.softmax(mask, dim=2) 75 | 76 | up_flow = F.unfold(factor * flow, [3,3], padding=1) 77 | up_flow = up_flow.view(N, D, 9, 1, 1, H, W) 78 | 79 | up_flow = torch.sum(mask * up_flow, dim=2) 80 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 81 | return up_flow.reshape(N, D, factor*H, factor*W) 82 | 83 | def forward(self, fmap1, fmap2, net_list, inp_list, iters=12, flow_init=None, test_mode=False): 84 | if self.args.corr_implementation == "reg": # Default 85 | corr_block = CorrBlock1D 86 | fmap1, fmap2 = fmap1.float(), fmap2.float() 87 | elif self.args.corr_implementation == "reg_cuda": # Faster version of reg 88 | corr_block = CorrBlockFast1D 89 | corr_fn = corr_block(fmap1, fmap2, radius=self.args.corr_radius, num_levels=self.args.corr_levels) 90 | 91 | coords0, coords1 = self.initialize_flow(net_list[0]) 92 | 93 | if flow_init is not None: 94 | coords1 = coords1 + flow_init 95 | 96 | flow_predictions = [] 97 | for itr in range(iters): 98 | coords1 = coords1.detach() 99 | corr = corr_fn(coords1) # index correlation volume 100 | flow = coords1 - coords0 101 | with autocast(enabled=self.args.mixed_precision): 102 | if self.args.n_gru_layers == 3 and self.args.slow_fast_gru: # Update low-res GRU 103 | net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False) 104 | if self.args.n_gru_layers >= 2 and self.args.slow_fast_gru: # Update low-res GRU and mid-res GRU 105 | net_list = self.update_block(net_list, inp_list, iter32=self.args.n_gru_layers==3, iter16=True, iter08=False, update=False) 106 | net_list, up_mask, delta_flow = self.update_block(net_list, inp_list, corr, flow, iter32=self.args.n_gru_layers==3, iter16=self.args.n_gru_layers>=2) 107 | 108 | # in stereo mode, project flow onto epipolar 109 | delta_flow[:, 1] = 0.0 110 | 111 | # F(t+1) = F(t) + \Delta(t) 112 | coords1 = coords1 + delta_flow 113 | 114 | # We do not need to upsample or output intermediate results in test_mode 115 | if test_mode and itr < iters-1: 116 | continue 117 | 118 | # upsample predictions 119 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 120 | flow_up = flow_up[:, :1] 121 | 122 | flow_predictions.append(flow_up) 123 | 124 | if test_mode: 125 | return flow_up 126 | 127 | return flow_predictions 128 | -------------------------------------------------------------------------------- /core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # from opt_einsum import contract 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256, output_dim=2): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim, input_dim, kernel_size=3): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2) 22 | 23 | def forward(self, h, cz, cr, cq, *x_list): 24 | x = torch.cat(x_list, dim=1) 25 | hx = torch.cat([h, x], dim=1) 26 | 27 | z = torch.sigmoid(self.convz(hx) + cz) 28 | r = torch.sigmoid(self.convr(hx) + cr) 29 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)) + cq) 30 | 31 | h = (1-z) * h + z * q 32 | return h 33 | 34 | class SepConvGRU(nn.Module): 35 | def __init__(self, hidden_dim=128, input_dim=192+128): 36 | super(SepConvGRU, self).__init__() 37 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 40 | 41 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 44 | 45 | 46 | def forward(self, h, *x): 47 | # horizontal 48 | x = torch.cat(x, dim=1) 49 | hx = torch.cat([h, x], dim=1) 50 | z = torch.sigmoid(self.convz1(hx)) 51 | r = torch.sigmoid(self.convr1(hx)) 52 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 53 | h = (1-z) * h + z * q 54 | 55 | # vertical 56 | hx = torch.cat([h, x], dim=1) 57 | z = torch.sigmoid(self.convz2(hx)) 58 | r = torch.sigmoid(self.convr2(hx)) 59 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 60 | h = (1-z) * h + z * q 61 | 62 | return h 63 | 64 | class BasicMotionEncoder(nn.Module): 65 | def __init__(self, args): 66 | super(BasicMotionEncoder, self).__init__() 67 | self.args = args 68 | 69 | cor_planes = args.corr_levels * (2*args.corr_radius + 1) 70 | 71 | self.convc1 = nn.Conv2d(cor_planes, 64, 1, padding=0) 72 | self.convc2 = nn.Conv2d(64, 64, 3, padding=1) 73 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 74 | self.convf2 = nn.Conv2d(64, 64, 3, padding=1) 75 | self.conv = nn.Conv2d(64+64, 128-2, 3, padding=1) 76 | 77 | def forward(self, flow, corr): 78 | cor = F.relu(self.convc1(corr)) 79 | cor = F.relu(self.convc2(cor)) 80 | flo = F.relu(self.convf1(flow)) 81 | flo = F.relu(self.convf2(flo)) 82 | 83 | cor_flo = torch.cat([cor, flo], dim=1) 84 | out = F.relu(self.conv(cor_flo)) 85 | return torch.cat([out, flow], dim=1) 86 | 87 | def pool2x(x): 88 | return F.avg_pool2d(x, 3, stride=2, padding=1) 89 | 90 | def pool4x(x): 91 | return F.avg_pool2d(x, 5, stride=4, padding=1) 92 | 93 | def interp(x, dest): 94 | interp_args = {'mode': 'bilinear', 'align_corners': True} 95 | return F.interpolate(x, dest.shape[2:], **interp_args) 96 | 97 | class BasicMultiUpdateBlock(nn.Module): 98 | def __init__(self, args, hidden_dims=[]): 99 | super().__init__() 100 | self.args = args 101 | self.encoder = BasicMotionEncoder(args) 102 | encoder_output_dim = 128 103 | 104 | self.gru08 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (args.n_gru_layers > 1)) 105 | self.gru16 = ConvGRU(hidden_dims[1], hidden_dims[0] * (args.n_gru_layers == 3) + hidden_dims[2]) 106 | self.gru32 = ConvGRU(hidden_dims[0], hidden_dims[1]) 107 | self.flow_head = FlowHead(hidden_dims[2], hidden_dim=256, output_dim=2) 108 | factor = 2**self.args.n_downsample 109 | 110 | self.mask = nn.Sequential( 111 | nn.Conv2d(hidden_dims[2], 256, 3, padding=1), 112 | nn.ReLU(inplace=True), 113 | nn.Conv2d(256, (factor**2)*9, 1, padding=0)) 114 | 115 | def forward(self, net, inp, corr=None, flow=None, iter08=True, iter16=True, iter32=True, update=True): 116 | 117 | if iter32: 118 | net[2] = self.gru32(net[2], *(inp[2]), pool2x(net[1])) 119 | if iter16: 120 | if self.args.n_gru_layers > 2: 121 | net[1] = self.gru16(net[1], *(inp[1]), pool2x(net[0]), interp(net[2], net[1])) 122 | else: 123 | net[1] = self.gru16(net[1], *(inp[1]), pool2x(net[0])) 124 | if iter08: 125 | motion_features = self.encoder(flow, corr) 126 | if self.args.n_gru_layers > 1: 127 | net[0] = self.gru08(net[0], *(inp[0]), motion_features, interp(net[1], net[0])) 128 | else: 129 | net[0] = self.gru08(net[0], *(inp[0]), motion_features) 130 | 131 | if not update: 132 | return net 133 | 134 | delta_flow = self.flow_head(net[0]) 135 | 136 | # scale mask to balence gradients 137 | mask = .25 * self.mask(net[0]) 138 | return net, mask, delta_flow 139 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aipixel/GPS-Gaussian/0024776deee4824f270d4bb534a17ffd85f63cf2/core/utils/__init__.py -------------------------------------------------------------------------------- /core/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import warnings 4 | import os 5 | import time 6 | from glob import glob 7 | from skimage import color, io 8 | from PIL import Image 9 | 10 | import cv2 11 | cv2.setNumThreads(0) 12 | cv2.ocl.setUseOpenCL(False) 13 | 14 | import torch 15 | from torchvision.transforms import ColorJitter, functional, Compose 16 | import torch.nn.functional as F 17 | 18 | def get_middlebury_images(): 19 | root = "datasets/Middlebury/MiddEval3" 20 | with open(os.path.join(root, "official_train.txt"), 'r') as f: 21 | lines = f.read().splitlines() 22 | return sorted([os.path.join(root, 'trainingQ', f'{name}/im0.png') for name in lines]) 23 | 24 | def get_eth3d_images(): 25 | return sorted(glob('datasets/ETH3D/two_view_training/*/im0.png')) 26 | 27 | def get_kitti_images(): 28 | return sorted(glob('datasets/KITTI/training/image_2/*_10.png')) 29 | 30 | def transfer_color(image, style_mean, style_stddev): 31 | reference_image_lab = color.rgb2lab(image) 32 | reference_stddev = np.std(reference_image_lab, axis=(0,1), keepdims=True)# + 1 33 | reference_mean = np.mean(reference_image_lab, axis=(0,1), keepdims=True) 34 | 35 | reference_image_lab = reference_image_lab - reference_mean 36 | lamb = style_stddev/reference_stddev 37 | style_image_lab = lamb * reference_image_lab 38 | output_image_lab = style_image_lab + style_mean 39 | l, a, b = np.split(output_image_lab, 3, axis=2) 40 | l = l.clip(0, 100) 41 | output_image_lab = np.concatenate((l,a,b), axis=2) 42 | with warnings.catch_warnings(): 43 | warnings.simplefilter("ignore", category=UserWarning) 44 | output_image_rgb = color.lab2rgb(output_image_lab) * 255 45 | return output_image_rgb 46 | 47 | class AdjustGamma(object): 48 | 49 | def __init__(self, gamma_min, gamma_max, gain_min=1.0, gain_max=1.0): 50 | self.gamma_min, self.gamma_max, self.gain_min, self.gain_max = gamma_min, gamma_max, gain_min, gain_max 51 | 52 | def __call__(self, sample): 53 | gain = random.uniform(self.gain_min, self.gain_max) 54 | gamma = random.uniform(self.gamma_min, self.gamma_max) 55 | return functional.adjust_gamma(sample, gamma, gain) 56 | 57 | def __repr__(self): 58 | return f"Adjust Gamma {self.gamma_min}, ({self.gamma_max}) and Gain ({self.gain_min}, {self.gain_max})" 59 | 60 | class FlowAugmentor: 61 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True, yjitter=False, saturation_range=[0.6,1.4], gamma=[1,1,1,1]): 62 | 63 | # spatial augmentation params 64 | self.crop_size = crop_size 65 | self.min_scale = min_scale 66 | self.max_scale = max_scale 67 | self.spatial_aug_prob = 1.0 68 | self.stretch_prob = 0.8 69 | self.max_stretch = 0.2 70 | 71 | # flip augmentation params 72 | self.yjitter = yjitter 73 | self.do_flip = do_flip 74 | self.h_flip_prob = 0.5 75 | self.v_flip_prob = 0.1 76 | 77 | # photometric augmentation params 78 | self.photo_aug = Compose([ColorJitter(brightness=0.4, contrast=0.4, saturation=saturation_range, hue=0.5/3.14), AdjustGamma(*gamma)]) 79 | self.asymmetric_color_aug_prob = 0.2 80 | self.eraser_aug_prob = 0.5 81 | 82 | def color_transform(self, img1, img2): 83 | """ Photometric augmentation """ 84 | 85 | # asymmetric 86 | if np.random.rand() < self.asymmetric_color_aug_prob: 87 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 88 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 89 | 90 | # symmetric 91 | else: 92 | image_stack = np.concatenate([img1, img2], axis=0) 93 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 94 | img1, img2 = np.split(image_stack, 2, axis=0) 95 | 96 | return img1, img2 97 | 98 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 99 | """ Occlusion augmentation """ 100 | 101 | ht, wd = img1.shape[:2] 102 | if np.random.rand() < self.eraser_aug_prob: 103 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 104 | for _ in range(np.random.randint(1, 3)): 105 | x0 = np.random.randint(0, wd) 106 | y0 = np.random.randint(0, ht) 107 | dx = np.random.randint(bounds[0], bounds[1]) 108 | dy = np.random.randint(bounds[0], bounds[1]) 109 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 110 | 111 | return img1, img2 112 | 113 | def spatial_transform(self, img1, img2, flow): 114 | # randomly sample scale 115 | ht, wd = img1.shape[:2] 116 | min_scale = np.maximum( 117 | (self.crop_size[0] + 8) / float(ht), 118 | (self.crop_size[1] + 8) / float(wd)) 119 | 120 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 121 | scale_x = scale 122 | scale_y = scale 123 | if np.random.rand() < self.stretch_prob: 124 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 125 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 126 | 127 | scale_x = np.clip(scale_x, min_scale, None) 128 | scale_y = np.clip(scale_y, min_scale, None) 129 | 130 | if np.random.rand() < self.spatial_aug_prob: 131 | # rescale the images 132 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 133 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 134 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 135 | flow = flow * [scale_x, scale_y] 136 | 137 | if self.do_flip: 138 | if np.random.rand() < self.h_flip_prob and self.do_flip == 'hf': # h-flip 139 | img1 = img1[:, ::-1] 140 | img2 = img2[:, ::-1] 141 | flow = flow[:, ::-1] * [-1.0, 1.0] 142 | 143 | if np.random.rand() < self.h_flip_prob and self.do_flip == 'h': # h-flip for stereo 144 | tmp = img1[:, ::-1] 145 | img1 = img2[:, ::-1] 146 | img2 = tmp 147 | 148 | if np.random.rand() < self.v_flip_prob and self.do_flip == 'v': # v-flip 149 | img1 = img1[::-1, :] 150 | img2 = img2[::-1, :] 151 | flow = flow[::-1, :] * [1.0, -1.0] 152 | 153 | if self.yjitter: 154 | y0 = np.random.randint(2, img1.shape[0] - self.crop_size[0] - 2) 155 | x0 = np.random.randint(2, img1.shape[1] - self.crop_size[1] - 2) 156 | 157 | y1 = y0 + np.random.randint(-2, 2 + 1) 158 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 159 | img2 = img2[y1:y1+self.crop_size[0], x0:x0+self.crop_size[1]] 160 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 161 | 162 | else: 163 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 164 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 165 | 166 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 167 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 168 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 169 | 170 | return img1, img2, flow 171 | 172 | 173 | def __call__(self, img1, img2, flow): 174 | img1, img2 = self.color_transform(img1, img2) 175 | img1, img2 = self.eraser_transform(img1, img2) 176 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 177 | 178 | img1 = np.ascontiguousarray(img1) 179 | img2 = np.ascontiguousarray(img2) 180 | flow = np.ascontiguousarray(flow) 181 | 182 | return img1, img2, flow 183 | 184 | class SparseFlowAugmentor: 185 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False, yjitter=False, saturation_range=[0.7,1.3], gamma=[1,1,1,1]): 186 | # spatial augmentation params 187 | self.crop_size = crop_size 188 | self.min_scale = min_scale 189 | self.max_scale = max_scale 190 | self.spatial_aug_prob = 0.8 191 | self.stretch_prob = 0.8 192 | self.max_stretch = 0.2 193 | 194 | # flip augmentation params 195 | self.do_flip = do_flip 196 | self.h_flip_prob = 0.5 197 | self.v_flip_prob = 0.1 198 | 199 | # photometric augmentation params 200 | self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)]) 201 | self.asymmetric_color_aug_prob = 0.2 202 | self.eraser_aug_prob = 0.5 203 | 204 | def color_transform(self, img1, img2): 205 | image_stack = np.concatenate([img1, img2], axis=0) 206 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 207 | img1, img2 = np.split(image_stack, 2, axis=0) 208 | return img1, img2 209 | 210 | def eraser_transform(self, img1, img2): 211 | ht, wd = img1.shape[:2] 212 | if np.random.rand() < self.eraser_aug_prob: 213 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 214 | for _ in range(np.random.randint(1, 3)): 215 | x0 = np.random.randint(0, wd) 216 | y0 = np.random.randint(0, ht) 217 | dx = np.random.randint(50, 100) 218 | dy = np.random.randint(50, 100) 219 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 220 | 221 | return img1, img2 222 | 223 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 224 | ht, wd = flow.shape[:2] 225 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 226 | coords = np.stack(coords, axis=-1) 227 | 228 | coords = coords.reshape(-1, 2).astype(np.float32) 229 | flow = flow.reshape(-1, 2).astype(np.float32) 230 | valid = valid.reshape(-1).astype(np.float32) 231 | 232 | coords0 = coords[valid>=1] 233 | flow0 = flow[valid>=1] 234 | 235 | ht1 = int(round(ht * fy)) 236 | wd1 = int(round(wd * fx)) 237 | 238 | coords1 = coords0 * [fx, fy] 239 | flow1 = flow0 * [fx, fy] 240 | 241 | xx = np.round(coords1[:,0]).astype(np.int32) 242 | yy = np.round(coords1[:,1]).astype(np.int32) 243 | 244 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 245 | xx = xx[v] 246 | yy = yy[v] 247 | flow1 = flow1[v] 248 | 249 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 250 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 251 | 252 | flow_img[yy, xx] = flow1 253 | valid_img[yy, xx] = 1 254 | 255 | return flow_img, valid_img 256 | 257 | def spatial_transform(self, img1, img2, flow, valid): 258 | # randomly sample scale 259 | 260 | ht, wd = img1.shape[:2] 261 | min_scale = np.maximum( 262 | (self.crop_size[0] + 1) / float(ht), 263 | (self.crop_size[1] + 1) / float(wd)) 264 | 265 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 266 | scale_x = np.clip(scale, min_scale, None) 267 | scale_y = np.clip(scale, min_scale, None) 268 | 269 | if np.random.rand() < self.spatial_aug_prob: 270 | # rescale the images 271 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 272 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 273 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 274 | 275 | if self.do_flip: 276 | if np.random.rand() < self.h_flip_prob and self.do_flip == 'hf': # h-flip 277 | img1 = img1[:, ::-1] 278 | img2 = img2[:, ::-1] 279 | flow = flow[:, ::-1] * [-1.0, 1.0] 280 | 281 | if np.random.rand() < self.h_flip_prob and self.do_flip == 'h': # h-flip for stereo 282 | tmp = img1[:, ::-1] 283 | img1 = img2[:, ::-1] 284 | img2 = tmp 285 | 286 | if np.random.rand() < self.v_flip_prob and self.do_flip == 'v': # v-flip 287 | img1 = img1[::-1, :] 288 | img2 = img2[::-1, :] 289 | flow = flow[::-1, :] * [1.0, -1.0] 290 | 291 | margin_y = 20 292 | margin_x = 50 293 | 294 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 295 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 296 | 297 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 298 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 299 | 300 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 301 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 302 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 303 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 304 | return img1, img2, flow, valid 305 | 306 | 307 | def __call__(self, img1, img2, flow, valid): 308 | img1, img2 = self.color_transform(img1, img2) 309 | img1, img2 = self.eraser_transform(img1, img2) 310 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 311 | 312 | img1 = np.ascontiguousarray(img1) 313 | img2 = np.ascontiguousarray(img2) 314 | flow = np.ascontiguousarray(flow) 315 | valid = np.ascontiguousarray(valid) 316 | 317 | return img1, img2, flow, valid 318 | -------------------------------------------------------------------------------- /core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | import json 6 | import imageio 7 | import cv2 8 | cv2.setNumThreads(0) 9 | cv2.ocl.setUseOpenCL(False) 10 | 11 | TAG_CHAR = np.array([202021.25], np.float32) 12 | 13 | def readFlow(fn): 14 | """ Read .flo file in Middlebury format""" 15 | # Code adapted from: 16 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 17 | 18 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 19 | # print 'fn = %s'%(fn) 20 | with open(fn, 'rb') as f: 21 | magic = np.fromfile(f, np.float32, count=1) 22 | if 202021.25 != magic: 23 | print('Magic number incorrect. Invalid .flo file') 24 | return None 25 | else: 26 | w = np.fromfile(f, np.int32, count=1) 27 | h = np.fromfile(f, np.int32, count=1) 28 | # print 'Reading %d x %d flo file\n' % (w, h) 29 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 30 | # Reshape data into 3D array (columns, rows, bands) 31 | # The reshape here is for visualization, the original code is (w,h,2) 32 | return np.resize(data, (int(h), int(w), 2)) 33 | 34 | def readPFM(file): 35 | file = open(file, 'rb') 36 | 37 | color = None 38 | width = None 39 | height = None 40 | scale = None 41 | endian = None 42 | 43 | header = file.readline().rstrip() 44 | if header == b'PF': 45 | color = True 46 | elif header == b'Pf': 47 | color = False 48 | else: 49 | raise Exception('Not a PFM file.') 50 | 51 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 52 | if dim_match: 53 | width, height = map(int, dim_match.groups()) 54 | else: 55 | raise Exception('Malformed PFM header.') 56 | 57 | scale = float(file.readline().rstrip()) 58 | if scale < 0: # little-endian 59 | endian = '<' 60 | scale = -scale 61 | else: 62 | endian = '>' # big-endian 63 | 64 | data = np.fromfile(file, endian + 'f') 65 | shape = (height, width, 3) if color else (height, width) 66 | 67 | data = np.reshape(data, shape) 68 | data = np.flipud(data) 69 | return data 70 | 71 | def writePFM(file, array): 72 | import os 73 | assert type(file) is str and type(array) is np.ndarray and \ 74 | os.path.splitext(file)[1] == ".pfm" 75 | with open(file, 'wb') as f: 76 | H, W = array.shape 77 | headers = ["Pf\n", f"{W} {H}\n", "-1\n"] 78 | for header in headers: 79 | f.write(str.encode(header)) 80 | array = np.flip(array, axis=0).astype(np.float32) 81 | f.write(array.tobytes()) 82 | 83 | 84 | 85 | def writeFlow(filename,uv,v=None): 86 | """ Write optical flow to file. 87 | 88 | If v is None, uv is assumed to contain both u and v channels, 89 | stacked in depth. 90 | Original code by Deqing Sun, adapted from Daniel Scharstein. 91 | """ 92 | nBands = 2 93 | 94 | if v is None: 95 | assert(uv.ndim == 3) 96 | assert(uv.shape[2] == 2) 97 | u = uv[:,:,0] 98 | v = uv[:,:,1] 99 | else: 100 | u = uv 101 | 102 | assert(u.shape == v.shape) 103 | height,width = u.shape 104 | f = open(filename,'wb') 105 | # write the header 106 | f.write(TAG_CHAR) 107 | np.array(width).astype(np.int32).tofile(f) 108 | np.array(height).astype(np.int32).tofile(f) 109 | # arrange into matrix form 110 | tmp = np.zeros((height, width*nBands)) 111 | tmp[:,np.arange(width)*2] = u 112 | tmp[:,np.arange(width)*2 + 1] = v 113 | tmp.astype(np.float32).tofile(f) 114 | f.close() 115 | 116 | 117 | def readFlowKITTI(filename): 118 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 119 | flow = flow[:,:,::-1].astype(np.float32) 120 | flow, valid = flow[:, :, :2], flow[:, :, 2] 121 | flow = (flow - 2**15) / 64.0 122 | return flow, valid 123 | 124 | def readDispKITTI(filename): 125 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 126 | valid = disp > 0.0 127 | return disp, valid 128 | 129 | # Method taken from /n/fs/raft-depth/RAFT-Stereo/datasets/SintelStereo/sdk/python/sintel_io.py 130 | def readDispSintelStereo(file_name): 131 | a = np.array(Image.open(file_name)) 132 | d_r, d_g, d_b = np.split(a, axis=2, indices_or_sections=3) 133 | disp = (d_r * 4 + d_g / (2**6) + d_b / (2**14))[..., 0] 134 | mask = np.array(Image.open(file_name.replace('disparities', 'occlusions'))) 135 | valid = ((mask == 0) & (disp > 0)) 136 | return disp, valid 137 | 138 | # Method taken from https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt 139 | def readDispFallingThings(file_name): 140 | a = np.array(Image.open(file_name)) 141 | with open('/'.join(file_name.split('/')[:-1] + ['_camera_settings.json']), 'r') as f: 142 | intrinsics = json.load(f) 143 | fx = intrinsics['camera_settings'][0]['intrinsic_settings']['fx'] 144 | disp = (fx * 6.0 * 100) / a.astype(np.float32) 145 | valid = disp > 0 146 | return disp, valid 147 | 148 | # Method taken from https://github.com/castacks/tartanair_tools/blob/master/data_type.md 149 | def readDispTartanAir(file_name): 150 | depth = np.load(file_name) 151 | disp = 80.0 / depth 152 | valid = disp > 0 153 | return disp, valid 154 | 155 | 156 | def readDispMiddlebury(file_name): 157 | if basename(file_name) == 'disp0GT.pfm': 158 | disp = readPFM(file_name).astype(np.float32) 159 | assert len(disp.shape) == 2 160 | nocc_pix = file_name.replace('disp0GT.pfm', 'mask0nocc.png') 161 | assert exists(nocc_pix) 162 | nocc_pix = imageio.imread(nocc_pix) == 255 163 | assert np.any(nocc_pix) 164 | return disp, nocc_pix 165 | elif basename(file_name) == 'disp0.pfm': 166 | disp = readPFM(file_name).astype(np.float32) 167 | valid = disp < 1e3 168 | return disp, valid 169 | 170 | def writeFlowKITTI(filename, uv): 171 | uv = 64.0 * uv + 2**15 172 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 173 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 174 | cv2.imwrite(filename, uv[..., ::-1]) 175 | 176 | 177 | def read_gen(file_name, pil=False): 178 | ext = splitext(file_name)[-1] 179 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 180 | return Image.open(file_name) 181 | elif ext == '.bin' or ext == '.raw': 182 | return np.load(file_name) 183 | elif ext == '.flo': 184 | return readFlow(file_name).astype(np.float32) 185 | elif ext == '.pfm': 186 | flow = readPFM(file_name).astype(np.float32) 187 | if len(flow.shape) == 2: 188 | return flow 189 | else: 190 | return flow[:, :, :-1] 191 | return [] -------------------------------------------------------------------------------- /core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel', divis_by=8): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by 12 | pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | assert all((x.ndim == 4) for x in inputs) 20 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 21 | 22 | def unpad(self, x): 23 | assert x.ndim == 4 24 | ht, wd = x.shape[-2:] 25 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 26 | return x[..., c[0]:c[1], c[2]:c[3]] 27 | 28 | def forward_interpolate(flow): 29 | flow = flow.detach().cpu().numpy() 30 | dx, dy = flow[0], flow[1] 31 | 32 | ht, wd = dx.shape 33 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 34 | 35 | x1 = x0 + dx 36 | y1 = y0 + dy 37 | 38 | x1 = x1.reshape(-1) 39 | y1 = y1.reshape(-1) 40 | dx = dx.reshape(-1) 41 | dy = dy.reshape(-1) 42 | 43 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 44 | x1 = x1[valid] 45 | y1 = y1[valid] 46 | dx = dx[valid] 47 | dy = dy[valid] 48 | 49 | flow_x = interpolate.griddata( 50 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 51 | 52 | flow_y = interpolate.griddata( 53 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 54 | 55 | flow = np.stack([flow_x, flow_y], axis=0) 56 | return torch.from_numpy(flow).float() 57 | 58 | 59 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 60 | """ Wrapper for grid_sample, uses pixel coordinates """ 61 | H, W = img.shape[-2:] 62 | xgrid, ygrid = coords.split([1,1], dim=-1) 63 | xgrid = 2*xgrid/(W-1) - 1 64 | if H > 1: 65 | ygrid = 2*ygrid/(H-1) - 1 66 | 67 | grid = torch.cat([xgrid, ygrid], dim=-1) 68 | img = F.grid_sample(img, grid, align_corners=True) 69 | 70 | if mask: 71 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 72 | return img, mask.float() 73 | 74 | return img 75 | 76 | 77 | def coords_grid(batch, ht, wd): 78 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 79 | coords = torch.stack(coords[::-1], dim=0).float() 80 | return coords[None].repeat(batch, 1, 1, 1) 81 | 82 | 83 | def upflow8(flow, mode='bilinear'): 84 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 85 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 86 | 87 | def downflow8(flow, mode='bilinear'): 88 | new_size = (int(flow.shape[2]/8), int(flow.shape[3]/8)) 89 | return 0.125 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 90 | 91 | def gauss_blur(input, N=5, std=1): 92 | B, D, H, W = input.shape 93 | x, y = torch.meshgrid(torch.arange(N).float() - N//2, torch.arange(N).float() - N//2) 94 | unnormalized_gaussian = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * std ** 2)) 95 | weights = unnormalized_gaussian / unnormalized_gaussian.sum().clamp(min=1e-4) 96 | weights = weights.view(1,1,N,N).to(input) 97 | output = F.conv2d(input.reshape(B*D,1,H,W), weights, padding=N//2) 98 | return output.view(B, D, H, W) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gps_gaussian 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python=3.10.13 7 | - pip=23.3.1 8 | - pytorch=2.0.1 9 | - torchvision=0.15.2 10 | - torchaudio=2.0.2 11 | - pytorch-cuda=11.8 12 | - tqdm 13 | - tensorboard 14 | - scipy 15 | - pip: 16 | - opencv-python 17 | - taichi==1.5.0 18 | - yacs==0.1.8 19 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | 16 | 17 | def render(data, idx, pts_xyz, pts_rgb, rotations, scales, opacity, bg_color): 18 | """ 19 | Render the scene. 20 | 21 | Background tensor (bg_color) must be on GPU! 22 | """ 23 | bg_color = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 24 | 25 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 26 | screenspace_points = torch.zeros_like(pts_xyz, dtype=torch.float32, requires_grad=True, device="cuda") + 0 27 | try: 28 | screenspace_points.retain_grad() 29 | except: 30 | pass 31 | 32 | # Set up rasterization configuration 33 | tanfovx = math.tan(data['novel_view']['FovX'][idx] * 0.5) 34 | tanfovy = math.tan(data['novel_view']['FovY'][idx] * 0.5) 35 | 36 | raster_settings = GaussianRasterizationSettings( 37 | image_height=int(data['novel_view']['height'][idx]), 38 | image_width=int(data['novel_view']['width'][idx]), 39 | tanfovx=tanfovx, 40 | tanfovy=tanfovy, 41 | bg=bg_color, 42 | scale_modifier=1.0, 43 | viewmatrix=data['novel_view']['world_view_transform'][idx], 44 | projmatrix=data['novel_view']['full_proj_transform'][idx], 45 | sh_degree=3, 46 | campos=data['novel_view']['camera_center'][idx], 47 | prefiltered=False, 48 | debug=False 49 | ) 50 | 51 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 52 | 53 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 54 | rendered_image, _ = rasterizer( 55 | means3D=pts_xyz, 56 | means2D=screenspace_points, 57 | shs=None, 58 | colors_precomp=pts_rgb, 59 | opacities=opacity, 60 | scales=scales, 61 | rotations=rotations, 62 | cov3D_precomp=None) 63 | 64 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 65 | # They will be excluded from value updates used in the splitting criteria. 66 | 67 | return rendered_image 68 | -------------------------------------------------------------------------------- /lib/GaussianRender.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from gaussian_renderer import render 4 | 5 | 6 | def pts2render(data, bg_color): 7 | bs = data['lmain']['img'].shape[0] 8 | render_novel_list = [] 9 | for i in range(bs): 10 | xyz_i_valid = [] 11 | rgb_i_valid = [] 12 | rot_i_valid = [] 13 | scale_i_valid = [] 14 | opacity_i_valid = [] 15 | for view in ['lmain', 'rmain']: 16 | valid_i = data[view]['pts_valid'][i, :] 17 | xyz_i = data[view]['xyz'][i, :, :] 18 | rgb_i = data[view]['img'][i, :, :, :].permute(1, 2, 0).view(-1, 3) 19 | rot_i = data[view]['rot_maps'][i, :, :, :].permute(1, 2, 0).view(-1, 4) 20 | scale_i = data[view]['scale_maps'][i, :, :, :].permute(1, 2, 0).view(-1, 3) 21 | opacity_i = data[view]['opacity_maps'][i, :, :, :].permute(1, 2, 0).view(-1, 1) 22 | 23 | xyz_i_valid.append(xyz_i[valid_i].view(-1, 3)) 24 | rgb_i_valid.append(rgb_i[valid_i].view(-1, 3)) 25 | rot_i_valid.append(rot_i[valid_i].view(-1, 4)) 26 | scale_i_valid.append(scale_i[valid_i].view(-1, 3)) 27 | opacity_i_valid.append(opacity_i[valid_i].view(-1, 1)) 28 | 29 | pts_xyz_i = torch.concat(xyz_i_valid, dim=0) 30 | pts_rgb_i = torch.concat(rgb_i_valid, dim=0) 31 | pts_rgb_i = pts_rgb_i * 0.5 + 0.5 32 | rot_i = torch.concat(rot_i_valid, dim=0) 33 | scale_i = torch.concat(scale_i_valid, dim=0) 34 | opacity_i = torch.concat(opacity_i_valid, dim=0) 35 | 36 | render_novel_i = render(data, i, pts_xyz_i, pts_rgb_i, rot_i, scale_i, opacity_i, bg_color=bg_color) 37 | render_novel_list.append(render_novel_i.unsqueeze(0)) 38 | 39 | data['novel_view']['img_pred'] = torch.concat(render_novel_list, dim=0) 40 | return data 41 | -------------------------------------------------------------------------------- /lib/TaichiRender.py: -------------------------------------------------------------------------------- 1 | 2 | import taichi as ti 3 | from lib.utils import * 4 | ti.init(ti.cuda) 5 | 6 | 7 | @ti.data_oriented 8 | class TaichiRenderBatch: 9 | def __init__(self, bs, res): 10 | self.res = res 11 | self.coord = ti.Vector.field(n=1, dtype=ti.i32, shape=(bs, res * res)) 12 | 13 | @ti.kernel 14 | def render_respective_color(self, pts: ti.types.ndarray(), pts_mask: ti.types.ndarray(), 15 | render_depth: ti.types.ndarray(), render_color: ti.types.ndarray()): 16 | for B, i in self.coord: 17 | if pts_mask[B, i, 0] < 0.5: 18 | continue 19 | IX, IY = ti.cast(pts[B, i, 0], ti.i32), ti.cast(pts[B, i, 1], ti.i32) 20 | IX = ti.min(self.res - 1, ti.max(IX, 0)) 21 | IY = ti.min(self.res - 1, ti.max(IY, 0)) 22 | if pts[B, i, 2] >= ti.atomic_max(render_depth[B, 0, IY, IX], pts[B, i, 2]): 23 | for k in ti.static(range(3)): 24 | render_color[B, k, IY, IX] = pts[B, i, k + 3] 25 | 26 | def flow2render(self, data): 27 | novel_view_calib = torch.matmul(data['novel_view']['intr'], data['novel_view']['extr']) 28 | B = novel_view_calib.shape[0] 29 | 30 | taichi_pts_list = [] 31 | taichi_mask_list = [] 32 | 33 | for view in ['lmain', 'rmain']: 34 | data_select = data[view] 35 | depth_pred = flow2depth(data_select).clone() 36 | valid = depth_pred != 0 37 | 38 | pts = depth2pc(depth_pred, data_select['extr'], data_select['intr']) 39 | valid = valid.view(B, -1, 1).squeeze(2) 40 | pts_valid = torch.zeros_like(pts) 41 | pts_valid[valid] = pts[valid] 42 | 43 | pts_valid = perspective(pts_valid, novel_view_calib) 44 | pts_valid[:, :, 2:] = 1.0 / (pts_valid[:, :, 2:] + 1e-8) 45 | 46 | img_valid = torch.zeros_like(pts_valid) 47 | img_valid[valid] = data_select['img'].permute(0, 2, 3, 1).view(B, -1, 3)[valid] 48 | taichi_pts = torch.cat((pts_valid, img_valid), dim=2) 49 | taichi_mask = valid.view(B, -1, 1).float() 50 | taichi_pts_list.append(taichi_pts) 51 | taichi_mask_list.append(taichi_mask) 52 | 53 | render_depth = torch.zeros((B, 1, self.res, self.res), device=pts.device).float() 54 | min_value = -1 55 | render_color = min_value + torch.zeros((B, 3, self.res, self.res), device=pts.device).float() 56 | for i in range(2): 57 | self.render_respective_color(taichi_pts_list[i], taichi_mask_list[i], render_depth, render_color) 58 | data['novel_view']['img_pred'] = render_color 59 | 60 | return data 61 | -------------------------------------------------------------------------------- /lib/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | 16 | 17 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 18 | Rt = np.zeros((4, 4)) 19 | Rt[:3, :3] = R.transpose() 20 | Rt[:3, 3] = t 21 | Rt[3, 3] = 1.0 22 | 23 | C2W = np.linalg.inv(Rt) 24 | cam_center = C2W[:3, 3] 25 | cam_center = (cam_center + translate) * scale 26 | C2W[:3, 3] = cam_center 27 | Rt = np.linalg.inv(C2W) 28 | return np.float32(Rt) 29 | 30 | 31 | def getProjectionMatrix(znear, zfar, K, h, w): 32 | near_fx = znear / K[0, 0] 33 | near_fy = znear / K[1, 1] 34 | left = - (w - K[0, 2]) * near_fx 35 | right = K[0, 2] * near_fx 36 | bottom = (K[1, 2] - h) * near_fy 37 | top = K[1, 2] * near_fy 38 | 39 | P = torch.zeros(4, 4) 40 | z_sign = 1.0 41 | P[0, 0] = 2.0 * znear / (right - left) 42 | P[1, 1] = 2.0 * znear / (top - bottom) 43 | P[0, 2] = (right + left) / (right - left) 44 | P[1, 2] = (top + bottom) / (top - bottom) 45 | P[3, 2] = z_sign 46 | P[2, 2] = z_sign * zfar / (zfar - znear) 47 | P[2, 3] = -(zfar * znear) / (zfar - znear) 48 | return P 49 | 50 | 51 | def focal2fov(focal, pixels): 52 | return 2*math.atan(pixels/(2*focal)) 53 | -------------------------------------------------------------------------------- /lib/gs_parm_network.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | from core.extractor import UnetExtractor, ResidualBlock 5 | 6 | 7 | class GSRegresser(nn.Module): 8 | def __init__(self, cfg, rgb_dim=3, depth_dim=1, norm_fn='group'): 9 | super().__init__() 10 | self.rgb_dims = cfg.raft.encoder_dims 11 | self.depth_dims = cfg.gsnet.encoder_dims 12 | self.decoder_dims = cfg.gsnet.decoder_dims 13 | self.head_dim = cfg.gsnet.parm_head_dim 14 | self.depth_encoder = UnetExtractor(in_channel=depth_dim, encoder_dim=self.depth_dims) 15 | 16 | self.decoder3 = nn.Sequential( 17 | ResidualBlock(self.rgb_dims[2]+self.depth_dims[2], self.decoder_dims[2], norm_fn=norm_fn), 18 | ResidualBlock(self.decoder_dims[2], self.decoder_dims[2], norm_fn=norm_fn) 19 | ) 20 | 21 | self.decoder2 = nn.Sequential( 22 | ResidualBlock(self.rgb_dims[1]+self.depth_dims[1]+self.decoder_dims[2], self.decoder_dims[1], norm_fn=norm_fn), 23 | ResidualBlock(self.decoder_dims[1], self.decoder_dims[1], norm_fn=norm_fn) 24 | ) 25 | 26 | self.decoder1 = nn.Sequential( 27 | ResidualBlock(self.rgb_dims[0]+self.depth_dims[0]+self.decoder_dims[1], self.decoder_dims[0], norm_fn=norm_fn), 28 | ResidualBlock(self.decoder_dims[0], self.decoder_dims[0], norm_fn=norm_fn) 29 | ) 30 | self.up = nn.Upsample(scale_factor=2, mode="bilinear") 31 | self.out_conv = nn.Conv2d(self.decoder_dims[0]+rgb_dim+depth_dim, self.head_dim, kernel_size=3, padding=1) 32 | self.out_relu = nn.ReLU(inplace=True) 33 | 34 | self.rot_head = nn.Sequential( 35 | nn.Conv2d(self.head_dim, self.head_dim, kernel_size=3, padding=1), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(self.head_dim, 4, kernel_size=1), 38 | ) 39 | self.scale_head = nn.Sequential( 40 | nn.Conv2d(self.head_dim, self.head_dim, kernel_size=3, padding=1), 41 | nn.ReLU(inplace=True), 42 | nn.Conv2d(self.head_dim, 3, kernel_size=1), 43 | nn.Softplus(beta=100) 44 | ) 45 | self.opacity_head = nn.Sequential( 46 | nn.Conv2d(self.head_dim, self.head_dim, kernel_size=3, padding=1), 47 | nn.ReLU(inplace=True), 48 | nn.Conv2d(self.head_dim, 1, kernel_size=1), 49 | nn.Sigmoid() 50 | ) 51 | 52 | def forward(self, img, depth, img_feat): 53 | img_feat1, img_feat2, img_feat3 = img_feat 54 | depth_feat1, depth_feat2, depth_feat3 = self.depth_encoder(depth) 55 | 56 | feat3 = torch.concat([img_feat3, depth_feat3], dim=1) 57 | feat2 = torch.concat([img_feat2, depth_feat2], dim=1) 58 | feat1 = torch.concat([img_feat1, depth_feat1], dim=1) 59 | 60 | up3 = self.decoder3(feat3) 61 | up3 = self.up(up3) 62 | up2 = self.decoder2(torch.cat([up3, feat2], dim=1)) 63 | up2 = self.up(up2) 64 | up1 = self.decoder1(torch.cat([up2, feat1], dim=1)) 65 | 66 | up1 = self.up(up1) 67 | out = torch.cat([up1, img, depth], dim=1) 68 | out = self.out_conv(out) 69 | out = self.out_relu(out) 70 | 71 | # rot head 72 | rot_out = self.rot_head(out) 73 | rot_out = torch.nn.functional.normalize(rot_out, dim=1) 74 | 75 | # scale head 76 | scale_out = torch.clamp_max(self.scale_head(out), 0.01) 77 | 78 | # opacity head 79 | opacity_out = self.opacity_head(out) 80 | 81 | return rot_out, scale_out, opacity_out 82 | -------------------------------------------------------------------------------- /lib/human_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | import numpy as np 4 | import os 5 | from PIL import Image 6 | import cv2 7 | import torch 8 | from lib.graphics_utils import getWorld2View2, getProjectionMatrix, focal2fov 9 | from pathlib import Path 10 | import logging 11 | import json 12 | from tqdm import tqdm 13 | 14 | 15 | def save_np_to_json(parm, save_name): 16 | for key in parm.keys(): 17 | parm[key] = parm[key].tolist() 18 | with open(save_name, 'w') as file: 19 | json.dump(parm, file, indent=1) 20 | 21 | 22 | def load_json_to_np(parm_name): 23 | with open(parm_name, 'r') as f: 24 | parm = json.load(f) 25 | for key in parm.keys(): 26 | parm[key] = np.array(parm[key]) 27 | return parm 28 | 29 | 30 | def depth2pts(depth, extrinsic, intrinsic): 31 | # depth H W extrinsic 3x4 intrinsic 3x3 pts map H W 3 32 | rot = extrinsic[:3, :3] 33 | trans = extrinsic[:3, 3:] 34 | S, S = depth.shape 35 | 36 | y, x = torch.meshgrid(torch.linspace(0.5, S-0.5, S, device=depth.device), 37 | torch.linspace(0.5, S-0.5, S, device=depth.device)) 38 | pts_2d = torch.stack([x, y, torch.ones_like(x)], dim=-1) # H W 3 39 | 40 | pts_2d[..., 2] = 1.0 / (depth + 1e-8) 41 | pts_2d[..., 0] -= intrinsic[0, 2] 42 | pts_2d[..., 1] -= intrinsic[1, 2] 43 | pts_2d_xy = pts_2d[..., :2] * pts_2d[..., 2:] 44 | pts_2d = torch.cat([pts_2d_xy, pts_2d[..., 2:]], dim=-1) 45 | 46 | pts_2d[..., 0] /= intrinsic[0, 0] 47 | pts_2d[..., 1] /= intrinsic[1, 1] 48 | pts_2d = pts_2d.reshape(-1, 3).T 49 | pts = rot.T @ pts_2d - rot.T @ trans 50 | return pts.T.view(S, S, 3) 51 | 52 | 53 | def pts2depth(ptsmap, extrinsic, intrinsic): 54 | S, S, _ = ptsmap.shape 55 | pts = ptsmap.view(-1, 3).T 56 | calib = intrinsic @ extrinsic 57 | pts = calib[:3, :3] @ pts 58 | pts = pts + calib[:3, 3:4] 59 | pts[:2, :] /= (pts[2:, :] + 1e-8) 60 | depth = 1.0 / (pts[2, :].view(S, S) + 1e-8) 61 | return depth 62 | 63 | 64 | def stereo_pts2flow(pts0, pts1, rectify0, rectify1, Tf_x): 65 | new_extr0, new_intr0, rectify_mat0_x, rectify_mat0_y = rectify0 66 | new_extr1, new_intr1, rectify_mat1_x, rectify_mat1_y = rectify1 67 | new_depth0 = pts2depth(torch.FloatTensor(pts0), torch.FloatTensor(new_extr0), torch.FloatTensor(new_intr0)) 68 | new_depth1 = pts2depth(torch.FloatTensor(pts1), torch.FloatTensor(new_extr1), torch.FloatTensor(new_intr1)) 69 | new_depth0 = new_depth0.detach().numpy() 70 | new_depth1 = new_depth1.detach().numpy() 71 | new_depth0 = cv2.remap(new_depth0, rectify_mat0_x, rectify_mat0_y, cv2.INTER_LINEAR) 72 | new_depth1 = cv2.remap(new_depth1, rectify_mat1_x, rectify_mat1_y, cv2.INTER_LINEAR) 73 | 74 | offset0 = new_intr1[0, 2] - new_intr0[0, 2] 75 | disparity0 = -new_depth0 * Tf_x 76 | flow0 = offset0 - disparity0 77 | 78 | offset1 = new_intr0[0, 2] - new_intr1[0, 2] 79 | disparity1 = -new_depth1 * (-Tf_x) 80 | flow1 = offset1 - disparity1 81 | 82 | flow0[new_depth0 < 0.05] = 0 83 | flow1[new_depth1 < 0.05] = 0 84 | 85 | return flow0, flow1 86 | 87 | 88 | def read_img(name): 89 | img = np.array(Image.open(name)) 90 | return img 91 | 92 | 93 | def read_depth(name): 94 | return cv2.imread(name, cv2.IMREAD_UNCHANGED).astype(np.float32) / 2.0 ** 15 95 | 96 | 97 | class StereoHumanDataset(Dataset): 98 | def __init__(self, opt, phase='train'): 99 | self.opt = opt 100 | self.use_processed_data = opt.use_processed_data 101 | self.phase = phase 102 | if self.phase == 'train': 103 | self.data_root = os.path.join(opt.data_root, 'train') 104 | elif self.phase == 'val': 105 | self.data_root = os.path.join(opt.data_root, 'val') 106 | elif self.phase == 'test': 107 | self.data_root = opt.test_data_root 108 | 109 | self.img_path = os.path.join(self.data_root, 'img/%s/%d.jpg') 110 | self.img_hr_path = os.path.join(self.data_root, 'img/%s/%d_hr.jpg') 111 | self.mask_path = os.path.join(self.data_root, 'mask/%s/%d.png') 112 | self.depth_path = os.path.join(self.data_root, 'depth/%s/%d.png') 113 | self.intr_path = os.path.join(self.data_root, 'parm/%s/%d_intrinsic.npy') 114 | self.extr_path = os.path.join(self.data_root, 'parm/%s/%d_extrinsic.npy') 115 | self.sample_list = sorted(list(os.listdir(os.path.join(self.data_root, 'img')))) 116 | 117 | if self.use_processed_data: 118 | self.local_data_root = os.path.join(opt.data_root, 'rectified_local', self.phase) 119 | self.local_img_path = os.path.join(self.local_data_root, 'img/%s/%d.jpg') 120 | self.local_mask_path = os.path.join(self.local_data_root, 'mask/%s/%d.png') 121 | self.local_flow_path = os.path.join(self.local_data_root, 'flow/%s/%d.npy') 122 | self.local_valid_path = os.path.join(self.local_data_root, 'valid/%s/%d.png') 123 | self.local_parm_path = os.path.join(self.local_data_root, 'parm/%s/%d_%d.json') 124 | 125 | if os.path.exists(self.local_data_root): 126 | assert len(os.listdir(os.path.join(self.local_data_root, 'img'))) == len(self.sample_list) 127 | logging.info(f"Using local data in {self.local_data_root} ...") 128 | else: 129 | self.save_local_stereo_data() 130 | 131 | def save_local_stereo_data(self): 132 | logging.info(f"Generating data to {self.local_data_root} ...") 133 | for sample_name in tqdm(self.sample_list): 134 | view0_data = self.load_single_view(sample_name, self.opt.source_id[0], hr_img=False, 135 | require_mask=True, require_pts=True) 136 | view1_data = self.load_single_view(sample_name, self.opt.source_id[1], hr_img=False, 137 | require_mask=True, require_pts=True) 138 | lmain_stereo_np = self.get_rectified_stereo_data(main_view_data=view0_data, ref_view_data=view1_data) 139 | 140 | for sub_dir in ['/img/', '/mask/', '/flow/', '/valid/', '/parm/']: 141 | Path(self.local_data_root + sub_dir + str(sample_name)).mkdir(exist_ok=True, parents=True) 142 | 143 | img0_save_name = self.local_img_path % (sample_name, self.opt.source_id[0]) 144 | mask0_save_name = self.local_mask_path % (sample_name, self.opt.source_id[0]) 145 | img1_save_name = self.local_img_path % (sample_name, self.opt.source_id[1]) 146 | mask1_save_name = self.local_mask_path % (sample_name, self.opt.source_id[1]) 147 | flow0_save_name = self.local_flow_path % (sample_name, self.opt.source_id[0]) 148 | valid0_save_name = self.local_valid_path % (sample_name, self.opt.source_id[0]) 149 | flow1_save_name = self.local_flow_path % (sample_name, self.opt.source_id[1]) 150 | valid1_save_name = self.local_valid_path % (sample_name, self.opt.source_id[1]) 151 | parm_save_name = self.local_parm_path % (sample_name, self.opt.source_id[0], self.opt.source_id[1]) 152 | 153 | Image.fromarray(lmain_stereo_np['img0']).save(img0_save_name, quality=95) 154 | Image.fromarray(lmain_stereo_np['mask0']).save(mask0_save_name) 155 | Image.fromarray(lmain_stereo_np['img1']).save(img1_save_name, quality=95) 156 | Image.fromarray(lmain_stereo_np['mask1']).save(mask1_save_name) 157 | np.save(flow0_save_name, lmain_stereo_np['flow0'].astype(np.float16)) 158 | Image.fromarray(lmain_stereo_np['valid0']).save(valid0_save_name) 159 | np.save(flow1_save_name, lmain_stereo_np['flow1'].astype(np.float16)) 160 | Image.fromarray(lmain_stereo_np['valid1']).save(valid1_save_name) 161 | save_np_to_json(lmain_stereo_np['camera'], parm_save_name) 162 | 163 | logging.info("Generating data Done!") 164 | 165 | def load_local_stereo_data(self, sample_name): 166 | img0_name = self.local_img_path % (sample_name, self.opt.source_id[0]) 167 | mask0_name = self.local_mask_path % (sample_name, self.opt.source_id[0]) 168 | img1_name = self.local_img_path % (sample_name, self.opt.source_id[1]) 169 | mask1_name = self.local_mask_path % (sample_name, self.opt.source_id[1]) 170 | flow0_name = self.local_flow_path % (sample_name, self.opt.source_id[0]) 171 | flow1_name = self.local_flow_path % (sample_name, self.opt.source_id[1]) 172 | valid0_name = self.local_valid_path % (sample_name, self.opt.source_id[0]) 173 | valid1_name = self.local_valid_path % (sample_name, self.opt.source_id[1]) 174 | parm_name = self.local_parm_path % (sample_name, self.opt.source_id[0], self.opt.source_id[1]) 175 | 176 | stereo_data = { 177 | 'img0': read_img(img0_name), 178 | 'mask0': read_img(mask0_name), 179 | 'img1': read_img(img1_name), 180 | 'mask1': read_img(mask1_name), 181 | 'camera': load_json_to_np(parm_name), 182 | 'flow0': np.load(flow0_name), 183 | 'valid0': read_img(valid0_name), 184 | 'flow1': np.load(flow1_name), 185 | 'valid1': read_img(valid1_name) 186 | } 187 | 188 | return stereo_data 189 | 190 | def load_single_view(self, sample_name, source_id, hr_img=False, require_mask=True, require_pts=True): 191 | img_name = self.img_path % (sample_name, source_id) 192 | image_hr_name = self.img_hr_path % (sample_name, source_id) 193 | mask_name = self.mask_path % (sample_name, source_id) 194 | depth_name = self.depth_path % (sample_name, source_id) 195 | intr_name = self.intr_path % (sample_name, source_id) 196 | extr_name = self.extr_path % (sample_name, source_id) 197 | 198 | intr, extr = np.load(intr_name), np.load(extr_name) 199 | mask, pts = None, None 200 | if hr_img: 201 | img = read_img(image_hr_name) 202 | intr[:2] *= 2 203 | else: 204 | img = read_img(img_name) 205 | if require_mask: 206 | mask = read_img(mask_name) 207 | if require_pts and os.path.exists(depth_name): 208 | depth = read_depth(depth_name) 209 | pts = depth2pts(torch.FloatTensor(depth), torch.FloatTensor(extr), torch.FloatTensor(intr)) 210 | 211 | return img, mask, intr, extr, pts 212 | 213 | def get_novel_view_tensor(self, sample_name, view_id): 214 | img, _, intr, extr, _ = self.load_single_view(sample_name, view_id, hr_img=self.opt.use_hr_img, 215 | require_mask=False, require_pts=False) 216 | height, width = img.shape[:2] 217 | img = torch.from_numpy(img).permute(2, 0, 1) 218 | img = img / 255.0 219 | 220 | R = np.array(extr[:3, :3], np.float32).reshape(3, 3).transpose(1, 0) 221 | T = np.array(extr[:3, 3], np.float32) 222 | 223 | FovX = focal2fov(intr[0, 0], width) 224 | FovY = focal2fov(intr[1, 1], height) 225 | projection_matrix = getProjectionMatrix(znear=self.opt.znear, zfar=self.opt.zfar, K=intr, h=height, w=width).transpose(0, 1) 226 | world_view_transform = torch.tensor(getWorld2View2(R, T, np.array(self.opt.trans), self.opt.scale)).transpose(0, 1) 227 | full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0) 228 | camera_center = world_view_transform.inverse()[3, :3] 229 | 230 | novel_view_data = { 231 | 'view_id': torch.IntTensor([view_id]), 232 | 'img': img, 233 | 'extr': torch.FloatTensor(extr), 234 | 'FovX': FovX, 235 | 'FovY': FovY, 236 | 'width': width, 237 | 'height': height, 238 | 'world_view_transform': world_view_transform, 239 | 'full_proj_transform': full_proj_transform, 240 | 'camera_center': camera_center 241 | } 242 | 243 | return novel_view_data 244 | 245 | def get_rectified_stereo_data(self, main_view_data, ref_view_data): 246 | img0, mask0, intr0, extr0, pts0 = main_view_data 247 | img1, mask1, intr1, extr1, pts1 = ref_view_data 248 | 249 | H, W = self.opt.src_res, self.opt.src_res 250 | r0, t0 = extr0[:3, :3], extr0[:3, 3:] 251 | r1, t1 = extr1[:3, :3], extr1[:3, 3:] 252 | inv_r0 = r0.T 253 | inv_t0 = - r0.T @ t0 254 | E0 = np.eye(4) 255 | E0[:3, :3], E0[:3, 3:] = inv_r0, inv_t0 256 | E1 = np.eye(4) 257 | E1[:3, :3], E1[:3, 3:] = r1, t1 258 | E = E1 @ E0 259 | R, T = E[:3, :3], E[:3, 3] 260 | dist0, dist1 = np.zeros(4), np.zeros(4) 261 | 262 | R0, R1, P0, P1, _, _, _ = cv2.stereoRectify(intr0, dist0, intr1, dist1, (W, H), R, T, flags=0) 263 | 264 | new_extr0 = R0 @ extr0 265 | new_intr0 = P0[:3, :3] 266 | new_extr1 = R1 @ extr1 267 | new_intr1 = P1[:3, :3] 268 | Tf_x = np.array(P1[0, 3]) 269 | 270 | camera = { 271 | 'intr0': new_intr0, 272 | 'intr1': new_intr1, 273 | 'extr0': new_extr0, 274 | 'extr1': new_extr1, 275 | 'Tf_x': Tf_x 276 | } 277 | 278 | rectify_mat0_x, rectify_mat0_y = cv2.initUndistortRectifyMap(intr0, dist0, R0, P0, (W, H), cv2.CV_32FC1) 279 | new_img0 = cv2.remap(img0, rectify_mat0_x, rectify_mat0_y, cv2.INTER_LINEAR) 280 | new_mask0 = cv2.remap(mask0, rectify_mat0_x, rectify_mat0_y, cv2.INTER_LINEAR) 281 | rectify_mat1_x, rectify_mat1_y = cv2.initUndistortRectifyMap(intr1, dist1, R1, P1, (W, H), cv2.CV_32FC1) 282 | new_img1 = cv2.remap(img1, rectify_mat1_x, rectify_mat1_y, cv2.INTER_LINEAR) 283 | new_mask1 = cv2.remap(mask1, rectify_mat1_x, rectify_mat1_y, cv2.INTER_LINEAR) 284 | rectify0 = new_extr0, new_intr0, rectify_mat0_x, rectify_mat0_y 285 | rectify1 = new_extr1, new_intr1, rectify_mat1_x, rectify_mat1_y 286 | 287 | stereo_data = { 288 | 'img0': new_img0, 289 | 'mask0': new_mask0, 290 | 'img1': new_img1, 291 | 'mask1': new_mask1, 292 | 'camera': camera 293 | } 294 | 295 | if pts0 is not None: 296 | flow0, flow1 = stereo_pts2flow(pts0, pts1, rectify0, rectify1, Tf_x) 297 | 298 | kernel = np.ones((3, 3), dtype=np.uint8) 299 | flow_eroded, valid_eroded = [], [] 300 | for (flow, new_mask) in [(flow0, new_mask0), (flow1, new_mask1)]: 301 | valid = (new_mask.copy()[:, :, 0] / 255.0).astype(np.float32) 302 | valid = cv2.erode(valid, kernel, 1) 303 | valid[valid >= 0.66] = 1.0 304 | valid[valid < 0.66] = 0.0 305 | flow *= valid 306 | valid *= 255.0 307 | flow_eroded.append(flow) 308 | valid_eroded.append(valid) 309 | 310 | stereo_data.update({ 311 | 'flow0': flow_eroded[0], 312 | 'valid0': valid_eroded[0].astype(np.uint8), 313 | 'flow1': flow_eroded[1], 314 | 'valid1': valid_eroded[1].astype(np.uint8) 315 | }) 316 | 317 | return stereo_data 318 | 319 | def stereo_to_dict_tensor(self, stereo_data, subject_name): 320 | img_tensor, mask_tensor = [], [] 321 | for (img_view, mask_view) in [('img0', 'mask0'), ('img1', 'mask1')]: 322 | img = torch.from_numpy(stereo_data[img_view]).permute(2, 0, 1) 323 | img = 2 * (img / 255.0) - 1.0 324 | mask = torch.from_numpy(stereo_data[mask_view]).permute(2, 0, 1).float() 325 | mask = mask / 255.0 326 | 327 | img = img * mask 328 | mask[mask < 0.5] = 0.0 329 | mask[mask >= 0.5] = 1.0 330 | img_tensor.append(img) 331 | mask_tensor.append(mask) 332 | 333 | lmain_data = { 334 | 'img': img_tensor[0], 335 | 'mask': mask_tensor[0], 336 | 'intr': torch.FloatTensor(stereo_data['camera']['intr0']), 337 | 'ref_intr': torch.FloatTensor(stereo_data['camera']['intr1']), 338 | 'extr': torch.FloatTensor(stereo_data['camera']['extr0']), 339 | 'Tf_x': torch.FloatTensor(stereo_data['camera']['Tf_x']) 340 | } 341 | 342 | rmain_data = { 343 | 'img': img_tensor[1], 344 | 'mask': mask_tensor[1], 345 | 'intr': torch.FloatTensor(stereo_data['camera']['intr1']), 346 | 'ref_intr': torch.FloatTensor(stereo_data['camera']['intr0']), 347 | 'extr': torch.FloatTensor(stereo_data['camera']['extr1']), 348 | 'Tf_x': -torch.FloatTensor(stereo_data['camera']['Tf_x']) 349 | } 350 | 351 | if 'flow0' in stereo_data: 352 | flow_tensor, valid_tensor = [], [] 353 | for (flow_view, valid_view) in [('flow0', 'valid0'), ('flow1', 'valid1')]: 354 | flow = torch.from_numpy(stereo_data[flow_view]) 355 | flow = torch.unsqueeze(flow, dim=0) 356 | flow_tensor.append(flow) 357 | 358 | valid = torch.from_numpy(stereo_data[valid_view]) 359 | valid = torch.unsqueeze(valid, dim=0) 360 | valid = valid / 255.0 361 | valid_tensor.append(valid) 362 | 363 | lmain_data['flow'], lmain_data['valid'] = flow_tensor[0], valid_tensor[0] 364 | rmain_data['flow'], rmain_data['valid'] = flow_tensor[1], valid_tensor[1] 365 | 366 | return {'name': subject_name, 'lmain': lmain_data, 'rmain': rmain_data} 367 | 368 | def get_item(self, index, novel_id=None): 369 | sample_id = index % len(self.sample_list) 370 | sample_name = self.sample_list[sample_id] 371 | 372 | if self.use_processed_data: 373 | stereo_np = self.load_local_stereo_data(sample_name) 374 | else: 375 | view0_data = self.load_single_view(sample_name, self.opt.source_id[0], hr_img=False, 376 | require_mask=True, require_pts=True) 377 | view1_data = self.load_single_view(sample_name, self.opt.source_id[1], hr_img=False, 378 | require_mask=True, require_pts=True) 379 | stereo_np = self.get_rectified_stereo_data(main_view_data=view0_data, ref_view_data=view1_data) 380 | dict_tensor = self.stereo_to_dict_tensor(stereo_np, sample_name) 381 | 382 | if novel_id: 383 | novel_id = np.random.choice(novel_id) 384 | dict_tensor.update({ 385 | 'novel_view': self.get_novel_view_tensor(sample_name, novel_id) 386 | }) 387 | 388 | return dict_tensor 389 | 390 | def get_test_item(self, index, source_id): 391 | sample_id = index % len(self.sample_list) 392 | sample_name = self.sample_list[sample_id] 393 | 394 | if self.use_processed_data: 395 | logging.error('test data loader not support processed data') 396 | 397 | view0_data = self.load_single_view(sample_name, source_id[0], hr_img=False, require_mask=True, require_pts=False) 398 | view1_data = self.load_single_view(sample_name, source_id[1], hr_img=False, require_mask=True, require_pts=False) 399 | lmain_intr_ori, lmain_extr_ori = view0_data[2], view0_data[3] 400 | rmain_intr_ori, rmain_extr_ori = view1_data[2], view1_data[3] 401 | stereo_np = self.get_rectified_stereo_data(main_view_data=view0_data, ref_view_data=view1_data) 402 | dict_tensor = self.stereo_to_dict_tensor(stereo_np, sample_name) 403 | 404 | dict_tensor['lmain']['intr_ori'] = torch.FloatTensor(lmain_intr_ori) 405 | dict_tensor['rmain']['intr_ori'] = torch.FloatTensor(rmain_intr_ori) 406 | dict_tensor['lmain']['extr_ori'] = torch.FloatTensor(lmain_extr_ori) 407 | dict_tensor['rmain']['extr_ori'] = torch.FloatTensor(rmain_extr_ori) 408 | 409 | img_len = self.opt.src_res * 2 if self.opt.use_hr_img else self.opt.src_res 410 | novel_dict = { 411 | 'height': torch.IntTensor([img_len]), 412 | 'width': torch.IntTensor([img_len]) 413 | } 414 | 415 | dict_tensor.update({ 416 | 'novel_view': novel_dict 417 | }) 418 | 419 | return dict_tensor 420 | 421 | def __getitem__(self, index): 422 | if self.phase == 'train': 423 | return self.get_item(index, novel_id=self.opt.train_novel_id) 424 | elif self.phase == 'val': 425 | return self.get_item(index, novel_id=self.opt.val_novel_id) 426 | 427 | def __len__(self): 428 | self.train_boost = 50 429 | self.val_boost = 200 430 | if self.phase == 'train': 431 | return len(self.sample_list) * self.train_boost 432 | elif self.phase == 'val': 433 | return len(self.sample_list) * self.val_boost 434 | else: 435 | return len(self.sample_list) 436 | -------------------------------------------------------------------------------- /lib/loss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from math import exp 6 | 7 | 8 | def sequence_loss(flow_preds, flow_gt, valid, loss_gamma=0.9): 9 | """ Loss function defined over sequence of flow predictions """ 10 | 11 | n_predictions = len(flow_preds) 12 | flow_loss = 0.0 13 | 14 | valid = (valid >= 0.5) 15 | assert not torch.isinf(flow_gt[valid.bool()]).any() 16 | 17 | for i in range(n_predictions): 18 | # We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations 19 | adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1)) 20 | i_weight = adjusted_loss_gamma**(n_predictions - i - 1) 21 | i_loss = (flow_preds[i] - flow_gt).abs() 22 | flow_loss += i_weight * i_loss[valid.bool()].mean() 23 | 24 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() 25 | epe = epe.view(-1)[valid.view(-1)] 26 | 27 | metrics = { 28 | 'train_epe': epe.mean().item(), 29 | 'train_1px': (epe < 1).float().mean().item(), 30 | 'train_3px': (epe < 3).float().mean().item() 31 | } 32 | 33 | return flow_loss, metrics 34 | 35 | 36 | def l1_loss(network_output, gt): 37 | return torch.abs((network_output - gt)).mean() 38 | 39 | 40 | def gaussian(window_size, sigma): 41 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 42 | return gauss / gauss.sum() 43 | 44 | 45 | def create_window(window_size, channel): 46 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 47 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 48 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 49 | return window 50 | 51 | 52 | def ssim(img1, img2, window_size=11, size_average=True): 53 | channel = img1.size(-3) 54 | window = create_window(window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | return _ssim(img1, img2, window, window_size, channel, size_average) 61 | 62 | 63 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 64 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 65 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 66 | 67 | mu1_sq = mu1.pow(2) 68 | mu2_sq = mu2.pow(2) 69 | mu1_mu2 = mu1 * mu2 70 | 71 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 72 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 73 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 74 | 75 | C1 = 0.01 ** 2 76 | C2 = 0.03 ** 2 77 | 78 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 79 | 80 | if size_average: 81 | return ssim_map.mean() 82 | else: 83 | return ssim_map.mean(1).mean(1).mean(1) 84 | 85 | 86 | def psnr(img1, img2): 87 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 88 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 89 | -------------------------------------------------------------------------------- /lib/network.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | from core.raft_stereo_human import RAFTStereoHuman 5 | from core.extractor import UnetExtractor 6 | from lib.gs_parm_network import GSRegresser 7 | from lib.loss import sequence_loss 8 | from lib.utils import flow2depth, depth2pc 9 | from torch.cuda.amp import autocast as autocast 10 | 11 | 12 | class RtStereoHumanModel(nn.Module): 13 | def __init__(self, cfg, with_gs_render=False): 14 | super().__init__() 15 | self.cfg = cfg 16 | self.with_gs_render = with_gs_render 17 | self.train_iters = self.cfg.raft.train_iters 18 | self.val_iters = self.cfg.raft.val_iters 19 | 20 | self.img_encoder = UnetExtractor(in_channel=3, encoder_dim=self.cfg.raft.encoder_dims) 21 | self.raft_stereo = RAFTStereoHuman(self.cfg.raft) 22 | if self.with_gs_render: 23 | self.gs_parm_regresser = GSRegresser(self.cfg, rgb_dim=3, depth_dim=1) 24 | 25 | def forward(self, data, is_train=True): 26 | bs = data['lmain']['img'].shape[0] 27 | 28 | image = torch.cat([data['lmain']['img'], data['rmain']['img']], dim=0) 29 | flow = torch.cat([data['lmain']['flow'], data['rmain']['flow']], dim=0) if is_train else None 30 | valid = torch.cat([data['lmain']['valid'], data['rmain']['valid']], dim=0) if is_train else None 31 | 32 | with autocast(enabled=self.cfg.raft.mixed_precision): 33 | img_feat = self.img_encoder(image) 34 | 35 | if is_train: 36 | flow_predictions = self.raft_stereo(img_feat[2], iters=self.train_iters) 37 | flow_loss, metrics = sequence_loss(flow_predictions, flow, valid) 38 | flow_pred_lmain, flow_pred_rmain = torch.split(flow_predictions[-1], [bs, bs]) 39 | 40 | if not self.with_gs_render: 41 | data['lmain']['flow_pred'] = flow_pred_lmain.detach() 42 | data['rmain']['flow_pred'] = flow_pred_rmain.detach() 43 | return data, flow_loss, metrics 44 | 45 | data['lmain']['flow_pred'] = flow_pred_lmain 46 | data['rmain']['flow_pred'] = flow_pred_rmain 47 | data = self.flow2gsparms(image, img_feat, data, bs) 48 | 49 | return data, flow_loss, metrics 50 | 51 | else: 52 | flow_up = self.raft_stereo(img_feat[2], iters=self.val_iters, test_mode=True) 53 | flow_loss, metrics = None, None 54 | 55 | data['lmain']['flow_pred'] = flow_up[0] 56 | data['rmain']['flow_pred'] = flow_up[1] 57 | 58 | if not self.with_gs_render: 59 | return data, flow_loss, metrics 60 | data = self.flow2gsparms(image, img_feat, data, bs) 61 | 62 | return data, flow_loss, metrics 63 | 64 | def flow2gsparms(self, lr_img, lr_img_feat, data, bs): 65 | for view in ['lmain', 'rmain']: 66 | data[view]['depth'] = flow2depth(data[view]) 67 | data[view]['xyz'] = depth2pc(data[view]['depth'], data[view]['extr'], data[view]['intr']).view(bs, -1, 3) 68 | valid = data[view]['depth'] != 0.0 69 | data[view]['pts_valid'] = valid.view(bs, -1) 70 | 71 | # regress gaussian parms 72 | lr_depth = torch.concat([data['lmain']['depth'], data['rmain']['depth']], dim=0) 73 | rot_maps, scale_maps, opacity_maps = self.gs_parm_regresser(lr_img, lr_depth, lr_img_feat) 74 | 75 | data['lmain']['rot_maps'], data['rmain']['rot_maps'] = torch.split(rot_maps, [bs, bs]) 76 | data['lmain']['scale_maps'], data['rmain']['scale_maps'] = torch.split(scale_maps, [bs, bs]) 77 | data['lmain']['opacity_maps'], data['rmain']['opacity_maps'] = torch.split(opacity_maps, [bs, bs]) 78 | 79 | return data 80 | 81 | -------------------------------------------------------------------------------- /lib/train_recoder.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import shutil 5 | import logging 6 | from pathlib import Path 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | 10 | def file_backup(exp_path, cfg, train_script): 11 | shutil.copy(train_script, exp_path) 12 | shutil.copytree('core', os.path.join(exp_path, 'core'), dirs_exist_ok=True) 13 | shutil.copytree('config', os.path.join(exp_path, 'config'), dirs_exist_ok=True) 14 | shutil.copytree('gaussian_renderer', os.path.join(exp_path, 'gaussian_renderer'), dirs_exist_ok=True) 15 | for sub_dir in ['lib']: 16 | files = os.listdir(sub_dir) 17 | for file in files: 18 | Path(os.path.join(exp_path, sub_dir)).mkdir(exist_ok=True, parents=True) 19 | if file[-3:] == '.py': 20 | shutil.copy(os.path.join(sub_dir, file), os.path.join(exp_path, sub_dir)) 21 | 22 | json_file_name = exp_path + '/cfg.json' 23 | with open(json_file_name, 'w') as json_file: 24 | json.dump(cfg, json_file, indent=2) 25 | 26 | 27 | class Logger: 28 | def __init__(self, scheduler, cfg): 29 | self.scheduler = scheduler 30 | self.sum_freq = cfg.loss_freq 31 | self.log_dir = cfg.logs_path 32 | self.total_steps = 0 33 | self.running_loss = {} 34 | self.writer = SummaryWriter(log_dir=self.log_dir) 35 | 36 | def _print_training_status(self): 37 | metrics_data = [self.running_loss[k] / self.sum_freq for k in sorted(self.running_loss.keys())] 38 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps, self.scheduler.get_last_lr()[0]) 39 | metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data) 40 | 41 | # print the training status 42 | logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}") 43 | 44 | if self.writer is None: 45 | self.writer = SummaryWriter(log_dir=self.log_dir) 46 | 47 | for k in self.running_loss: 48 | self.writer.add_scalar(k, self.running_loss[k] / self.sum_freq, self.total_steps) 49 | self.running_loss[k] = 0.0 50 | 51 | def push(self, metrics): 52 | for key in metrics: 53 | if key not in self.running_loss: 54 | self.running_loss[key] = 0.0 55 | 56 | self.running_loss[key] += metrics[key] 57 | 58 | if self.total_steps and self.total_steps % self.sum_freq == 0: 59 | self._print_training_status() 60 | self.running_loss = {} 61 | 62 | self.total_steps += 1 63 | 64 | def write_dict(self, results, write_step): 65 | if self.writer is None: 66 | self.writer = SummaryWriter(log_dir=self.log_dir) 67 | 68 | for key in results: 69 | self.writer.add_scalar(key, results[key], write_step) 70 | 71 | def close(self): 72 | self.writer.close() 73 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from scipy.spatial.transform import Rotation as Rot 5 | from scipy.spatial.transform import Slerp 6 | from lib.graphics_utils import getWorld2View2, getProjectionMatrix, focal2fov 7 | 8 | 9 | def get_novel_calib(data, opt, ratio=0.5, intr_key='intr', extr_key='extr'): 10 | bs = data['lmain'][intr_key].shape[0] 11 | fovx_list, fovy_list, world_view_transform_list, full_proj_transform_list, camera_center_list = [], [], [], [], [] 12 | for i in range(bs): 13 | intr0 = data['lmain'][intr_key][i, ...].cpu().numpy() 14 | intr1 = data['rmain'][intr_key][i, ...].cpu().numpy() 15 | extr0 = data['lmain'][extr_key][i, ...].cpu().numpy() 16 | extr1 = data['rmain'][extr_key][i, ...].cpu().numpy() 17 | 18 | rot0 = extr0[:3, :3] 19 | rot1 = extr1[:3, :3] 20 | rots = Rot.from_matrix(np.stack([rot0, rot1])) 21 | key_times = [0, 1] 22 | slerp = Slerp(key_times, rots) 23 | rot = slerp(ratio) 24 | npose = np.diag([1.0, 1.0, 1.0, 1.0]) 25 | npose = npose.astype(np.float32) 26 | npose[:3, :3] = rot.as_matrix() 27 | npose[:3, 3] = ((1.0 - ratio) * extr0 + ratio * extr1)[:3, 3] 28 | extr_new = npose[:3, :] 29 | intr_new = ((1.0 - ratio) * intr0 + ratio * intr1) 30 | 31 | if opt.use_hr_img: 32 | intr_new[:2] *= 2 33 | width, height = data['novel_view']['width'][i], data['novel_view']['height'][i] 34 | R = np.array(extr_new[:3, :3], np.float32).reshape(3, 3).transpose(1, 0) 35 | T = np.array(extr_new[:3, 3], np.float32) 36 | 37 | FovX = focal2fov(intr_new[0, 0], width) 38 | FovY = focal2fov(intr_new[1, 1], height) 39 | projection_matrix = getProjectionMatrix(znear=opt.znear, zfar=opt.zfar, K=intr_new, h=height, w=width).transpose(0, 1) 40 | world_view_transform = torch.tensor(getWorld2View2(R, T, np.array(opt.trans), opt.scale)).transpose(0, 1) 41 | full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0) 42 | camera_center = world_view_transform.inverse()[3, :3] 43 | 44 | fovx_list.append(FovX) 45 | fovy_list.append(FovY) 46 | world_view_transform_list.append(world_view_transform.unsqueeze(0)) 47 | full_proj_transform_list.append(full_proj_transform.unsqueeze(0)) 48 | camera_center_list.append(camera_center.unsqueeze(0)) 49 | 50 | data['novel_view']['FovX'] = torch.FloatTensor(np.array(fovx_list)).cuda() 51 | data['novel_view']['FovY'] = torch.FloatTensor(np.array(fovy_list)).cuda() 52 | data['novel_view']['world_view_transform'] = torch.concat(world_view_transform_list).cuda() 53 | data['novel_view']['full_proj_transform'] = torch.concat(full_proj_transform_list).cuda() 54 | data['novel_view']['camera_center'] = torch.concat(camera_center_list).cuda() 55 | return data 56 | 57 | 58 | def get_novel_calib_for_show(data, ratio=0.5, intr_key='intr', extr_key='extr'): 59 | bs = data['lmain'][intr_key].shape[0] 60 | intr_list, extr_list = [], [] 61 | data['novel_view'] = {} 62 | for i in range(bs): 63 | intr0 = data['lmain'][intr_key][i, ...].cpu().numpy() 64 | intr1 = data['rmain'][intr_key][i, ...].cpu().numpy() 65 | extr0 = data['lmain'][extr_key][i, ...].cpu().numpy() 66 | extr1 = data['rmain'][extr_key][i, ...].cpu().numpy() 67 | 68 | rot0 = extr0[:3, :3] 69 | rot1 = extr1[:3, :3] 70 | rots = Rot.from_matrix(np.stack([rot0, rot1])) 71 | key_times = [0, 1] 72 | slerp = Slerp(key_times, rots) 73 | rot = slerp(ratio) 74 | npose = np.diag([1.0, 1.0, 1.0, 1.0]) 75 | npose = npose.astype(np.float32) 76 | npose[:3, :3] = rot.as_matrix() 77 | npose[:3, 3] = ((1.0 - ratio) * extr0 + ratio * extr1)[:3, 3] 78 | extr_new = npose[:3, :] 79 | 80 | intr_new = ((1.0 - ratio) * intr0 + ratio * intr1) 81 | intr_list.append(intr_new) 82 | extr_list.append(extr_new) 83 | data['novel_view']['intr'] = torch.FloatTensor(np.array(intr_list)).cuda() 84 | data['novel_view']['extr'] = torch.FloatTensor(np.array(extr_list)).cuda() 85 | return data 86 | 87 | 88 | def depth2pc(depth, extrinsic, intrinsic): 89 | B, C, S, S = depth.shape 90 | depth = depth[:, 0, :, :] 91 | rot = extrinsic[:, :3, :3] 92 | trans = extrinsic[:, :3, 3:] 93 | 94 | y, x = torch.meshgrid(torch.linspace(0.5, S-0.5, S, device=depth.device), torch.linspace(0.5, S-0.5, S, device=depth.device)) 95 | pts_2d = torch.stack([x, y, torch.ones_like(x)], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1) # B S S 3 96 | 97 | pts_2d[..., 2] = 1.0 / (depth + 1e-8) 98 | pts_2d[:, :, :, 0] -= intrinsic[:, None, None, 0, 2] 99 | pts_2d[:, :, :, 1] -= intrinsic[:, None, None, 1, 2] 100 | pts_2d_xy = pts_2d[:, :, :, :2] * pts_2d[:, :, :, 2:] 101 | pts_2d = torch.cat([pts_2d_xy, pts_2d[..., 2:]], dim=-1) 102 | 103 | pts_2d[..., 0] /= intrinsic[:, 0, 0][:, None, None] 104 | pts_2d[..., 1] /= intrinsic[:, 1, 1][:, None, None] 105 | 106 | pts_2d = pts_2d.view(B, -1, 3).permute(0, 2, 1) 107 | rot_t = rot.permute(0, 2, 1) 108 | pts = torch.bmm(rot_t, pts_2d) - torch.bmm(rot_t, trans) 109 | 110 | return pts.permute(0, 2, 1) 111 | 112 | 113 | def flow2depth(data): 114 | offset = data['ref_intr'][:, 0, 2] - data['intr'][:, 0, 2] 115 | offset = torch.broadcast_to(offset[:, None, None, None], data['flow_pred'].shape) 116 | disparity = offset - data['flow_pred'] 117 | depth = -disparity / data['Tf_x'][:, None, None, None] 118 | depth *= data['mask'][:, :1, :, :] 119 | 120 | return depth 121 | 122 | def perspective(pts, calibs): 123 | pts = pts.permute(0, 2, 1) 124 | pts = torch.bmm(calibs[:, :3, :3], pts) 125 | pts = pts + calibs[:, :3, 3:4] 126 | pts[:, :2, :] /= pts[:, 2:, :] 127 | pts = pts.permute(0, 2, 1) 128 | return pts 129 | -------------------------------------------------------------------------------- /prepare_data/MAKE_DATA.md: -------------------------------------------------------------------------------- 1 | # Data Documentation 2 | 3 | We provide a scripts for rendering training data from human scans, many thanks [Ruizhi Shao](https://dsaurus.github.io/saurus/) for sharing this code. Take [THuman2.0](https://github.com/ytrock/THuman2.0-Dataset) as an example. 4 | 5 | - Download THuman2.0 scan data from [This Link](https://github.com/ytrock/THuman2.0-Dataset) and the SMPL-X fitting parameters from [This Link](https://drive.google.com/file/d/1rnkGomScq3yxyM9auA-oHW6m_OJ5mlGL/view?usp=sharing). Then spilt the THuman2.0 scans into train set and validation set. We use theSMPL-X parameters to normalization the orientation of human and it is not essential. Comment L133-L140 in [render_data.py](render_data.py#L133-L140) if you do not need to normalization the orientation of human scans in THuman2.0. 6 | ``` 7 | ./Thuman2.0 8 | ├── THuman2.0_Smpl_X_Paras/ 9 | ├── train/ 10 | │ ├── 0004/ 11 | │ │ ├── 0004.obj 12 | │ │ ├── material0.jpeg 13 | │ │ └── material0.mtl 14 | │ ├── 0005 15 | │ ├── 0007 16 | │ └── ... 17 | └──val 18 | ├── 0000 19 | ├── 0001 20 | ├── 0002 21 | └── ... 22 | ``` 23 | 24 | - Set the correct ```thuman_root``` and ```save_root``` in [render_data.py](render_data.py#L214-L215). Reset the ```cam_nums``` , ```scene_radius``` and camera parameters to make the training data similar to your targeting real-world scenario. 25 | -------------------------------------------------------------------------------- /prepare_data/render_data.py: -------------------------------------------------------------------------------- 1 | import taichi_three as t3 2 | import numpy as np 3 | from taichi_three.transform import * 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | import os 7 | import cv2 8 | import pickle 9 | os.environ["KMP_DUPLICATE_LIB_OK"] = "True" 10 | 11 | 12 | def save(pid, data_id, vid, save_path, extr, intr, depth, img, mask, img_hr=None): 13 | img_save_path = os.path.join(save_path, 'img', data_id + '_' + '%03d' % pid) 14 | depth_save_path = os.path.join(save_path, 'depth', data_id + '_' + '%03d' % pid) 15 | mask_save_path = os.path.join(save_path, 'mask', data_id + '_' + '%03d' % pid) 16 | parm_save_path = os.path.join(save_path, 'parm', data_id + '_' + '%03d' % pid) 17 | Path(img_save_path).mkdir(exist_ok=True, parents=True) 18 | Path(parm_save_path).mkdir(exist_ok=True, parents=True) 19 | Path(mask_save_path).mkdir(exist_ok=True, parents=True) 20 | Path(depth_save_path).mkdir(exist_ok=True, parents=True) 21 | 22 | depth = depth * 2.0 ** 15 23 | cv2.imwrite(os.path.join(depth_save_path, '{}.png'.format(vid)), depth.astype(np.uint16)) 24 | img = (np.clip(img, 0, 1) * 255.0 + 0.5).astype(np.uint8)[:, :, ::-1] 25 | mask = (np.clip(mask, 0, 1) * 255.0 + 0.5).astype(np.uint8) 26 | cv2.imwrite(os.path.join(img_save_path, '{}.jpg'.format(vid)), img) 27 | if img_hr is not None: 28 | img_hr = (np.clip(img_hr, 0, 1) * 255.0 + 0.5).astype(np.uint8)[:, :, ::-1] 29 | cv2.imwrite(os.path.join(img_save_path, '{}_hr.jpg'.format(vid)), img_hr) 30 | cv2.imwrite(os.path.join(mask_save_path, '{}.png'.format(vid)), mask) 31 | np.save(os.path.join(parm_save_path, '{}_intrinsic.npy'.format(vid)), intr) 32 | np.save(os.path.join(parm_save_path, '{}_extrinsic.npy'.format(vid)), extr) 33 | 34 | 35 | class StaticRenderer: 36 | def __init__(self, src_res): 37 | ti.init(arch=ti.cuda, device_memory_fraction=0.8) 38 | self.scene = t3.Scene() 39 | self.N = 10 40 | self.src_res = src_res 41 | self.hr_res = (src_res[0] * 2, src_res[1] * 2) 42 | 43 | def change_all(self): 44 | save_obj = [] 45 | save_tex = [] 46 | for model in self.scene.models: 47 | save_obj.append(model.init_obj) 48 | save_tex.append(model.init_tex) 49 | ti.init(arch=ti.cuda, device_memory_fraction=0.8) 50 | print('init') 51 | self.scene = t3.Scene() 52 | for i in range(len(save_obj)): 53 | model = t3.StaticModel(self.N, obj=save_obj[i], tex=save_tex[i]) 54 | self.scene.add_model(model) 55 | 56 | def check_update(self, obj): 57 | temp_n = self.N 58 | self.N = max(obj['vi'].shape[0], self.N) 59 | self.N = max(obj['f'].shape[0], self.N) 60 | if not (obj['vt'] is None): 61 | self.N = max(obj['vt'].shape[0], self.N) 62 | 63 | if self.N > temp_n: 64 | self.N *= 2 65 | self.change_all() 66 | self.camera_light() 67 | 68 | def add_model(self, obj, tex=None): 69 | self.check_update(obj) 70 | model = t3.StaticModel(self.N, obj=obj, tex=tex) 71 | self.scene.add_model(model) 72 | 73 | def modify_model(self, index, obj, tex=None): 74 | self.check_update(obj) 75 | self.scene.models[index].init_obj = obj 76 | self.scene.models[index].init_tex = tex 77 | self.scene.models[index]._init() 78 | 79 | def camera_light(self): 80 | camera = t3.Camera(res=self.src_res) 81 | self.scene.add_camera(camera) 82 | 83 | camera_hr = t3.Camera(res=self.hr_res) 84 | self.scene.add_camera(camera_hr) 85 | 86 | light_dir = np.array([0, 0, 1]) 87 | light_list = [] 88 | for l in range(6): 89 | rotate = np.matmul(rotationX(math.radians(np.random.uniform(-30, 30))), 90 | rotationY(math.radians(360 // 6 * l))) 91 | dir = [*np.matmul(rotate, light_dir)] 92 | light = t3.Light(dir, color=[1.0, 1.0, 1.0]) 93 | light_list.append(light) 94 | lights = t3.Lights(light_list) 95 | self.scene.add_lights(lights) 96 | 97 | 98 | def render_data(renderer, data_path, phase, data_id, save_path, cam_nums, res, dis=1.0, is_thuman=False): 99 | obj_path = os.path.join(data_path, phase, data_id, '%s.obj' % data_id) 100 | texture_path = data_path 101 | img_path = os.path.join(texture_path, phase, data_id, 'material0.jpeg') 102 | texture = cv2.imread(img_path)[:, :, ::-1] 103 | texture = np.ascontiguousarray(texture) 104 | texture = texture.swapaxes(0, 1)[:, ::-1, :] 105 | obj = t3.readobj(obj_path, scale=1) 106 | 107 | # height normalization 108 | vy_max = np.max(obj['vi'][:, 1]) 109 | vy_min = np.min(obj['vi'][:, 1]) 110 | human_height = 1.80 + np.random.uniform(-0.05, 0.05, 1) 111 | obj['vi'][:, :3] = obj['vi'][:, :3] / (vy_max - vy_min) * human_height 112 | obj['vi'][:, 1] -= np.min(obj['vi'][:, 1]) 113 | look_at_center = np.array([0, 0.85, 0]) 114 | base_cam_pitch = -8 115 | 116 | # randomly move the scan 117 | move_range = 0.1 if human_height < 1.80 else 0.05 118 | delta_x = np.max(obj['vi'][:, 0]) - np.min(obj['vi'][:, 0]) 119 | delta_z = np.max(obj['vi'][:, 2]) - np.min(obj['vi'][:, 2]) 120 | if delta_x > 1.0 or delta_z > 1.0: 121 | move_range = 0.01 122 | obj['vi'][:, 0] += np.random.uniform(-move_range, move_range, 1) 123 | obj['vi'][:, 2] += np.random.uniform(-move_range, move_range, 1) 124 | 125 | if len(renderer.scene.models) >= 1: 126 | renderer.modify_model(0, obj, texture) 127 | else: 128 | renderer.add_model(obj, texture) 129 | 130 | degree_interval = 360 / cam_nums 131 | angle_list1 = list(range(360-int(degree_interval//2), 360)) 132 | angle_list2 = list(range(0, 0+int(degree_interval//2))) 133 | angle_list = angle_list1 + angle_list2 134 | angle_base = np.random.choice(angle_list, 1)[0] 135 | if is_thuman: 136 | # thuman needs a normalization of orientation 137 | smpl_path = os.path.join(data_path, 'THuman2.0_Smpl_X_Paras', data_id, 'smplx_param.pkl') 138 | with open(smpl_path, 'rb') as f: 139 | smpl_para = pickle.load(f) 140 | 141 | y_orient = smpl_para['global_orient'][0][1] 142 | angle_base += (y_orient*180.0/np.pi) 143 | 144 | for pid in range(cam_nums): 145 | angle = angle_base + pid * degree_interval 146 | 147 | def render(dis, angle, look_at_center, p, renderer, render_hr=False): 148 | ori_vec = np.array([0, 0, dis]) 149 | rotate = np.matmul(rotationY(math.radians(angle)), rotationX(math.radians(p))) 150 | fwd = np.matmul(rotate, ori_vec) 151 | cam_pos = look_at_center + fwd 152 | 153 | x_min = 0 154 | y_min = -25 155 | cx = res[0] * 0.5 156 | cy = res[1] * 0.5 157 | fx = res[0] * 0.8 158 | fy = res[1] * 0.8 159 | _cx = cx - x_min 160 | _cy = cy - y_min 161 | renderer.scene.cameras[0].set_intrinsic(fx, fy, _cx, _cy) 162 | renderer.scene.cameras[0].set(pos=cam_pos, target=look_at_center) 163 | renderer.scene.cameras[0]._init() 164 | 165 | if render_hr: 166 | fx = res[0] * 0.8 * 2 167 | fy = res[1] * 0.8 * 2 168 | _cx = (res[0] * 0.5 - x_min) * 2 169 | _cy = (res[1] * 0.5 - y_min) * 2 170 | renderer.scene.cameras[1].set_intrinsic(fx, fy, _cx, _cy) 171 | renderer.scene.cameras[1].set(pos=cam_pos, target=look_at_center) 172 | renderer.scene.cameras[1]._init() 173 | 174 | renderer.scene.render() 175 | camera = renderer.scene.cameras[0] 176 | camera_hr = renderer.scene.cameras[1] 177 | extrinsic = camera.export_extrinsic() 178 | intrinsic = camera.export_intrinsic() 179 | depth_map = camera.zbuf.to_numpy().swapaxes(0, 1) 180 | img = camera.img.to_numpy().swapaxes(0, 1) 181 | img_hr = camera_hr.img.to_numpy().swapaxes(0, 1) 182 | mask = camera.mask.to_numpy().swapaxes(0, 1) 183 | return extrinsic, intrinsic, depth_map, img, mask, img_hr 184 | 185 | renderer.scene.render() 186 | camera = renderer.scene.cameras[0] 187 | extrinsic = camera.export_extrinsic() 188 | intrinsic = camera.export_intrinsic() 189 | depth_map = camera.zbuf.to_numpy().swapaxes(0, 1) 190 | img = camera.img.to_numpy().swapaxes(0, 1) 191 | mask = camera.mask.to_numpy().swapaxes(0, 1) 192 | return extrinsic, intrinsic, depth_map, img, mask 193 | 194 | 195 | extr, intr, depth, img, mask = render(dis, angle, look_at_center, base_cam_pitch, renderer) 196 | save(pid, data_id, 0, save_path, extr, intr, depth, img, mask) 197 | extr, intr, depth, img, mask = render(dis, (angle+degree_interval) % 360, look_at_center, base_cam_pitch, renderer) 198 | save(pid, data_id, 1, save_path, extr, intr, depth, img, mask) 199 | 200 | # three novel viewpoints between source views 201 | angle1 = (angle + (np.random.uniform() * degree_interval / 2)) % 360 202 | angle2 = (angle + degree_interval / 2) % 360 203 | angle3 = (angle + degree_interval - (np.random.uniform() * degree_interval / 2)) % 360 204 | 205 | extr, intr, depth, img, mask, img_hr = render(dis, angle1, look_at_center, base_cam_pitch, renderer, render_hr=True) 206 | save(pid, data_id, 2, save_path, extr, intr, depth, img, mask, img_hr) 207 | extr, intr, depth, img, mask, img_hr = render(dis, angle2, look_at_center, base_cam_pitch, renderer, render_hr=True) 208 | save(pid, data_id, 3, save_path, extr, intr, depth, img, mask, img_hr) 209 | extr, intr, depth, img, mask, img_hr = render(dis, angle3, look_at_center, base_cam_pitch, renderer, render_hr=True) 210 | save(pid, data_id, 4, save_path, extr, intr, depth, img, mask, img_hr) 211 | 212 | 213 | if __name__ == '__main__': 214 | cam_nums = 16 215 | scene_radius = 2.0 216 | res = (1024, 1024) 217 | thuman_root = 'PATH/TO/THuman2.0' 218 | save_root = 'PATH/TO/SAVE/RENDERED/DATA' 219 | 220 | np.random.seed(1314) 221 | renderer = StaticRenderer(src_res=res) 222 | 223 | for phase in ['train', 'val']: 224 | thuman_list = sorted(os.listdir(os.path.join(thuman_root, phase))) 225 | save_path = os.path.join(save_root, phase) 226 | 227 | for data_id in tqdm(thuman_list): 228 | render_data(renderer, thuman_root, phase, data_id, save_path, cam_nums, res, dis=scene_radius, is_thuman=True) 229 | -------------------------------------------------------------------------------- /prepare_data/taichi_three/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import version as __version__ 2 | from .scene import * 3 | from .model import * 4 | from .scatter import * 5 | from .geometry import * 6 | from .loader import * 7 | from .light import * 8 | -------------------------------------------------------------------------------- /prepare_data/taichi_three/common.py: -------------------------------------------------------------------------------- 1 | class AutoInit: 2 | def init(self): 3 | if not hasattr(self, '_AutoInit_had_init'): 4 | self._init() 5 | self._AutoInit_had_init = True 6 | 7 | def _init(self): 8 | raise NotImplementedError 9 | -------------------------------------------------------------------------------- /prepare_data/taichi_three/geometry.py: -------------------------------------------------------------------------------- 1 | import math 2 | import taichi as ti 3 | import taichi.math as ts 4 | 5 | 6 | @ti.func 7 | def render_triangle(model, camera, face, lights): 8 | scene = model.scene 9 | _1 = ti.static(min(1, model.faces.m - 1)) 10 | _2 = ti.static(min(2, model.faces.m - 1)) 11 | ia, ib, ic = model.vi[face[0, 0]], model.vi[face[1, 0]], model.vi[face[2, 0]] 12 | ca, cb, cc = ia, ib, ic 13 | if model.type[None] >= 1: 14 | ca, cb, cc = model.vc[face[0, 0]], model.vc[face[1, 0]], model.vc[face[2, 0]] 15 | ta, tb, tc = model.vt[face[0, _1]], model.vt[face[1, _1]], model.vt[face[2, _1]] 16 | na, nb, nc = model.vn[face[0, _2]], model.vn[face[1, _2]], model.vn[face[2, _2]] 17 | a = camera.untrans_pos(ia) 18 | b = camera.untrans_pos(ib) 19 | c = camera.untrans_pos(ic) 20 | 21 | # NOTE: the normal computation indicates that # a front-facing face should 22 | # be COUNTER-CLOCKWISE, i.e., glFrontFace(GL_CCW); 23 | # this is to be compatible with obj model loading. 24 | normal = (a - b).cross(a - c).normalized() 25 | pos = (a + b + c) / 3 26 | view_pos = (a + b + c) / 3 27 | shading = (view_pos.normalized().dot(normal) - 0.5)*2.0 28 | if ti.static(camera.type == camera.ORTHO): 29 | view_pos = ti.Vector([0.0, 0.0, 1.0], ti.f32) 30 | reverse_v = 1 31 | if ti.static(model.reverse == True): 32 | reverse_v = -1 33 | if ts.dot(view_pos, normal)*reverse_v <= 0: 34 | # shading 35 | color = ti.Vector([0.0, 0.0, 0.0], ti.f32) 36 | for i in range(lights.n): 37 | light_dir = lights.light_dirs[i] 38 | # print(light_dir) 39 | light_dir = camera.untrans_dir(light_dir) 40 | light_color = lights.light_colors[i] 41 | color += scene.opt.render_func(pos, normal, ti.Vector([0.0, 0.0, 0.0], ti.f32), light_dir, light_color) 42 | color = scene.opt.pre_process(color) 43 | A = camera.uncook(a) 44 | B = camera.uncook(b) 45 | C = camera.uncook(c) 46 | scr_norm = 1 / (A - C).cross(B - A) 47 | B_A = (B - A) * scr_norm 48 | C_B = (C - B) * scr_norm 49 | A_C = (A - C) * scr_norm 50 | 51 | W = 1 52 | # screen space bounding box 53 | M, N = int(ti.floor(min(A, B, C) - W)), int(ti.ceil(max(A, B, C) + W)) 54 | M.x, N.x = min(max(M.x, 0), camera.img.shape[0]), min(max(N.x, 0), camera.img.shape[1]) 55 | M.y, N.y = min(max(M.y, 0), camera.img.shape[0]), min(max(N.y, 0), camera.img.shape[1]) 56 | for X in ti.grouped(ti.ndrange((M.x, N.x), (M.y, N.y))): 57 | # barycentric coordinates using the area method 58 | X_A = X - A 59 | w_C = B_A.cross(X_A) 60 | w_B = A_C.cross(X_A) 61 | w_A = 1 - w_C - w_B 62 | # draw 63 | in_screen = w_A >= 0 and w_B >= 0 and w_C >= 0 and 0 < X[0] < camera.img.shape[0] and 0 < X[1] < camera.img.shape[1] 64 | if not in_screen: 65 | continue 66 | zindex = 0.0 67 | if ti.static(model.reverse == True): 68 | zindex = (a.z * w_A + b.z * w_B + c.z * w_C) 69 | else: 70 | zindex = 1 / (a.z * w_A + b.z * w_B + c.z * w_C) 71 | if zindex < ti.atomic_max(camera.zbuf[X], zindex): 72 | continue 73 | 74 | coor = (ta * w_A + tb * w_B + tc * w_C) 75 | if model.type[None] == 1: 76 | camera.img[X] = (ca * w_A + cb * w_B + cc * w_C) 77 | elif model.type[None] == 2: 78 | camera.img[X] = color * (ca * w_A + cb * w_B + cc * w_C) 79 | else: 80 | camera.img[X] = color * model.texSample(coor) 81 | camera.mask[X] = ti.Vector([1.0, 1.0, 1.0], ti.f32) 82 | camera.normal_map[X] = -normal / 2.0 + 0.5 83 | # camera.normal_map[X] = ts.vec3(shading / 2.0 + 0.5) 84 | 85 | @ti.func 86 | def render_particle(model, camera, vertex, radius): 87 | scene = model.scene 88 | L2W = model.L2W 89 | a = camera.untrans_pos(L2W @ vertex) 90 | A = camera.uncook(a) 91 | 92 | M = int(ti.floor(A - radius)) 93 | N = int(ti.ceil(A + radius)) 94 | 95 | for X in ti.grouped(ti.ndrange((M.x, N.x), (M.y, N.y))): 96 | if X.x < 0 or X.x >= camera.res[0] or X.y < 0 or X.y >= camera.res[1]: 97 | continue 98 | if (X - A).norm_sqr() > radius**2: 99 | continue 100 | 101 | camera.img[X] = ti.Vector([1.0, 1.0, 1.0], ti.f32) 102 | -------------------------------------------------------------------------------- /prepare_data/taichi_three/light.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | from .common import * 3 | import math 4 | 5 | ''' 6 | The base light class represents a directional light. 7 | ''' 8 | @ti.data_oriented 9 | class Light(AutoInit): 10 | 11 | def __init__(self, dir=None, color=None): 12 | dir = dir or [0, 0, 1] 13 | norm = math.sqrt(sum(x ** 2 for x in dir)) 14 | dir = [x / norm for x in dir] 15 | 16 | self.dir_py = [-x for x in dir] 17 | self.color_py = color or [1, 1, 1] 18 | 19 | self.dir = ti.Vector.field(3, ti.float32, ()) 20 | self.color = ti.Vector.field(3, ti.float32, ()) 21 | # store the current light direction in the view space 22 | # so that we don't have to compute it for each vertex 23 | self.viewdir = ti.Vector.field(3, ti.float32, ()) 24 | 25 | def set(self, dir=[0, 0, 1], color=[1, 1, 1]): 26 | norm = math.sqrt(sum(x**2 for x in dir)) 27 | dir = [x / norm for x in dir] 28 | self.dir_py = dir 29 | self.color = color 30 | 31 | def _init(self): 32 | self.dir[None] = self.dir_py 33 | self.color[None] = self.color_py 34 | 35 | @ti.func 36 | def intensity(self, pos): 37 | return 1 38 | 39 | @ti.func 40 | def get_color(self): 41 | return self.color[None] 42 | 43 | @ti.func 44 | def get_dir(self): 45 | return self.viewdir 46 | 47 | @ti.func 48 | def set_view(self, camera): 49 | self.viewdir[None] = camera.untrans_dir(self.dir[None]) 50 | 51 | 52 | @ti.data_oriented 53 | class Lights(AutoInit): 54 | 55 | def __init__(self, light_list): 56 | n = len(light_list) 57 | self.n = n 58 | self.light_list = light_list 59 | self.light_dirs = ti.Vector.field(3, ti.float32, n) 60 | self.light_colors = ti.Vector.field(3, ti.float32, n) 61 | 62 | def init_data(self): 63 | index = 0 64 | for light in self.light_list: 65 | self.light_dirs[index] = self.light_list[index].dir_py 66 | self.light_colors[index] = self.light_list[index].color_py 67 | index += 1 68 | 69 | class PointLight(Light): 70 | pass 71 | 72 | -------------------------------------------------------------------------------- /prepare_data/taichi_three/loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | 5 | def _append(faces, indices): 6 | if len(indices) == 4: 7 | faces.append([indices[0], indices[1], indices[2]]) 8 | faces.append([indices[2], indices[3], indices[0]]) 9 | elif len(indices) == 3: 10 | faces.append(indices) 11 | else: 12 | assert False, len(indices) 13 | 14 | 15 | def readobj(path, scale=1): 16 | vi = [] 17 | vt = [] 18 | vn = [] 19 | faces = [] 20 | 21 | with open(path, 'r') as myfile: 22 | lines = myfile.readlines() 23 | 24 | # cache vertices 25 | for line in lines: 26 | try: 27 | type, fields = line.split(maxsplit=1) 28 | fields = [float(_) for _ in fields.split()] 29 | except ValueError: 30 | continue 31 | 32 | if type == 'v': 33 | vi.append(fields) 34 | elif type == 'vt': 35 | vt.append(fields) 36 | elif type == 'vn': 37 | vn.append(fields) 38 | 39 | # cache faces 40 | for line in lines: 41 | try: 42 | type, fields = line.split(maxsplit=1) 43 | fields = fields.split() 44 | except ValueError: 45 | continue 46 | 47 | # line looks like 'f 5/1/1 1/2/1 4/3/1' 48 | # or 'f 314/380/494 382/400/494 388/550/494 506/551/494' for quads 49 | if type != 'f': 50 | continue 51 | 52 | # a field should look like '5/1/1' 53 | # for vertex/vertex UV coords/vertex Normal (indexes number in the list) 54 | # the index in 'f 5/1/1 1/2/1 4/3/1' STARTS AT 1 !!! 55 | 56 | indices = [[int(_) - 1 if _ != '' else 0 for _ in field.split('/')] for field in fields] 57 | 58 | if len(indices) == 4: 59 | faces.append([indices[0], indices[1], indices[2]]) 60 | faces.append([indices[2], indices[3], indices[0]]) 61 | elif len(indices) == 3: 62 | faces.append(indices) 63 | else: 64 | assert False, len(indices) 65 | 66 | ret = {} 67 | ret['vi'] = None if len(vi) == 0 else np.array(vi).astype(np.float32) * scale 68 | ret['vt'] = None if len(vt) == 0 else np.array(vt).astype(np.float32) 69 | ret['vn'] = None if len(vn) == 0 else np.array(vn).astype(np.float32) 70 | ret['f'] = None if len(faces) == 0 else np.array(faces).astype(np.int32) 71 | return ret 72 | -------------------------------------------------------------------------------- /prepare_data/taichi_three/meshgen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import taichi as ti 3 | import taichi_glsl as tl 4 | import math 5 | 6 | 7 | def _pre(x): 8 | if not isinstance(x, ti.Matrix): 9 | x = ti.Vector(x) 10 | return x 11 | 12 | 13 | def _ser(foo): 14 | def wrapped(self, *args, **kwargs): 15 | foo(self, *args, **kwargs) 16 | return self 17 | 18 | return wrapped 19 | 20 | 21 | def _mparg(foo): 22 | def wrapped(self, *args): 23 | if len(args) > 1: 24 | return [foo(self, x) for x in args] 25 | else: 26 | return foo(self, args[0]) 27 | 28 | return wrapped 29 | 30 | 31 | class MeshGen: 32 | def __init__(self): 33 | self.v = [] 34 | self.f = [] 35 | 36 | @_ser 37 | def quad(self, a, b, c, d): 38 | a, b, c, d = self.add_v(a, b, c, d) 39 | self.add_f([a, b, c], [c, d, a]) 40 | 41 | @_ser 42 | def cube(self, a, b): 43 | aaa = self.add_v(tl.mix(a, b, tl.D.yyy)) 44 | baa = self.add_v(tl.mix(a, b, tl.D.xyy)) 45 | aba = self.add_v(tl.mix(a, b, tl.D.yxy)) 46 | aab = self.add_v(tl.mix(a, b, tl.D.yyx)) 47 | bba = self.add_v(tl.mix(a, b, tl.D.xxy)) 48 | abb = self.add_v(tl.mix(a, b, tl.D.yxx)) 49 | bab = self.add_v(tl.mix(a, b, tl.D.xyx)) 50 | bbb = self.add_v(tl.mix(a, b, tl.D.xxx)) 51 | 52 | self.add_f4([aaa, aba, bba, baa]) # back 53 | self.add_f4([aab, bab, bbb, abb]) # front 54 | self.add_f4([aaa, aab, abb, aba]) # left 55 | self.add_f4([baa, bba, bbb, bab]) # right 56 | self.add_f4([aaa, baa, bab, aab]) # bottom 57 | self.add_f4([aba, abb, bbb, bba]) # top 58 | 59 | @_ser 60 | def cylinder(self, bottom, top, dir1, dir2, N): 61 | bottom = _pre(bottom) 62 | top = _pre(top) 63 | dir1 = _pre(dir1) 64 | dir2 = _pre(dir2) 65 | 66 | B, T = [], [] 67 | for i in range(N): 68 | disp = tl.mat(dir1.entries, dir2.entries).T() @ tl.vecAngle(tl.math.tau * i / N) 69 | B.append(self.add_v(bottom + disp)) 70 | T.append(self.add_v(top + disp)) 71 | 72 | BC = self.add_v(bottom) 73 | TC = self.add_v(top) 74 | 75 | for i in range(N): 76 | j = (i + 1) % N 77 | self.add_f4([B[i], B[j], T[j], T[i]]) 78 | 79 | for i in range(N): 80 | j = (i + 1) % N 81 | self.add_f([B[j], B[i], BC]) 82 | self.add_f([T[i], T[j], TC]) 83 | 84 | @_ser 85 | def tri(self, a, b, c): 86 | a, b, c = self.add_v(a, b, c) 87 | self.add_f([a, b, c]) 88 | 89 | @_mparg 90 | def add_v(self, v): 91 | if isinstance(v, ti.Matrix): 92 | v = v.entries 93 | ret = len(self.v) 94 | self.v.append(v) 95 | return ret 96 | 97 | @_mparg 98 | def add_f(self, f): 99 | ret = len(self.f) 100 | self.f.append(f) 101 | return ret 102 | 103 | @_mparg 104 | def add_f4(self, f): 105 | a, b, c, d = f 106 | return self.add_f([a, b, c], [c, d, a]) 107 | 108 | 109 | def __getitem__(self, key): 110 | if key == 'v': 111 | return np.array(self.v) 112 | if key == 'f': 113 | return np.array(self.f) 114 | -------------------------------------------------------------------------------- /prepare_data/taichi_three/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import taichi as ti 3 | import taichi.math as ts 4 | from .geometry import * 5 | from .transform import * 6 | from .common import * 7 | import math 8 | 9 | 10 | @ti.func 11 | def sample(field: ti.template(), P): 12 | ''' 13 | Sampling a field with indices clampped into the field shape. 14 | :parameter field: (Tensor) 15 | Specify the field to sample. 16 | :parameter P: (Vector) 17 | Specify the index in field. 18 | :return: 19 | The return value is calcuated as:: 20 | P = clamp(P, 0, vec(*field.shape) - 1) 21 | return field[int(P)] 22 | ''' 23 | shape = ti.Vector(field.shape) 24 | P = ts.clamp(P, 0, shape - 1) 25 | return field[int(P)] 26 | 27 | @ti.func 28 | def bilerp(field: ti.template(), P): 29 | ''' 30 | Bilinear sampling an 2D field with a real index. 31 | :parameter field: (2D Tensor) 32 | Specify the field to sample. 33 | :parameter P: (2D Vector of float) 34 | Specify the index in field. 35 | :note: 36 | If one of the element to be accessed is out of `field.shape`, then 37 | `bilerp` will automatically do a clamp for you, see :func:`sample`. 38 | :return: 39 | The return value is calcuated as:: 40 | I = int(P) 41 | x = fract(P) 42 | y = 1 - x 43 | return (sample(field, I + D.xx) * x.x * x.y + 44 | sample(field, I + D.xy) * x.x * y.y + 45 | sample(field, I + D.yy) * y.x * y.y + 46 | sample(field, I + D.yx) * y.x * x.y) 47 | .. where D = vec(1, 0, -1) 48 | ''' 49 | I = int(P) 50 | x = ts.fract(P) 51 | y = 1 - x 52 | D = ts.vec3(1, 0, -1) 53 | return (sample(field, I + D.xx) * x.x * x.y + 54 | sample(field, I + D.xy) * x.x * y.y + 55 | sample(field, I + D.yy) * y.x * y.y + 56 | sample(field, I + D.yx) * y.x * x.y) 57 | 58 | @ti.data_oriented 59 | class Model(AutoInit): 60 | TEX = 0 61 | COLOR = 1 62 | 63 | def __init__(self, f_n=None, f_m=None, 64 | vi_n=None, vt_n=None, vn_n=None, tex_n=None, col_n=None, 65 | obj=None, tex=None): 66 | 67 | self.faces = None 68 | self.vi = None 69 | self.vt = None 70 | self.vn = None 71 | self.tex = None 72 | self.type = ti.field(dtype=ti.int32, shape=()) 73 | self.reverse = False 74 | 75 | if obj is not None: 76 | f_n = None if obj['f'] is None else obj['f'].shape[0] 77 | vi_n = None if obj['vi'] is None else obj['vi'].shape[0] 78 | vt_n = None if obj['vt'] is None else obj['vt'].shape[0] 79 | vn_n = None if obj['vn'] is None else obj['vn'].shape[0] 80 | 81 | if tex is not None: 82 | tex_n = tex.shape[:2] 83 | 84 | if f_m is None: 85 | f_m = 1 86 | if vt_n is not None: 87 | f_m = 2 88 | if vn_n is not None: 89 | f_m = 3 90 | 91 | if vi_n is None: 92 | vi_n = 1 93 | if vt_n is None: 94 | vt_n = 1 95 | if vn_n is None: 96 | vn_n = 1 97 | if col_n is None: 98 | col_n = 1 99 | 100 | if f_n is not None: 101 | self.faces = ti.Matrix.field(3, f_m, ti.i32, f_n) 102 | if vi_n is not None: 103 | self.vi = ti.Vector.field(3, ti.f32, vi_n) 104 | if vt_n is not None: 105 | self.vt = ti.Vector.field(2, ti.f32, vt_n) 106 | if vn_n is not None: 107 | self.vn = ti.Vector.field(3, ti.f32, vn_n) 108 | if tex_n is not None: 109 | self.tex = ti.Vector.field(3, ti.f32, tex_n) 110 | if col_n is not None: 111 | self.vc = ti.Vector.field(3, ti.f32, col_n) 112 | 113 | if obj is not None: 114 | self.init_obj = obj 115 | if tex is not None: 116 | self.init_tex = tex 117 | 118 | def from_obj(self, obj): 119 | if obj['f'] is not None: 120 | self.faces.from_numpy(obj['f']) 121 | if obj['vi'] is not None: 122 | self.vi.from_numpy(obj['vi']) 123 | if obj['vt'] is not None: 124 | self.vt.from_numpy(obj['vt']) 125 | if obj['vn'] is not None: 126 | self.vn.from_numpy(obj['vn']) 127 | 128 | def _init(self): 129 | self.type[None] = 0 130 | if hasattr(self, 'init_obj'): 131 | self.from_obj(self.init_obj) 132 | if hasattr(self, 'init_tex'): 133 | self.tex.from_numpy(self.init_tex.astype(np.float32) / 255) 134 | 135 | @ti.func 136 | def render(self, camera): 137 | for i in ti.grouped(self.faces): 138 | render_triangle(self, camera, self.faces[i]) 139 | 140 | @ti.func 141 | def texSample(self, coor): 142 | if ti.static(self.tex is not None): 143 | return ts.bilerp(self.tex, coor * ts.vec(*self.tex.shape)) 144 | else: 145 | return 1 146 | 147 | 148 | @ti.data_oriented 149 | class StaticModel(AutoInit): 150 | TEX = 0 151 | COLOR = 1 152 | 153 | def __init__(self, N, f_m=None, col_n=None, 154 | obj=None, tex=None): 155 | self.faces = None 156 | self.vi = None 157 | self.vt = None 158 | self.vn = None 159 | self.tex = None 160 | # 0 origin 1 pure color 2 shader color 161 | self.type = ti.field(dtype=ti.int32, shape=()) 162 | self.f_n = ti.field(dtype=ti.int32, shape=()) 163 | self.reverse = False 164 | self.N = N 165 | 166 | if obj is not None: 167 | f_n = None if obj['f'] is None else obj['f'].shape[0] 168 | vi_n = None if obj['vi'] is None else obj['vi'].shape[0] 169 | vt_n = None if obj['vt'] is None else obj['vt'].shape[0] 170 | vn_n = None if obj['vn'] is None else obj['vn'].shape[0] 171 | 172 | if not (tex is None): 173 | tex_n = tex.shape[:2] 174 | else: 175 | tex_n = None 176 | 177 | if f_m is None: 178 | f_m = 1 179 | if vt_n is not None: 180 | f_m = 2 181 | if vn_n is not None: 182 | f_m = 3 183 | 184 | if vi_n is None: 185 | vi_n = 1 186 | if vt_n is None: 187 | vt_n = 1 188 | if vn_n is None: 189 | vn_n = 1 190 | if col_n is None: 191 | col_n = 1 192 | 193 | if f_n is not None: 194 | self.faces = ti.Matrix.field(3, f_m, ti.i32, N) 195 | if vi_n is not None: 196 | self.vi = ti.Vector.field(3, ti.f32, N) 197 | if vt_n is not None: 198 | self.vt = ti.Vector.field(2, ti.f32, N) 199 | if vn_n is not None: 200 | self.vn = ti.Vector.field(3, ti.f32, N) 201 | if not (tex_n is None): 202 | self.tex = ti.Vector.field(3, ti.f32, tex_n) 203 | if col_n is not None: 204 | self.vc = ti.Vector.field(3, ti.f32, N) 205 | 206 | if obj is not None: 207 | self.init_obj = obj 208 | if tex is not None: 209 | self.init_tex = tex 210 | 211 | def modify_color(self, color): 212 | s_color = np.zeros((self.N, 3)).astype(np.float32) 213 | s_color[:color.shape[0]] = color 214 | self.vc.from_numpy(s_color) 215 | 216 | def from_obj(self, obj): 217 | N = self.N 218 | if obj['f'] is not None: 219 | s_faces = np.zeros((N, obj['f'].shape[1], obj['f'].shape[2])).astype(int) 220 | s_faces[:obj['f'].shape[0]] = obj['f'] 221 | self.f_n[None] = obj['f'].shape[0] 222 | self.faces.from_numpy(s_faces) 223 | if obj['vi'] is not None: 224 | s_vi = np.zeros((N, 3)).astype(np.float32) 225 | s_vi[:obj['vi'].shape[0]] = obj['vi'][:, :3] 226 | self.vi.from_numpy(s_vi) 227 | if obj['vt'] is not None: 228 | s_vt = np.zeros((N, 2)).astype(np.float32) 229 | s_vt[:obj['vt'].shape[0]] = obj['vt'] 230 | self.vt.from_numpy(s_vt) 231 | if obj['vn'] is not None: 232 | s_vn = np.zeros((N, 3)).astype(np.float32) 233 | s_vn[:obj['vn'].shape[0]] = obj['vn'] 234 | self.vn.from_numpy(s_vn) 235 | 236 | def _init(self): 237 | self.type[None] = 0 238 | if hasattr(self, 'init_obj'): 239 | self.from_obj(self.init_obj) 240 | if hasattr(self, 'init_tex') and (self.init_tex is not None): 241 | self.tex.from_numpy(self.init_tex.astype(np.float32) / 255) 242 | 243 | @ti.func 244 | def render(self, camera, lights): 245 | for i in ti.grouped(self.faces): 246 | render_triangle(self, camera, self.faces[i], lights) 247 | 248 | @ti.func 249 | def texSample(self, coor): 250 | if ti.static(self.tex is not None): 251 | return bilerp(self.tex, coor * ts.vec2(*self.tex.shape)) 252 | else: 253 | return 1 254 | -------------------------------------------------------------------------------- /prepare_data/taichi_three/raycast.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | import taichi.math as ts 3 | from .scene import * 4 | import math 5 | 6 | 7 | EPS = 1e-3 8 | INF = 1e3 9 | 10 | 11 | @ti.data_oriented 12 | class ObjectRT(ts.TaichiClass): 13 | @ti.func 14 | def calc_sdf(self, p): 15 | ret = INF 16 | for I in ti.grouped(ti.ndrange(*self.pos.shape())): 17 | ret = min(ret, self.make_one(I).do_calc_sdf(p)) 18 | return ret 19 | 20 | @ti.func 21 | def intersect(self, orig, dir): 22 | ret, normal = INF, ts.vec3(0.0) 23 | for I in ti.grouped(ti.ndrange(*self.pos.shape())): 24 | t, n = self.make_one(I).do_intersect(orig, dir) 25 | if t < ret: 26 | ret, normal = t, n 27 | return ret, normal 28 | 29 | def do_calc_sdf(self, p): 30 | raise NotImplementedError 31 | 32 | def do_intersect(self, orig, dir): 33 | raise NotImplementedError 34 | 35 | 36 | @ti.data_oriented 37 | class Ball(ObjectRT): 38 | def __init__(self, pos, radius): 39 | self.pos = pos 40 | self.radius = radius 41 | 42 | @ti.func 43 | def make_one(self, I): 44 | return Ball(self.pos[I], self.radius[I]) 45 | 46 | @ti.func 47 | def do_calc_sdf(self, p): 48 | return ts.distance(self.pos, p) - self.radius 49 | 50 | @ti.func 51 | def do_intersect(self, orig, dir): 52 | op = self.pos - orig 53 | b = op.dot(dir) 54 | det = b ** 2 - op.norm_sqr() + self.radius ** 2 55 | ret = INF 56 | if det > 0.0: 57 | det = ti.sqrt(det) 58 | t = b - det 59 | if t > EPS: 60 | ret = t 61 | else: 62 | t = b + det 63 | if t > EPS: 64 | ret = t 65 | return ret, ts.normalize(dir * ret - op) 66 | 67 | 68 | @ti.data_oriented 69 | class SceneRTBase(Scene): 70 | def __init__(self): 71 | super(SceneRTBase, self).__init__() 72 | self.balls = [] 73 | 74 | def trace(self, pos, dir): 75 | raise NotImplementedError 76 | 77 | @ti.func 78 | def color_at(self, coor, camera): 79 | orig, dir = camera.generate(coor) 80 | 81 | pos, normal = self.trace(orig, dir) 82 | light_dir = self.light_dir[None] 83 | 84 | color = self.opt.render_func(pos, normal, dir, light_dir) 85 | color = self.opt.pre_process(color) 86 | return color 87 | 88 | @ti.kernel 89 | def _render(self): 90 | if ti.static(len(self.cameras)): 91 | for camera in ti.static(self.cameras): 92 | for I in ti.grouped(camera.img): 93 | coor = self.cook_coor(I, camera) 94 | color = self.color_at(coor, camera) 95 | camera.img[I] = color 96 | 97 | def add_ball(self, pos, radius): 98 | b = Ball(pos, radius) 99 | self.balls.append(b) 100 | 101 | 102 | @ti.data_oriented 103 | class SceneRT(SceneRTBase): 104 | @ti.func 105 | def intersect(self, orig, dir): 106 | ret, normal = INF, ts.vec3(0.0) 107 | for b in ti.static(self.balls): 108 | t, n = b.intersect(orig, dir) 109 | if t < ret: 110 | ret, normal = t, n 111 | return ret, normal 112 | 113 | @ti.func 114 | def trace(self, orig, dir): 115 | depth, normal = self.intersect(orig, dir) 116 | pos = orig + dir * depth 117 | return pos, normal 118 | 119 | 120 | @ti.data_oriented 121 | class SceneSDF(SceneRTBase): 122 | @ti.func 123 | def calc_sdf(self, p): 124 | ret = INF 125 | for b in ti.static(self.balls): 126 | ret = min(ret, b.calc_sdf(p)) 127 | return ret 128 | 129 | @ti.func 130 | def calc_grad(self, p): 131 | return ts.vec( 132 | self.calc_sdf(p + ts.vec(EPS, 0, 0)), 133 | self.calc_sdf(p + ts.vec(0, EPS, 0)), 134 | self.calc_sdf(p + ts.vec(0, 0, EPS))) 135 | 136 | @ti.func 137 | def trace(self, orig, dir): 138 | pos = orig 139 | color = ts.vec3(0.0) 140 | normal = ts.vec3(0.0) 141 | for s in range(100): 142 | t = self.calc_sdf(pos) 143 | if t <= 0: 144 | normal = ts.normalize(self.calc_grad(pos) - t) 145 | break 146 | pos += dir * t 147 | return pos, normal 148 | -------------------------------------------------------------------------------- /prepare_data/taichi_three/scatter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import taichi as ti 3 | import taichi.math as ts 4 | from .geometry import * 5 | from .transform import * 6 | from .common import * 7 | import math 8 | 9 | 10 | @ti.data_oriented 11 | class ScatterModel(AutoInit): 12 | def __init__(self, num=None, radius=2): 13 | self.L2W = Affine.field(()) 14 | 15 | self.num = num 16 | self.radius = radius 17 | 18 | if num is not None: 19 | self.particles = ti.Vector.field(3, ti.i32, num) 20 | 21 | def _init(self): 22 | self.L2W.init() 23 | 24 | @ti.func 25 | def render(self, camera): 26 | for i in ti.grouped(self.particles): 27 | render_particle(self, camera, self.particles[i], self.radius) 28 | -------------------------------------------------------------------------------- /prepare_data/taichi_three/scene.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | from .transform import * 3 | from .shading import * 4 | from .light import * 5 | 6 | 7 | @ti.data_oriented 8 | class Scene(AutoInit): 9 | def __init__(self): 10 | self.lights = [] 11 | self.cameras = [] 12 | self.opt = Shading() 13 | self.models = [] 14 | 15 | @ti.func 16 | def cook_coor(self, I, camera): 17 | scale = ti.static(2 / min(*camera.img.shape())) 18 | coor = (I - ts.vec2(*camera.img.shape()) / 2) * scale 19 | return coor 20 | 21 | @ti.func 22 | def uncook_coor(self, coor, camera): 23 | scale = ti.static(min(*camera.img.shape()) / 2) 24 | I = coor.xy * scale + ts.vec2(*camera.img.shape()) / 2 25 | return I 26 | 27 | def add_model(self, model): 28 | model.scene = self 29 | self.models.append(model) 30 | 31 | def add_camera(self, camera): 32 | camera.scene = self 33 | self.cameras.append(camera) 34 | 35 | def add_lights(self, lights): 36 | lights.scene = self 37 | self.lights = lights 38 | 39 | def _init(self): 40 | for camera in self.cameras: 41 | camera.init() 42 | for model in self.models: 43 | model.init() 44 | self.lights.init_data() 45 | 46 | 47 | @ti.kernel 48 | def _single_render(self, num : ti.template()): 49 | if ti.static(len(self.cameras)): 50 | for camera in ti.static(self.cameras): 51 | camera.clear_buffer() 52 | self.models[num].render(camera, self.lights) 53 | 54 | def single_render(self, num): 55 | self.lights.init_data() 56 | for camera in self.cameras: 57 | camera.init() 58 | self.models[num].init() 59 | self._single_render(num) 60 | 61 | def render(self): 62 | self.init() 63 | self._render() 64 | 65 | @ti.kernel 66 | def _render(self): 67 | if ti.static(len(self.cameras)): 68 | for camera in ti.static(self.cameras): 69 | camera.clear_buffer() 70 | # sets up light directions 71 | if ti.static(len(self.models)): 72 | for model in ti.static(self.models): 73 | model.render(camera, self.lights) 74 | -------------------------------------------------------------------------------- /prepare_data/taichi_three/shading.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | import taichi.math as tm 3 | from .transform import * 4 | import math 5 | 6 | class Shading: 7 | def __init__(self, **kwargs): 8 | self.is_normal_map = False 9 | self.lambert = 0.58 10 | self.half_lambert = 0.04 11 | self.blinn_phong = 0.3 12 | self.phong = 0.0 13 | self.shineness = 10 14 | self.__dict__.update(kwargs) 15 | 16 | @ti.func 17 | def render_func(self, pos, normal, dir, light_dir, light_color): 18 | color = ti.Vector([0.0, 0.0, 0.0], ti.f32) 19 | shineness = self.shineness 20 | half_lambert = normal.dot(light_dir) * 0.5 + 0.5 21 | lambert = max(0, normal.dot(light_dir)) 22 | blinn_phong = normal.dot(tm.mix(light_dir, -dir, 0.5)) 23 | blinn_phong = pow(max(blinn_phong, 0), shineness) 24 | refl_dir = tm.reflect(light_dir, normal) 25 | phong = -tm.dot(normal, refl_dir) 26 | phong = pow(max(phong, 0), shineness) 27 | 28 | strength = 0.0 29 | if ti.static(self.lambert != 0.0): 30 | strength += lambert * self.lambert 31 | if ti.static(self.half_lambert != 0.0): 32 | strength += half_lambert * self.half_lambert 33 | if ti.static(self.blinn_phong != 0.0): 34 | strength += blinn_phong * self.blinn_phong 35 | if ti.static(self.phong != 0.0): 36 | strength += phong * self.phong 37 | color = tm.vec3(strength) 38 | 39 | if ti.static(self.is_normal_map): 40 | color = normal * 0.5 + 0.5 41 | return color * light_color 42 | 43 | @ti.func 44 | def pre_process(self, color): 45 | blue = tm.vec3(0.00, 0.01, 0.05) 46 | orange = tm.vec3(1.19, 1.04, 0.98) 47 | return ti.sqrt(ts.mix(blue, orange, color)) 48 | -------------------------------------------------------------------------------- /prepare_data/taichi_three/transform.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | import taichi.math as ts 3 | from .common import * 4 | import math 5 | 6 | 7 | def rotationX(angle): 8 | return [ 9 | [1, 0, 0], 10 | [0, math.cos(angle), -math.sin(angle)], 11 | [0, math.sin(angle), math.cos(angle)], 12 | ] 13 | 14 | def rotationY(angle): 15 | return [ 16 | [ math.cos(angle), 0, math.sin(angle)], 17 | [ 0, 1, 0], 18 | [-math.sin(angle), 0, math.cos(angle)], 19 | ] 20 | 21 | def rotationZ(angle): 22 | return [ 23 | [math.cos(angle), -math.sin(angle), 0], 24 | [math.sin(angle), math.cos(angle), 0], 25 | [ 0, 0, 1], 26 | ] 27 | 28 | 29 | @ti.data_oriented 30 | class Affine(AutoInit): 31 | @property 32 | def matrix(self): 33 | return self.entries[0] 34 | 35 | @property 36 | def offset(self): 37 | return self.entries[1] 38 | 39 | @classmethod 40 | def _field(cls, shape=None): 41 | return ti.Matrix.field(3, 3, ti.f32, shape), ti.Vector.field(3, ti.f32, shape) 42 | 43 | @ti.func 44 | def loadIdentity(self): 45 | self.matrix = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] 46 | self.offset = [0, 0, 0] 47 | 48 | @ti.kernel 49 | def _init(self): 50 | self.loadIdentity() 51 | 52 | @ti.func 53 | def __matmul__(self, other): 54 | return self.matrix @ other + self.offset 55 | 56 | @ti.func 57 | def inverse(self): 58 | # TODO: incorrect: 59 | return Affine(self.matrix.inverse(), -self.offset) 60 | 61 | def loadOrtho(self, fwd=[0, 0, 1], up=[0, 1, 0]): 62 | # fwd = target - pos 63 | # fwd = fwd.normalized() 64 | fwd_len = math.sqrt(sum(x**2 for x in fwd)) 65 | fwd = [x / fwd_len for x in fwd] 66 | # right = fwd.cross(up) 67 | right = [ 68 | fwd[2] * up[1] - fwd[1] * up[2], 69 | fwd[0] * up[2] - fwd[2] * up[0], 70 | fwd[1] * up[0] - fwd[0] * up[1], 71 | ] 72 | # right = right.normalized() 73 | right_len = math.sqrt(sum(x**2 for x in right)) 74 | right = [x / right_len for x in right] 75 | # up = right.cross(fwd) 76 | up = [ 77 | right[2] * fwd[1] - right[1] * fwd[2], 78 | right[0] * fwd[2] - right[2] * fwd[0], 79 | right[1] * fwd[0] - right[0] * fwd[1], 80 | ] 81 | 82 | # trans = ti.Matrix.cols([right, up, fwd]) 83 | trans = [right, up, fwd] 84 | trans = [[trans[i][j] for i in range(3)] for j in range(3)] 85 | self.matrix[None] = trans 86 | 87 | def from_mouse(self, mpos): 88 | if isinstance(mpos, ti.GUI): 89 | if mpos.is_pressed(ti.GUI.LMB): 90 | mpos = mpos.get_cursor_pos() 91 | else: 92 | mpos = (0, 0) 93 | a, t = mpos 94 | if a != 0 or t != 0: 95 | a, t = a * math.tau - math.pi, t * math.pi - math.pi / 2 96 | c = math.cos(t) 97 | self.loadOrtho(fwd=[c * math.sin(a), math.sin(t), c * math.cos(a)]) 98 | 99 | 100 | @ti.data_oriented 101 | class Camera(AutoInit): 102 | ORTHO = 'Orthogonal' 103 | TAN_FOV = 'Tangent Perspective' # rectilinear perspective 104 | COS_FOV = 'Cosine Perspective' # curvilinear perspective, see en.wikipedia.org/wiki/Curvilinear_perspective 105 | 106 | def __init__(self, res=None, fx=None, fy=None, cx=None, cy=None, 107 | pos=[0, 0, -2], target=[0, 0, 0], up=[0, -1, 0], fov=30): 108 | self.res = res or (512, 512) 109 | self.img = ti.Vector.field(3, ti.f32, self.res) 110 | self.zbuf = ti.field(ti.f32, self.res) 111 | self.mask = ti.Vector.field(3, ti.f32, self.res) 112 | self.normal_map = ti.Vector.field(3, ti.f32, self.res) 113 | self.trans = ti.Matrix.field(3, 3, ti.f32, ()) 114 | self.pos = ti.Vector.field(3, ti.f32, ()) 115 | self.target = ti.Vector.field(3, ti.f32, ()) 116 | self.intrinsic = ti.Matrix.field(3, 3, ti.f32, ()) 117 | self.type = self.TAN_FOV 118 | self.fov = math.radians(fov) 119 | 120 | self.cx = cx or self.res[0] // 2 121 | self.cy = cy or self.res[1] // 2 122 | self.fx = fx or self.cx / math.tan(self.fov) 123 | self.fy = fy or self.cy / math.tan(self.fov) 124 | # python scope camera transformations 125 | self.pos_py = pos 126 | self.target_py = target 127 | self.trans_py = None 128 | self.up_py = up 129 | self.set(init=True) 130 | # mouse position for camera control 131 | self.mpos = (0, 0) 132 | 133 | def set_intrinsic(self, fx=None, fy=None, cx=None, cy=None): 134 | # see http://ais.informatik.uni-freiburg.de/teaching/ws09/robotics2/pdfs/rob2-08-camera-calibration.pdf 135 | self.fx = fx or self.fx 136 | self.fy = fy or self.fy 137 | self.cx = self.cx if cx is None else cx 138 | self.cy = self.cy if cy is None else cy 139 | 140 | ''' 141 | NOTE: taichi_three uses a LEFT HANDED coordinate system. 142 | that is, the +Z axis points FROM the camera TOWARDS the scene, 143 | with X, Y being device coordinates 144 | ''' 145 | def set(self, pos=None, target=None, up=None, init=False): 146 | pos = self.pos_py if pos is None else pos 147 | target = self.target_py if target is None else target 148 | up = self.up_py if up is None else up 149 | # fwd = target - pos 150 | fwd = [target[i] - pos[i] for i in range(3)] 151 | # fwd = fwd.normalized() 152 | fwd_len = math.sqrt(sum(x**2 for x in fwd)) 153 | fwd = [x / fwd_len for x in fwd] 154 | # right = fwd.cross(up) 155 | # print(fwd, up) 156 | # left = [ 157 | # fwd[2] * up[1] - fwd[1] * up[2], 158 | # fwd[0] * up[2] - fwd[2] * up[0], 159 | # fwd[1] * up[0] - fwd[0] * up[1], 160 | # ] 161 | right = [ 162 | fwd[2] * up[1] - fwd[1] * up[2], 163 | fwd[0] * up[2] - fwd[2] * up[0], 164 | fwd[1] * up[0] - fwd[0] * up[1], 165 | ] 166 | right_len = math.sqrt(sum(x**2 for x in right)) 167 | right = [x / right_len for x in right] 168 | # up = right.cross(fwd) 169 | up = [ 170 | right[2] * fwd[1] - right[1] * fwd[2], 171 | right[0] * fwd[2] - right[2] * fwd[0], 172 | right[1] * fwd[0] - right[0] * fwd[1], 173 | ] 174 | # up = [ 175 | # -left[2] * fwd[1] + left[1] * fwd[2], 176 | # -left[0] * fwd[2] + left[2] * fwd[0], 177 | # -left[1] * fwd[0] + left[0] * fwd[1], 178 | # ] 179 | # trans = ti.Matrix.cols([right, up, fwd]) 180 | trans = [right, up, fwd] 181 | # print(right, up, fwd) 182 | self.trans_py = [[trans[i][j] for i in range(3)] for j in range(3)] 183 | # self.trans_py = [[trans[i][j] for j in range(3)] for i in range(3)] 184 | self.pos_py = pos 185 | self.target_py = target 186 | if not init: 187 | self.pos[None] = self.pos_py 188 | self.trans[None] = self.trans_py 189 | self.target[None] = self.target_py 190 | 191 | def set_extrinsic(self, trans, pose): 192 | trans = [[trans[i][j] for j in range(3)] for i in range(3)] 193 | self.trans_py = trans 194 | self.pos_py = pose 195 | 196 | def _init(self): 197 | self.pos[None] = self.pos_py 198 | self.trans[None] = self.trans_py 199 | self.target[None] = self.target_py 200 | self.intrinsic[None][0, 0] = self.fx 201 | self.intrinsic[None][0, 2] = self.cx 202 | self.intrinsic[None][1, 1] = self.fy 203 | self.intrinsic[None][1, 2] = self.cy 204 | self.intrinsic[None][2, 2] = 1.0 205 | 206 | @ti.func 207 | def clear_buffer(self): 208 | for I in ti.grouped(self.img): 209 | self.img[I] = ti.Vector([0.0, 0.0, 0.0], ti.f32) 210 | self.zbuf[I] = 0.0 211 | self.mask[I] = ti.Vector([0.0, 0.0, 0.0], ti.f32) 212 | self.normal_map[I] = ti.Vector([0.5, 0.5, 0.5], ti.f32) 213 | 214 | def from_mouse(self, gui): 215 | is_alter_move = gui.is_pressed(ti.GUI.CTRL) 216 | if gui.is_pressed(ti.GUI.LMB): 217 | mpos = gui.get_cursor_pos() 218 | if self.mpos != (0, 0): 219 | self.orbit((mpos[0] - self.mpos[0], mpos[1] - self.mpos[1]), 220 | pov=is_alter_move) 221 | self.mpos = mpos 222 | elif gui.is_pressed(ti.GUI.RMB): 223 | mpos = gui.get_cursor_pos() 224 | if self.mpos != (0, 0): 225 | self.zoom_by_mouse(mpos, (mpos[0] - self.mpos[0], mpos[1] - self.mpos[1]), 226 | dolly=is_alter_move) 227 | self.mpos = mpos 228 | elif gui.is_pressed(ti.GUI.MMB): 229 | mpos = gui.get_cursor_pos() 230 | if self.mpos != (0, 0): 231 | self.pan((mpos[0] - self.mpos[0], mpos[1] - self.mpos[1])) 232 | self.mpos = mpos 233 | else: 234 | if gui.event and gui.event.key == ti.GUI.WHEEL: 235 | # one mouse wheel unit is (0, 120) 236 | self.zoom(-gui.event.delta[1] / 1200, 237 | dolly=is_alter_move) 238 | gui.event = None 239 | mpos = (0, 0) 240 | self.mpos = mpos 241 | 242 | 243 | def orbit(self, delta, sensitivity=5, pov=False): 244 | ds, dt = delta 245 | if ds != 0 or dt != 0: 246 | dis = math.sqrt(sum((self.target_py[i] - self.pos_py[i]) ** 2 for i in range(3))) 247 | fov = self.fov 248 | ds, dt = ds * fov * sensitivity, dt * fov * sensitivity 249 | newdir = ti.Vector([ds, dt, 1], ti.f32).normalized() 250 | newdir = [sum(self.trans[None][i, j] * newdir[j] for j in range(3))\ 251 | for i in range(3)] 252 | if pov: 253 | newtarget = [self.pos_py[i] + dis * newdir[i] for i in range(3)] 254 | self.set(target=newtarget) 255 | else: 256 | newpos = [self.target_py[i] - dis * newdir[i] for i in range(3)] 257 | self.set(pos=newpos) 258 | 259 | def zoom_by_mouse(self, pos, delta, sensitivity=3, dolly=False): 260 | ds, dt = delta 261 | if ds != 0 or dt != 0: 262 | z = math.sqrt(ds ** 2 + dt ** 2) * sensitivity 263 | if (pos[0] - 0.5) * ds + (pos[1] - 0.5) * dt > 0: 264 | z *= -1 265 | self.zoom(z, dolly) 266 | 267 | def zoom(self, z, dolly=False): 268 | newpos = [(1 + z) * self.pos_py[i] - z * self.target_py[i] for i in range(3)] 269 | if dolly: 270 | newtarget = [z * self.pos_py[i] + (1 - z) * self.target_py[i] for i in range(3)] 271 | self.set(pos=newpos, target=newtarget) 272 | else: 273 | self.set(pos=newpos) 274 | 275 | def pan(self, delta, sensitivity=3): 276 | ds, dt = delta 277 | if ds != 0 or dt != 0: 278 | dis = math.sqrt(sum((self.target_py[i] - self.pos_py[i]) ** 2 for i in range(3))) 279 | fov = self.fov 280 | ds, dt = ds * fov * sensitivity, dt * fov * sensitivity 281 | newdir = ti.Vector([-ds, -dt, 1], ti.f32).normalized() 282 | newdir = [sum(self.trans[None][i, j] * newdir[j] for j in range(3))\ 283 | for i in range(3)] 284 | newtarget = [self.pos_py[i] + dis * newdir[i] for i in range(3)] 285 | newpos = [self.pos_py[i] + newtarget[i] - self.target_py[i] for i in range(3)] 286 | self.set(pos=newpos, target=newtarget) 287 | 288 | @ti.func 289 | def trans_pos(self, pos): 290 | return self.trans[None] @ pos + self.pos[None] 291 | 292 | @ti.func 293 | def trans_dir(self, pos): 294 | return self.trans[None] @ pos 295 | 296 | @ti.func 297 | def untrans_pos(self, pos): 298 | return self.trans[None].inverse() @ (pos - self.pos[None]) 299 | 300 | @ti.func 301 | def untrans_dir(self, pos): 302 | return self.trans[None].inverse() @ pos 303 | 304 | @ti.func 305 | def uncook(self, pos): 306 | if ti.static(self.type == self.ORTHO): 307 | pos[0] *= self.intrinsic[None][0, 0] 308 | pos[1] *= self.intrinsic[None][1, 1] 309 | pos[0] += self.intrinsic[None][0, 2] 310 | pos[1] += self.intrinsic[None][1, 2] 311 | elif ti.static(self.type == self.TAN_FOV): 312 | pos = self.intrinsic[None] @ pos 313 | pos[0] /= abs(pos[2]) 314 | pos[1] /= abs(pos[2]) 315 | else: 316 | raise NotImplementedError("Curvilinear projection matrix not implemented!") 317 | return ti.Vector([pos[0], pos[1]], ti.f32) 318 | 319 | def export_intrinsic(self): 320 | import numpy as np 321 | intrinsic = np.zeros((3, 3)) 322 | intrinsic[0, 0] = self.fx 323 | intrinsic[1, 1] = self.fy 324 | intrinsic[0, 2] = self.cx 325 | intrinsic[1, 2] = self.cy 326 | intrinsic[2, 2] = 1 327 | return intrinsic 328 | 329 | def export_extrinsic(self): 330 | import numpy as np 331 | trans = np.array(self.trans_py) 332 | pos = np.array(self.pos_py) 333 | extrinsic = np.zeros((3, 4)) 334 | 335 | trans = np.transpose(trans) 336 | for i in range(3): 337 | for j in range(3): 338 | extrinsic[i][j] = trans[i, j] 339 | pos = -trans @ pos 340 | for i in range(3): 341 | extrinsic[i][3] = pos[i] 342 | return extrinsic 343 | 344 | @ti.func 345 | def generate(self, coor): 346 | fov = ti.static(self.fov) 347 | tan_fov = ti.static(math.tan(fov)) 348 | 349 | orig = ti.Vector([0.0, 0.0, 0.0], ti.f32) 350 | dir = ti.Vector([0.0, 0.0, 1.0], ti.f32) 351 | 352 | if ti.static(self.type == self.ORTHO): 353 | orig = ti.Vector([coor[0], coor[1], 0.0], ti.f32) 354 | elif ti.static(self.type == self.TAN_FOV): 355 | uv = coor * fov 356 | dir = ti.Vector([uv[0], uv[1], 1], ti.f32).normalized() 357 | elif ti.static(self.type == self.COS_FOV): 358 | uv = coor * fov 359 | dir = ti.Vector([ti.sin(uv), ti.cos(uv.norm())], ti.f32).normalized() 360 | 361 | orig = self.trans_pos(orig) 362 | dir = self.trans_dir(dir) 363 | 364 | return orig, dir 365 | -------------------------------------------------------------------------------- /prepare_data/taichi_three/version.py: -------------------------------------------------------------------------------- 1 | version = (0, 0, 4) 2 | taiglsl_version = (0, 0, 9) 3 | taichi_version = (0, 6, 28) 4 | 5 | print(f'[Tai3D] version {".".join(map(str, version))}') 6 | print(f'[Tai3D] Inputs are welcomed at https://github.com/taichi-dev/taichi_three') 7 | print(f'[Tai3D] Camera control hints: LMB to orbit, MMB to move, RMB to scale') 8 | -------------------------------------------------------------------------------- /test_real_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import argparse 4 | import logging 5 | import numpy as np 6 | import cv2 7 | import os 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | 11 | from lib.human_loader import StereoHumanDataset 12 | from lib.network import RtStereoHumanModel 13 | from config.stereo_human_config import ConfigStereoHuman as config 14 | from lib.utils import get_novel_calib 15 | from lib.GaussianRender import pts2render 16 | 17 | import torch 18 | import warnings 19 | warnings.filterwarnings("ignore", category=UserWarning) 20 | 21 | 22 | class StereoHumanRender: 23 | def __init__(self, cfg_file, phase): 24 | self.cfg = cfg_file 25 | self.bs = self.cfg.batch_size 26 | 27 | self.model = RtStereoHumanModel(self.cfg, with_gs_render=True) 28 | self.dataset = StereoHumanDataset(self.cfg.dataset, phase=phase) 29 | self.model.cuda() 30 | if self.cfg.restore_ckpt: 31 | self.load_ckpt(self.cfg.restore_ckpt) 32 | self.model.eval() 33 | 34 | def infer_seqence(self, view_select, ratio=0.5): 35 | total_frames = len(os.listdir(os.path.join(self.cfg.dataset.test_data_root, 'img'))) 36 | for idx in tqdm(range(total_frames)): 37 | item = self.dataset.get_test_item(idx, source_id=view_select) 38 | data = self.fetch_data(item) 39 | data = get_novel_calib(data, self.cfg.dataset, ratio=ratio, intr_key='intr_ori', extr_key='extr_ori') 40 | with torch.no_grad(): 41 | data, _, _ = self.model(data, is_train=False) 42 | data = pts2render(data, bg_color=self.cfg.dataset.bg_color) 43 | 44 | render_novel = self.tensor2np(data['novel_view']['img_pred']) 45 | cv2.imwrite(self.cfg.test_out_path + '/%s_novel.jpg' % (data['name']), render_novel) 46 | 47 | def tensor2np(self, img_tensor): 48 | img_np = img_tensor.permute(0, 2, 3, 1)[0].detach().cpu().numpy() 49 | img_np = img_np * 255 50 | img_np = img_np[:, :, ::-1].astype(np.uint8) 51 | return img_np 52 | 53 | def fetch_data(self, data): 54 | for view in ['lmain', 'rmain']: 55 | for item in data[view].keys(): 56 | data[view][item] = data[view][item].cuda().unsqueeze(0) 57 | return data 58 | 59 | def load_ckpt(self, load_path): 60 | assert os.path.exists(load_path) 61 | logging.info(f"Loading checkpoint from {load_path} ...") 62 | ckpt = torch.load(load_path, map_location='cuda') 63 | self.model.load_state_dict(ckpt['network'], strict=True) 64 | logging.info(f"Parameter loading done") 65 | 66 | 67 | if __name__ == '__main__': 68 | logging.basicConfig(level=logging.INFO, 69 | format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s') 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('--test_data_root', type=str, required=True) 72 | parser.add_argument('--ckpt_path', type=str, required=True) 73 | parser.add_argument('--src_view', type=int, nargs='+', required=True) 74 | parser.add_argument('--ratio', type=float, default=0.5) 75 | arg = parser.parse_args() 76 | 77 | cfg = config() 78 | cfg_for_train = os.path.join('./config', 'stage2.yaml') 79 | cfg.load(cfg_for_train) 80 | cfg = cfg.get_cfg() 81 | 82 | cfg.defrost() 83 | cfg.batch_size = 1 84 | cfg.dataset.test_data_root = arg.test_data_root 85 | cfg.dataset.use_processed_data = False 86 | cfg.restore_ckpt = arg.ckpt_path 87 | cfg.test_out_path = './test_out' 88 | Path(cfg.test_out_path).mkdir(exist_ok=True, parents=True) 89 | cfg.freeze() 90 | 91 | render = StereoHumanRender(cfg, phase='test') 92 | render.infer_seqence(view_select=arg.src_view, ratio=arg.ratio) 93 | -------------------------------------------------------------------------------- /test_view_interp.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import argparse 4 | import logging 5 | import numpy as np 6 | import cv2 7 | import os 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | 11 | from lib.human_loader import StereoHumanDataset 12 | from lib.network import RtStereoHumanModel 13 | from config.stereo_human_config import ConfigStereoHuman as config 14 | from lib.utils import get_novel_calib 15 | from lib.GaussianRender import pts2render 16 | 17 | import torch 18 | import warnings 19 | warnings.filterwarnings("ignore", category=UserWarning) 20 | 21 | 22 | class StereoHumanRender: 23 | def __init__(self, cfg_file, phase): 24 | self.cfg = cfg_file 25 | self.bs = self.cfg.batch_size 26 | 27 | self.model = RtStereoHumanModel(self.cfg, with_gs_render=True) 28 | self.dataset = StereoHumanDataset(self.cfg.dataset, phase=phase) 29 | self.model.cuda() 30 | if self.cfg.restore_ckpt: 31 | self.load_ckpt(self.cfg.restore_ckpt) 32 | self.model.eval() 33 | 34 | def infer_static(self, view_select, novel_view_nums): 35 | total_samples = len(os.listdir(os.path.join(self.cfg.dataset.test_data_root, 'img'))) 36 | for idx in tqdm(range(total_samples)): 37 | item = self.dataset.get_test_item(idx, source_id=view_select) 38 | data = self.fetch_data(item) 39 | for i in range(novel_view_nums): 40 | ratio_tmp = (i+0.5)*(1/novel_view_nums) 41 | data_i = get_novel_calib(data, self.cfg.dataset, ratio=ratio_tmp, intr_key='intr_ori', extr_key='extr_ori') 42 | with torch.no_grad(): 43 | data_i, _, _ = self.model(data_i, is_train=False) 44 | data_i = pts2render(data_i, bg_color=self.cfg.dataset.bg_color) 45 | 46 | render_novel = self.tensor2np(data['novel_view']['img_pred']) 47 | cv2.imwrite(self.cfg.test_out_path + '/%s_novel%s.jpg' % (data_i['name'], str(i).zfill(2)), render_novel) 48 | 49 | def tensor2np(self, img_tensor): 50 | img_np = img_tensor.permute(0, 2, 3, 1)[0].detach().cpu().numpy() 51 | img_np = img_np * 255 52 | img_np = img_np[:, :, ::-1].astype(np.uint8) 53 | return img_np 54 | 55 | def fetch_data(self, data): 56 | for view in ['lmain', 'rmain']: 57 | for item in data[view].keys(): 58 | data[view][item] = data[view][item].cuda().unsqueeze(0) 59 | return data 60 | 61 | def load_ckpt(self, load_path): 62 | assert os.path.exists(load_path) 63 | logging.info(f"Loading checkpoint from {load_path} ...") 64 | ckpt = torch.load(load_path, map_location='cuda') 65 | self.model.load_state_dict(ckpt['network'], strict=True) 66 | logging.info(f"Parameter loading done") 67 | 68 | 69 | if __name__ == '__main__': 70 | logging.basicConfig(level=logging.INFO, 71 | format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s') 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--test_data_root', type=str, required=True) 74 | parser.add_argument('--ckpt_path', type=str, required=True) 75 | parser.add_argument('--novel_view_nums', type=int, default=5) 76 | arg = parser.parse_args() 77 | 78 | cfg = config() 79 | cfg_for_train = os.path.join('./config', 'stage2.yaml') 80 | cfg.load(cfg_for_train) 81 | cfg = cfg.get_cfg() 82 | 83 | cfg.defrost() 84 | cfg.batch_size = 1 85 | cfg.dataset.test_data_root = arg.test_data_root 86 | cfg.dataset.use_processed_data = False 87 | cfg.restore_ckpt = arg.ckpt_path 88 | cfg.test_out_path = './interp_out' 89 | Path(cfg.test_out_path).mkdir(exist_ok=True, parents=True) 90 | cfg.freeze() 91 | 92 | render = StereoHumanRender(cfg, phase='test') 93 | render.infer_static(view_select=[0, 1], novel_view_nums=arg.novel_view_nums) 94 | -------------------------------------------------------------------------------- /train_stage1.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import logging 3 | 4 | import numpy as np 5 | import cv2 6 | import os 7 | from pathlib import Path 8 | from tqdm import tqdm 9 | from datetime import datetime 10 | 11 | from lib.human_loader import StereoHumanDataset 12 | from lib.network import RtStereoHumanModel 13 | from config.stereo_human_config import ConfigStereoHuman as config 14 | from lib.train_recoder import Logger, file_backup 15 | from lib.utils import get_novel_calib_for_show as get_novel_calib 16 | from lib.TaichiRender import TaichiRenderBatch 17 | 18 | import torch 19 | import torch.optim as optim 20 | from torch.cuda.amp import GradScaler 21 | from torch.utils.data import DataLoader 22 | import warnings 23 | warnings.filterwarnings("ignore", category=UserWarning) 24 | 25 | 26 | class Trainer: 27 | def __init__(self, cfg_file): 28 | self.cfg = cfg_file 29 | 30 | self.model = RtStereoHumanModel(self.cfg, with_gs_render=False) 31 | self.train_set = StereoHumanDataset(self.cfg.dataset, phase='train') 32 | self.train_loader = DataLoader(self.train_set, batch_size=self.cfg.batch_size, shuffle=True, 33 | num_workers=self.cfg.batch_size*2, pin_memory=True) 34 | self.train_iterator = iter(self.train_loader) 35 | self.val_set = StereoHumanDataset(self.cfg.dataset, phase='val') 36 | self.val_loader = DataLoader(self.val_set, batch_size=2, shuffle=False, num_workers=4, pin_memory=True) 37 | self.len_val = int(len(self.val_loader) / self.val_set.val_boost) # real length of val set 38 | self.val_iterator = iter(self.val_loader) 39 | self.optimizer = optim.AdamW(self.model.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.wdecay, eps=1e-8) 40 | self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer, self.cfg.lr, 100100, pct_start=0.01, 41 | cycle_momentum=False, anneal_strategy='linear') 42 | 43 | self.logger = Logger(self.scheduler, cfg.record) 44 | self.total_steps = 0 45 | 46 | self.model.cuda() 47 | if self.cfg.restore_ckpt: 48 | self.load_ckpt(self.cfg.restore_ckpt) 49 | self.model.train() 50 | self.model.raft_stereo.freeze_bn() 51 | self.scaler = GradScaler(enabled=self.cfg.raft.mixed_precision) 52 | self.render = TaichiRenderBatch(bs=1, res=self.cfg.dataset.src_res) 53 | 54 | def train(self): 55 | for _ in tqdm(range(self.total_steps, self.cfg.num_steps)): 56 | self.optimizer.zero_grad() 57 | data = self.fetch_data(phase='train') 58 | 59 | # Raft Stereo 60 | _, flow_loss, metrics = self.model(data) 61 | loss = flow_loss 62 | 63 | if self.total_steps and self.total_steps % self.cfg.record.loss_freq == 0: 64 | self.logger.writer.add_scalar(f'lr', self.optimizer.param_groups[0]['lr'], self.total_steps) 65 | self.save_ckpt(save_path=Path('%s/%s_latest.pth' % (cfg.record.ckpt_path, cfg.name)), show_log=False) 66 | self.logger.push(metrics) 67 | 68 | self.scaler.scale(loss).backward() 69 | self.scaler.unscale_(self.optimizer) 70 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) 71 | 72 | self.scaler.step(self.optimizer) 73 | self.scheduler.step() 74 | self.scaler.update() 75 | 76 | if self.total_steps and self.total_steps % self.cfg.record.eval_freq == 0: 77 | self.model.eval() 78 | self.run_eval() 79 | self.model.train() 80 | self.model.raft_stereo.freeze_bn() 81 | 82 | self.total_steps += 1 83 | 84 | print("FINISHED TRAINING") 85 | self.logger.close() 86 | self.save_ckpt(save_path=Path('%s/%s_final.pth' % (cfg.record.ckpt_path, cfg.name))) 87 | 88 | def run_eval(self): 89 | logging.info(f"Doing validation ...") 90 | torch.cuda.empty_cache() 91 | epe_list, one_pix_list = [], [] 92 | show_idx = np.random.choice(list(range(self.len_val)), 1) 93 | for idx in range(self.len_val): 94 | data = self.fetch_data(phase='val') 95 | with torch.no_grad(): 96 | data, _, _ = self.model(data, is_train=False) 97 | 98 | if idx == show_idx: 99 | data = get_novel_calib(data, ratio=0.5) 100 | data = self.render.flow2render(data) 101 | tmp_novel = data['novel_view']['img_pred'][0].detach() 102 | tmp_novel = (tmp_novel / 2.0 + 0.5) * 255 103 | tmp_novel = tmp_novel.permute(1, 2, 0).cpu().numpy() 104 | tmp_img_name = '%s/%s.jpg' % (cfg.record.show_path, self.total_steps) 105 | cv2.imwrite(tmp_img_name, tmp_novel[:, :, ::-1].astype(np.uint8)) 106 | 107 | for view in ['lmain', 'rmain']: 108 | valid = (data[view]['valid'] >= 0.5) 109 | epe = torch.sum((data[view]['flow'] - data[view]['flow_pred']) ** 2, dim=1).sqrt() 110 | epe = epe.view(-1)[valid.view(-1)] 111 | one_pix = (epe < 1) 112 | epe_list.append(epe.mean().item()) 113 | one_pix_list.append(one_pix.float().mean().item()) 114 | 115 | val_epe = np.round(np.mean(np.array(epe_list)), 4) 116 | val_one_pix = np.round(np.mean(np.array(one_pix_list)), 4) 117 | logging.info(f"Validation Metrics ({self.total_steps}): epe {val_epe}, 1pix {val_one_pix}") 118 | self.logger.write_dict({'val_epe': val_epe, 'val_1pix': val_one_pix}, write_step=self.total_steps) 119 | torch.cuda.empty_cache() 120 | 121 | def fetch_data(self, phase): 122 | if phase == 'train': 123 | try: 124 | data = next(self.train_iterator) 125 | except: 126 | self.train_iterator = iter(self.train_loader) 127 | data = next(self.train_iterator) 128 | elif phase == 'val': 129 | try: 130 | data = next(self.val_iterator) 131 | except: 132 | self.val_iterator = iter(self.val_loader) 133 | data = next(self.val_iterator) 134 | 135 | for view in ['lmain', 'rmain']: 136 | for item in data[view].keys(): 137 | data[view][item] = data[view][item].cuda() 138 | return data 139 | 140 | def load_ckpt(self, load_path, load_optimizer=True, strict=True): 141 | assert os.path.exists(load_path) 142 | logging.info(f"Loading checkpoint from {load_path} ...") 143 | ckpt = torch.load(load_path, map_location='cuda') 144 | self.model.load_state_dict(ckpt['network'], strict=strict) 145 | logging.info(f"Parameter loading done") 146 | if load_optimizer: 147 | self.total_steps = ckpt['total_steps'] + 1 148 | self.logger.total_steps = self.total_steps 149 | self.optimizer.load_state_dict(ckpt['optimizer']) 150 | self.scheduler.load_state_dict(ckpt['scheduler']) 151 | logging.info(f"Optimizer loading done") 152 | 153 | def save_ckpt(self, save_path, show_log=True): 154 | if show_log: 155 | logging.info(f"Save checkpoint to {save_path} ...") 156 | torch.save({ 157 | 'total_steps': self.total_steps, 158 | 'network': self.model.state_dict(), 159 | 'optimizer': self.optimizer.state_dict(), 160 | 'scheduler': self.scheduler.state_dict() 161 | }, save_path) 162 | 163 | 164 | if __name__ == '__main__': 165 | logging.basicConfig(level=logging.INFO, 166 | format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s') 167 | 168 | cfg = config() 169 | cfg.load("config/stage1.yaml") 170 | cfg = cfg.get_cfg() 171 | 172 | cfg.defrost() 173 | dt = datetime.today() 174 | cfg.exp_name = '%s_%s%s' % (cfg.name, str(dt.month).zfill(2), str(dt.day).zfill(2)) 175 | cfg.record.ckpt_path = "experiments/%s/ckpt" % cfg.exp_name 176 | cfg.record.show_path = "experiments/%s/show" % cfg.exp_name 177 | cfg.record.logs_path = "experiments/%s/logs" % cfg.exp_name 178 | cfg.record.file_path = "experiments/%s/file" % cfg.exp_name 179 | cfg.freeze() 180 | 181 | for path in [cfg.record.ckpt_path, cfg.record.show_path, cfg.record.logs_path, cfg.record.file_path]: 182 | Path(path).mkdir(exist_ok=True, parents=True) 183 | 184 | file_backup(cfg.record.file_path, cfg, train_script=os.path.basename(__file__)) 185 | 186 | torch.manual_seed(1314) 187 | np.random.seed(1314) 188 | 189 | trainer = Trainer(cfg) 190 | trainer.train() 191 | -------------------------------------------------------------------------------- /train_stage2.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import logging 4 | 5 | import numpy as np 6 | import cv2 7 | import os 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | from datetime import datetime 11 | 12 | from lib.human_loader import StereoHumanDataset 13 | from lib.network import RtStereoHumanModel 14 | from config.stereo_human_config import ConfigStereoHuman as config 15 | from lib.train_recoder import Logger, file_backup 16 | from lib.GaussianRender import pts2render 17 | from lib.loss import l1_loss, ssim, psnr 18 | 19 | import torch 20 | import torch.optim as optim 21 | from torch.cuda.amp import GradScaler 22 | from torch.utils.data import DataLoader 23 | import warnings 24 | warnings.filterwarnings("ignore", category=UserWarning) 25 | 26 | 27 | class Trainer: 28 | def __init__(self, cfg_file): 29 | self.cfg = cfg_file 30 | 31 | self.model = RtStereoHumanModel(self.cfg, with_gs_render=True) 32 | self.train_set = StereoHumanDataset(self.cfg.dataset, phase='train') 33 | self.train_loader = DataLoader(self.train_set, batch_size=self.cfg.batch_size, shuffle=True, 34 | num_workers=self.cfg.batch_size*2, pin_memory=True) 35 | self.train_iterator = iter(self.train_loader) 36 | self.val_set = StereoHumanDataset(self.cfg.dataset, phase='val') 37 | self.val_loader = DataLoader(self.val_set, batch_size=2, shuffle=False, num_workers=4, pin_memory=True) 38 | self.len_val = int(len(self.val_loader) / self.val_set.val_boost) # real length of val set 39 | self.val_iterator = iter(self.val_loader) 40 | self.optimizer = optim.AdamW(self.model.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.wdecay, eps=1e-8) 41 | self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer, self.cfg.lr, self.cfg.num_steps + 100, 42 | pct_start=0.01, cycle_momentum=False, anneal_strategy='linear') 43 | 44 | self.logger = Logger(self.scheduler, cfg.record) 45 | self.total_steps = 0 46 | 47 | self.model.cuda() 48 | if self.cfg.restore_ckpt: 49 | self.load_ckpt(self.cfg.restore_ckpt) 50 | elif self.cfg.stage1_ckpt: 51 | logging.info(f"Using checkpoint from stage1") 52 | self.load_ckpt(self.cfg.stage1_ckpt, load_optimizer=False, strict=False) 53 | self.model.train() 54 | self.model.raft_stereo.freeze_bn() 55 | self.scaler = GradScaler(enabled=self.cfg.raft.mixed_precision) 56 | 57 | def train(self): 58 | for _ in tqdm(range(self.total_steps, self.cfg.num_steps)): 59 | self.optimizer.zero_grad() 60 | data = self.fetch_data(phase='train') 61 | 62 | # Raft Stereo + GS Regresser 63 | data, flow_loss, metrics = self.model(data, is_train=True) 64 | # Gaussian Render 65 | data = pts2render(data, bg_color=self.cfg.dataset.bg_color) 66 | 67 | render_novel = data['novel_view']['img_pred'] 68 | gt_novel = data['novel_view']['img'].cuda() 69 | 70 | Ll1 = l1_loss(render_novel, gt_novel) 71 | Lssim = 1.0 - ssim(render_novel, gt_novel) 72 | loss = 1.0 * flow_loss + 0.8 * Ll1 + 0.2 * Lssim 73 | 74 | if self.total_steps and self.total_steps % self.cfg.record.loss_freq == 0: 75 | self.logger.writer.add_scalar(f'lr', self.optimizer.param_groups[0]['lr'], self.total_steps) 76 | self.save_ckpt(save_path=Path('%s/%s_latest.pth' % (cfg.record.ckpt_path, cfg.name)), show_log=False) 77 | metrics.update({ 78 | 'l1': Ll1.item(), 79 | 'ssim': Lssim.item(), 80 | }) 81 | self.logger.push(metrics) 82 | 83 | self.scaler.scale(loss).backward() 84 | self.scaler.unscale_(self.optimizer) 85 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) 86 | 87 | self.scaler.step(self.optimizer) 88 | self.scheduler.step() 89 | self.scaler.update() 90 | 91 | if self.total_steps and self.total_steps % self.cfg.record.eval_freq == 0: 92 | self.model.eval() 93 | self.run_eval() 94 | self.model.train() 95 | self.model.raft_stereo.freeze_bn() 96 | 97 | self.total_steps += 1 98 | 99 | print("FINISHED TRAINING") 100 | self.logger.close() 101 | self.save_ckpt(save_path=Path('%s/%s_final.pth' % (cfg.record.ckpt_path, cfg.name))) 102 | 103 | def run_eval(self): 104 | logging.info(f"Doing validation ...") 105 | torch.cuda.empty_cache() 106 | epe_list, one_pix_list, psnr_list = [], [], [] 107 | show_idx = np.random.choice(list(range(self.len_val)), 1) 108 | for idx in range(self.len_val): 109 | data = self.fetch_data(phase='val') 110 | with torch.no_grad(): 111 | data, _, _ = self.model(data, is_train=False) 112 | data = pts2render(data, bg_color=self.cfg.dataset.bg_color) 113 | 114 | render_novel = data['novel_view']['img_pred'] 115 | gt_novel = data['novel_view']['img'].cuda() 116 | psnr_value = psnr(render_novel, gt_novel).mean().double() 117 | psnr_list.append(psnr_value.item()) 118 | 119 | if idx == show_idx: 120 | tmp_novel = data['novel_view']['img_pred'][0].detach() 121 | tmp_novel *= 255 122 | tmp_novel = tmp_novel.permute(1, 2, 0).cpu().numpy() 123 | tmp_img_name = '%s/%s.jpg' % (cfg.record.show_path, self.total_steps) 124 | cv2.imwrite(tmp_img_name, tmp_novel[:, :, ::-1].astype(np.uint8)) 125 | 126 | for view in ['lmain', 'rmain']: 127 | valid = (data[view]['valid'] >= 0.5) 128 | epe = torch.sum((data[view]['flow'] - data[view]['flow_pred']) ** 2, dim=1).sqrt() 129 | epe = epe.view(-1)[valid.view(-1)] 130 | one_pix = (epe < 1) 131 | epe_list.append(epe.mean().item()) 132 | one_pix_list.append(one_pix.float().mean().item()) 133 | 134 | val_epe = np.round(np.mean(np.array(epe_list)), 4) 135 | val_one_pix = np.round(np.mean(np.array(one_pix_list)), 4) 136 | val_psnr = np.round(np.mean(np.array(psnr_list)), 4) 137 | logging.info(f"Validation Metrics ({self.total_steps}): epe {val_epe}, 1pix {val_one_pix}, psnr {val_psnr}") 138 | self.logger.write_dict({'val_epe': val_epe, 'val_1pix': val_one_pix, 'val_psnr': val_psnr}, write_step=self.total_steps) 139 | torch.cuda.empty_cache() 140 | 141 | def fetch_data(self, phase): 142 | if phase == 'train': 143 | try: 144 | data = next(self.train_iterator) 145 | except: 146 | self.train_iterator = iter(self.train_loader) 147 | data = next(self.train_iterator) 148 | elif phase == 'val': 149 | try: 150 | data = next(self.val_iterator) 151 | except: 152 | self.val_iterator = iter(self.val_loader) 153 | data = next(self.val_iterator) 154 | 155 | for view in ['lmain', 'rmain']: 156 | for item in data[view].keys(): 157 | data[view][item] = data[view][item].cuda() 158 | return data 159 | 160 | def load_ckpt(self, load_path, load_optimizer=True, strict=True): 161 | assert os.path.exists(load_path) 162 | logging.info(f"Loading checkpoint from {load_path} ...") 163 | ckpt = torch.load(load_path, map_location='cuda') 164 | self.model.load_state_dict(ckpt['network'], strict=strict) 165 | logging.info(f"Parameter loading done") 166 | if load_optimizer: 167 | self.total_steps = ckpt['total_steps'] + 1 168 | self.logger.total_steps = self.total_steps 169 | self.optimizer.load_state_dict(ckpt['optimizer']) 170 | self.scheduler.load_state_dict(ckpt['scheduler']) 171 | logging.info(f"Optimizer loading done") 172 | 173 | def save_ckpt(self, save_path, show_log=True): 174 | if show_log: 175 | logging.info(f"Save checkpoint to {save_path} ...") 176 | torch.save({ 177 | 'total_steps': self.total_steps, 178 | 'network': self.model.state_dict(), 179 | 'optimizer': self.optimizer.state_dict(), 180 | 'scheduler': self.scheduler.state_dict() 181 | }, save_path) 182 | 183 | 184 | if __name__ == '__main__': 185 | logging.basicConfig(level=logging.INFO, 186 | format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s') 187 | 188 | cfg = config() 189 | cfg.load("config/stage2.yaml") 190 | cfg = cfg.get_cfg() 191 | 192 | cfg.defrost() 193 | dt = datetime.today() 194 | cfg.exp_name = '%s_%s%s' % (cfg.name, str(dt.month).zfill(2), str(dt.day).zfill(2)) 195 | cfg.record.ckpt_path = "experiments/%s/ckpt" % cfg.exp_name 196 | cfg.record.show_path = "experiments/%s/show" % cfg.exp_name 197 | cfg.record.logs_path = "experiments/%s/logs" % cfg.exp_name 198 | cfg.record.file_path = "experiments/%s/file" % cfg.exp_name 199 | cfg.freeze() 200 | 201 | for path in [cfg.record.ckpt_path, cfg.record.show_path, cfg.record.logs_path, cfg.record.file_path]: 202 | Path(path).mkdir(exist_ok=True, parents=True) 203 | 204 | file_backup(cfg.record.file_path, cfg, train_script=os.path.basename(__file__)) 205 | 206 | torch.manual_seed(1314) 207 | np.random.seed(1314) 208 | 209 | trainer = Trainer(cfg) 210 | trainer.train() 211 | --------------------------------------------------------------------------------