├── MPI.py ├── MPV.py ├── README.md ├── config_parser.py ├── configs ├── debug_mpmesh.txt ├── debug_video.txt ├── mpi_base.txt ├── mpis │ ├── 1017palm.txt │ ├── 1017yuan.txt │ ├── 1020rock.txt │ ├── 1020ustfall1.txt │ ├── 1020ustfall2.txt │ ├── 108fall1.txt │ ├── 108fall2.txt │ ├── 108fall3.txt │ ├── 108fall4.txt │ ├── 108fall5.txt │ ├── 1101grass.txt │ ├── 1101towerd.txt │ ├── 110grasstree.txt │ ├── 110pillarrm.txt │ ├── ustfallclose.txt │ └── usttap.txt ├── mpv_base.txt └── mpvs │ ├── 1017palm.txt │ ├── 1017yuan.txt │ ├── 1020rock.txt │ ├── 1020ustfall1.txt │ ├── 1020ustfall2.txt │ ├── 108fall1.txt │ ├── 108fall2.txt │ ├── 108fall3.txt │ ├── 108fall4.txt │ ├── 108fall5.txt │ ├── 1101grass.txt │ ├── 1101towerd.txt │ ├── 110grasstree.txt │ ├── 110pillar.txt │ ├── ustfallclose.txt │ └── usttap.txt ├── dataloader.py ├── evaluations ├── C3D_model.py ├── LPIPS.py ├── NNMSE.py ├── SVFID.py ├── __init__.py ├── c3d_test.py ├── lpips │ ├── __init__.py │ ├── lpips.py │ ├── pretrained_networks.py │ └── weights │ │ └── v0.1 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth └── metrics.py ├── requirements.txt ├── run_all.sh ├── scripts ├── colmaps │ ├── __init__.py │ ├── colmap_script │ │ └── colmap_io_model.py │ └── llffposes │ │ ├── __init__.py │ │ ├── colmap_read_model.py │ │ ├── colmap_wrapper.py │ │ └── pose_utils.py ├── script_evaluate_ours.py ├── script_export_mesh.py ├── script_owndata_step1_standardization.py ├── script_owndata_step2_genllffpose.py └── script_render_video.py ├── teaser.jpg ├── train_3d.py ├── train_3dvid.py ├── utils.py ├── utils_mpi.py └── utils_vid.py /README.md: -------------------------------------------------------------------------------- 1 | # 3D Video Loops from Asynchronous Input 2 | This repository is the official code for the CVPR23 paper: **3D Video Loops from Asynchronous Input**. Please visit our [project page](https://limacv.github.io/VideoLoop3D_web/) for more information, such as supplementary, demo and dataset. 3 | 4 | ![Teaser](teaser.jpg) 5 | 6 | ## 1. Introduction 7 | In this project, we construct a 3D video loop from multi-view videos that can be asynchronous. The 3D video loop is represented as MTV, a new representation, which is essentially multiple tiles with dynamic textures. This code implements the following functionality: 8 | 9 | 1. The 2-stage optimization, which is the core of the paper. 10 | 2. An off-line renderer that render using pytorch slowly. 11 | 3. Evaluation code that compute metrics for comparison. 12 | 4. Scripts for data preprocessing and mesh export. 13 | 14 | There is another WebGL based renderer implemented [here](https://github.com/limacv/VideoLoopUI), which renders the exported mesh in real time even on an old iPhone X. 15 | 16 | ## 2. Train on dataset 17 | 18 | ### 2.1 prerequisite 19 | 20 | - The optimization is quite memory consuming. It requires a GPU with memory >= 24GB, e.g. RTX3090. Make sure you have enough GPU memory! 21 | - Install dependencies in the ```requirements.txt``` 22 | ``` 23 | conda create -n vloop3d python==3.8 24 | conda activate vloop3d 25 | pip install -r requirements.txt 26 | ``` 27 | - Install Pytorch3D following the instructures [here](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md). 28 | 29 | ### 2.2 dataset 30 | Download dataset from the link [here](https://drive.google.com/drive/folders/1sWH2thQgW_aZGRHGtSoEZ6ZN9s0PManK?usp=sharing). Place them somewhere. For example, you've placed ```fall2720p``` in ```/fall2720p```. 31 | 32 | ### 2.3 config path 33 | In the ```configs/mpi_base.txt``` and ```configs/mpv_base.txt```, change the ```prefix``` dir to ``````. 34 | 35 | Then later all the files will be stored in the ```//```. In the example it will be ```/mpis/108fall2``` and ```/mpvs/108fall2```. 36 | 37 | ### 2.4 stage 1: 38 | In this stage, we generate static Multi-plane Image (MPI) and 3D loopable mask (typically 10-15mins). 39 | Run following: 40 | ``` 41 | python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/$DATA_NAME.txt 42 | ``` 43 | 44 | ### 2.5 stage 2: 45 | After stage 1 finishes, run following. Note this will load **epoch_0119.tar** file generated in stage 1. In stage 2, we generate final 3D looping video using looping loss (typically 3-6h). 46 | ``` 47 | python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/$DATA_NAME.txt 48 | ``` 49 | 50 | After stage 2 finishes, you can get a 3D video loop saved as *.tar file. 51 | 52 | ## 3. Offline Render 53 | To render the MPV, you can run the following script: 54 | ``` 55 | export PYTHONPATH=.:$PYTHONPATH 56 | PYTHONPATH=.:$PYTHONPATH python scripts/script_render_video.py --config configs/mpv_base.txt --config1 configs/mpvs/$DATA_NAME.txt 57 | ``` 58 | We offer very simple control over time and view ``````. 59 | - If not specify ``````: render the spiral camera pose similar as NeRF. 60 | - If ``` = --t 0:10 --v 3```: render 3-th training pose from 0 to 10 frames. 61 | 62 | The rendering results will be saved at ```///renderonly/*``` 63 | 64 | ## 4. Evaluation 65 | To evaluate the results, you can run the following script: 66 | ``` 67 | export PYTHONPATH=.:$PYTHONPATH 68 | python scripts/scripts_evaluate_ours.py --config configs/mpv_base.txt --config1 configs/mpvs/$DATA_NAME.txt 69 | ``` 70 | This will generate ```///eval_metrics.txt```, which contains values for each metric. 71 | 72 | ## 5. Export mesh 73 | 74 | To export mesh: 75 | ``` 76 | export PYTHONPATH=.:$PYTHONPATH 77 | python scripts/scripts_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/$DATA_NAME.txt 78 | ``` 79 | This will generate mesh files under ```//meshes/```. 80 | 81 | ## 6. Online Renderer 82 | 83 | Please refer to [this repo](https://github.com/limacv/VideoLoopUI) for more details. 84 | 85 | ## 7. Using your own data 86 | 87 | The dataset file structure is fairly straightfoward, and the camera pose file 88 | follows those in the [LLFF dataset](https://github.com/Fyusion/LLFF) in the NeRF paper. 89 | So it should be easy to create and structure your own dataset. 90 | 91 | But still, we provide some simple scripts to help create your own dataset. 92 | 93 | ### 7.0 capture your data 94 | 95 | One limitation of our method is that 96 | we only support scenes with textural motion and repeatitive pattern. 97 | So the best practice is to capture water flows. 98 | 99 | During capturing, the best option is to use a tripod, 100 | which perfectly guarantees the camera pose to be static. 101 | If you hand-hold your camera, make sure it's still, and later 102 | you can stabilize the video using software. 103 | But this may lead to artifact in the results (also a future work). 104 | 105 | Also make sure there is enough static region so that the COLMAP works. 106 | 107 | ### 7.1 structuralize data 108 | 109 | I usually start by using video editing softwares to 110 | concatenate multi-view videos into a long video, 111 | with each interval have some black frames. 112 | This helps to standardize the fps, remove the sound, stabilize, etc. 113 | 114 | Then run the following to create a structured dataset: 115 | ``` 116 | cd scripts 117 | python script_owndata_step1_standardization.py \ 118 | --input_path \ 119 | --output_prefix 120 | ``` 121 | 122 | ### 7.2 pose estimation 123 | 124 | Make sure you install the [COLMAP](https://colmap.github.io/). 125 | Then assign the colmap executable path to variable ```colmap_path``` 126 | in the file ```scripts/colmaps/llffposes/colmap_wrapper.py``` 127 | 128 | Then run: 129 | ``` 130 | python script_owndata_step2_genllffpose.py \ 131 | --scenedir 132 | ``` 133 | 134 | ### 7.3 decide config file 135 | 136 | Create your own config file. Pay attention to the following configs: 137 | - ```near_factor``` and ```far_factor```: control the near and far plane of the MPI. There is ```close``` and ```far``` parameters automatically computed from the reconstructed point cloud by the llff pose estimation script. The final close plane and far plane will be ```near_factor * near``` and ```far_factor * far```, so you can manually tune the near/far plane using the *_factor. 138 | - ```sparsify_rmfirstlayer```: this config is pretty dirty. We find that sometimes in the 1st stage, the view-inconsistency of the input will be baked in some nearest planes. So you can choose to manually filter out these planes and how many planes to remove in the tile culling process is controlled by ```sparsify_rmfirstlayer```. 139 | - ```mpv_frm_num```: this is the frame number to be optimized. 140 | - ```loss_ref_idx```: when compute the looping loss, we find that setting a large patch size for every view leads to blurry, while a small patch leads to spatial inconsistency. Therefore we set a large patch size to only a few "reference" views, which is specified by ```loss_ref_idx```. 141 | 142 | ## Other Notes 143 | 144 | - Implementation details for the looping loss: 145 | - Pytorch unfold eats lots of GPU memory. Since the looping loss is computed for each pixel location, we save the memory by looping through macro_block, which yields same results but lower memory usage. 146 | - Instead of directly computing the loss between Q and K, we first assemble a retargeted video loop by folding the Kf(i). We find that this operation greatly reduces the training memory and training time. 147 | - We use different patch size for different view, as is illustrated in 7.3. 148 | - In each iteration, we randomly perturb the camera intrinsic for half pixel (i.e. cx += rand() - 0.5, same for cy). We find this can reduce the tiling artifact. See the demo [here](https://limacv.github.io/VideoLoopUI/?dataurl=assets/ustfall1_tiling) for adding this perturb and [here](https://limacv.github.io/VideoLoopUI/?dataurl=assets/ustfall1) for without perturb. There is still some artifact when render in high resolution (the training is conducted in 640x360). 149 | - Adaptive learning rate. We find that directly optimizing so much parameters, with each iteration only involve small number of parameters will lead to very noisy optimization results. This is because the Moment term in the optimizer will keep part of the parameters updating even if it has no gradient. We find that scaling the learning rate by the average frequency one parameter will have gradient solves the problem. 150 | 151 | 152 | ## Star & Citation 153 | If you feel this repo is useful, please consider 154 | starring this project, or citing the paper: 155 | ``` 156 | @misc{videoloop, 157 | doi = {10.48550/ARXIV.2303.05312}, 158 | url = {https://arxiv.org/abs/2303.05312}, 159 | author = {Ma, Li and Li, Xiaoyu and Liao, Jing and Sander, Pedro V.}, 160 | title = {3D Video Loops from Asynchronous Input}, 161 | publisher = {arXiv}, 162 | year = {2023}, 163 | } 164 | 165 | ``` 166 | -------------------------------------------------------------------------------- /config_parser.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | 3 | 4 | def config_parser(): 5 | parser = configargparse.ArgumentParser() 6 | # Two sets of config for naive hierarchical config structure 7 | parser.add_argument('--config', is_config_file=True, 8 | help='config file path for base') 9 | parser.add_argument('--config1', is_config_file=True, default='', 10 | help='config file path for each data') 11 | parser.add_argument("--expname", type=str, 12 | help='experiment name') 13 | parser.add_argument("--expname_postfix", type=str, default='', 14 | help='experiment name = expname + expname_postfix') 15 | parser.add_argument("--test_view_idx", type=str, default='', 16 | help='#,#,#') 17 | 18 | parser.add_argument("--prefix", type=str, default='', 19 | help='the root of everything') 20 | parser.add_argument("--datadir", type=str, 21 | help='input data directory') 22 | parser.add_argument("--expdir", type=str, 23 | help='where to store ckpts and logs') 24 | parser.add_argument("--seed", type=int, default=666, 25 | help='random seed') 26 | parser.add_argument("--factor", type=int, default=2, 27 | help='downsample factor for LLFF images') 28 | parser.add_argument("--near_factor", type=float, default=0.9, help='the actual near plane will be near_factor * near') 29 | parser.add_argument("--far_factor", type=float, default=2, help='the actual far plane will be far_factor * far') 30 | parser.add_argument("--chunk", type=int, default=1024 * 32, 31 | help='unused') 32 | parser.add_argument("--fp16", action='store_true', 33 | help='use half precision to train, currently still have bug, do NOT use') 34 | parser.add_argument("--bg_color", type=str, default="", 35 | help='0#0#0, or random, the background color') 36 | parser.add_argument("--scale_invariant", action='store_true', 37 | help='scale_invariant rgb loss, scaling before compute the MSE') 38 | 39 | # for MPV only, not used for MPMesh 40 | parser.add_argument("--mpv_frm_num", type=int, default=90, 41 | help='frame number of the mpv') 42 | parser.add_argument("--mpv_isloop", action='store_true', 43 | help='whether to produce looping videos') 44 | parser.add_argument("--init_from", type=str, default='', 45 | help='path to ckpt, will add prefix, currently only support reload from MPI') 46 | parser.add_argument("--init_std", type=float, default=0, 47 | help='noise std of the dynamic MPV') 48 | parser.add_argument("--add_uv_noise", action='store_true', 49 | help='add noise to uv, unused') 50 | parser.add_argument("--add_intrin_noise", action='store_true', 51 | help='add noise to intrinsic, to prevent tiling artifact') 52 | 53 | # loss config 54 | parser.add_argument("--loss_ref_idx", type=str, default='0', 55 | help='#,#,# swd_alpha = ref if view==swd_alpha_reference_viewidx else other') 56 | parser.add_argument("--loss_name", type=str, default='gpnn', 57 | help='gpnn, mse, swd, avg. gpnn_x to specify alpha==x') 58 | parser.add_argument("--loss_name_ref", type=str, default='gpnn', 59 | help='gpnn, mse, swd, avg. gpnn_x to specify alpha==x') 60 | parser.add_argument("--swd_macro_block", type=int, default=65, 61 | help='used for gpnn low mem') 62 | parser.add_argument("--swd_patch_size_ref", type=int, default=5, 63 | help='gpnn patch size for reference view') 64 | parser.add_argument("--swd_patch_size", type=int, default=5, 65 | help='gpnn patch size for other view') 66 | parser.add_argument("--swd_patcht_size_ref", type=int, default=5, 67 | help='gpnn temporal patch size for reference view') 68 | parser.add_argument("--swd_patcht_size", type=int, default=5, 69 | help='gpnn temporal patch size for other view') 70 | parser.add_argument("--swd_stride_ref", type=int, default=2, 71 | help='gpnn stride size for reference view') 72 | parser.add_argument("--swd_stride", type=int, default=2, 73 | help='gpnn stride size for other view') 74 | parser.add_argument("--swd_stridet", type=int, default=2, 75 | help='gpnn temporal stride size for reference view') 76 | parser.add_argument("--swd_stridet_ref", type=int, default=2, 77 | help='gpnn temporal stride size for other view') 78 | parser.add_argument("--swd_rou", type=str, default='0', 79 | help='parameter of robustness term, can also be mse, abs') 80 | parser.add_argument("--swd_rou_ref", type=str, default='0', 81 | help='parameter of robustness term, can also be mse, abs') 82 | parser.add_argument("--swd_scaling", type=float, default=0.2, 83 | help='parameter of robustness term') 84 | parser.add_argument("--swd_scaling_ref", type=float, default=0.2, 85 | help='parameter of robustness term') 86 | parser.add_argument("--swd_alpha", type=float, default=0, 87 | help='alpha, bigger than 100 is equivalent to None, (the rou in paper)') 88 | parser.add_argument("--swd_alpha_ref", type=float, default=0, 89 | help='alpha, bigger than 100 is equivalent to None, (the rou in paper)') 90 | parser.add_argument("--swd_dist_fn", type=str, default='mse', 91 | help='distance function, currently not setable') 92 | parser.add_argument("--swd_dist_fn_ref", type=str, default='mse', 93 | help='distance function, currently not setable') 94 | parser.add_argument("--swd_factor", type=int, default=1, 95 | help='factor, will compute NN in factored images') 96 | parser.add_argument("--swd_factor_ref", type=int, default=1, 97 | help='factor, will compute NN in factored images') 98 | parser.add_argument("--swd_loss_gain_ref", type=float, default=1, 99 | help='alpha, bigger than 100 is equivalent to None') 100 | 101 | # pyramid configuration 102 | parser.add_argument("--pyr_stage", type=str, default='', 103 | help='x,y,z,... iteration to upsample') 104 | parser.add_argument("--pyr_minimal_dim", type=int, default=60, 105 | help='if > 0, will determine the pyr_stage') 106 | parser.add_argument("--pyr_num_epoch", type=int, default=600, 107 | help='iter num in each level') 108 | parser.add_argument("--pyr_factor", type=float, default=0.5, 109 | help='factor in each pyr level') 110 | parser.add_argument("--pyr_init_level", type=int, default=-1, 111 | help='before that, use mse') 112 | 113 | # for mpi 114 | parser.add_argument("--sparsify_epoch", type=int, default=-1, 115 | help='sparsify the MPMesh in epoch') 116 | parser.add_argument("--sparsify_rmfirstlayer", type=int, default=0, 117 | help='if true, will remove the first #i layer') 118 | parser.add_argument("--sparsify_erode", type=int, default=2, 119 | help='iters to dilate the alpha channel') 120 | parser.add_argument("--learn_loop_mask", action='store_true', 121 | help='if true, will learn a loop_mask jointly') 122 | 123 | parser.add_argument("--direct2sh_epoch", type=int, default=-1, 124 | help='converting direct to sh, unused now') 125 | parser.add_argument("--sparsify_alpha_thresh", type=float, default=0.03, 126 | help='alpha thresh for tile culling') 127 | parser.add_argument("--vid2img_mode", type=str, default='average', 128 | help='choose among average, median, static, dynamic') 129 | parser.add_argument("--mpi_h_scale", type=float, default=1, 130 | help='the height of the stored MPI is ') 131 | parser.add_argument("--mpi_w_scale", type=float, default=1, 132 | help='the width of the stored MPI is ') 133 | parser.add_argument("--mpi_h_verts", type=int, default=12, 134 | help='number of vertices, decide the tile size') 135 | parser.add_argument("--mpi_w_verts", type=int, default=15, 136 | help='number of vertices, decide the tile size') 137 | parser.add_argument("--mpi_d", type=int, default=64, 138 | help='number of the MPI layer') 139 | parser.add_argument("--atlas_grid_h", type=int, default=8, 140 | help='atlas_grid_h * atlas_grid_w == mpi_d') 141 | parser.add_argument("--atlas_size_scale", type=float, default=1, 142 | help='atlas_size = mpi_d * H * W * atlas_size_scale') 143 | parser.add_argument("--atlas_cnl", type=int, default=4, 144 | help='channel num, currently not setable, much be 4') 145 | parser.add_argument("--model_type", type=str, default="MPMesh", 146 | help='currently not setable, much be MPMesh') 147 | parser.add_argument("--rgb_mlp_type", type=str, default='direct', 148 | help='not used, must be direct') 149 | parser.add_argument("--rgb_activate", type=str, default='sigmoid', 150 | help='activate function for rgb output, choose among "none", "sigmoid"') 151 | parser.add_argument("--alpha_activate", type=str, default='sigmoid', 152 | help='activate function for alpha output, choose among "none", "sigmoid"') 153 | parser.add_argument("--optimize_geo_start", type=int, default=10000000, 154 | help='iteration to start optimizing verts and uvs, currently not used') 155 | parser.add_argument("--optimize_verts_gain", type=float, default=1, 156 | help='set 0 to disable the vertices optimization') 157 | parser.add_argument("--normalize_verts", action='store_true', 158 | help='if true, the parameter is normalized') 159 | 160 | # about training 161 | parser.add_argument("--upsample_stage", type=str, default="", 162 | help='x,y,z,... stage to perform upsampling') 163 | parser.add_argument("--rgb_smooth_loss_weight", type=float, default=0, 164 | help='rgb spatial smooth loss') 165 | parser.add_argument("--a_smooth_loss_weight", type=float, default=0, 166 | help='alpha spatial smooth loss') 167 | parser.add_argument("--d_smooth_loss_weight", type=float, default=0, 168 | help='depth smooth loss') 169 | parser.add_argument("--l_smooth_loss_weight", type=float, default=0, 170 | help='loop mask (label) smooth loss') 171 | parser.add_argument("--edge_scale", type=float, default=4, 172 | help='edge aware smooth loss, 0 to disable edge aware') 173 | parser.add_argument("--normalize_blendweight_fordepth", action='store_true', 174 | help='edge aware smooth loss, 0 to disable edge aware') 175 | parser.add_argument("--density_loss_weight", type=float, default=0, 176 | help='density loss') 177 | parser.add_argument("--density_loss_epoch", type=int, default=0, 178 | help='gradually grow the density to epoch') 179 | parser.add_argument("--sparsity_loss_weight", type=float, default=0, 180 | help='sparsity loss weight') 181 | 182 | # training options 183 | parser.add_argument("--N_iters", type=int, default=30) 184 | parser.add_argument("--optimizer", type=str, default='adam', choices=['adam', 'sgd'], 185 | help='optmizer') 186 | parser.add_argument("--patch_h_size", type=int, default=512, 187 | help='patch size for each iteration') 188 | parser.add_argument("--patch_w_size", type=int, default=512, 189 | help='patch size for each iteration') 190 | parser.add_argument("--patch_h_stride", type=int, default=128, 191 | help='stride size for each iteration') 192 | parser.add_argument("--patch_w_stride", type=int, default=128, 193 | help='stride size for each iteration') 194 | parser.add_argument("--lrate", type=float, default=5e-4, 195 | help='learning rate') 196 | parser.add_argument("--lrate_adaptive", action='store_true', 197 | help='adaptively adjust learning rate based on patch size, or it will generate noise') 198 | parser.add_argument("--lrate_decay", type=int, default=30, 199 | help='exponential learning rate decay (in 1000 steps)') 200 | 201 | # logging options 202 | parser.add_argument("--i_img", type=int, default=300, 203 | help='frequency of tensorboard image logging') 204 | parser.add_argument("--i_print", type=int, default=300, 205 | help='frequency of console printout and metric loggin') 206 | parser.add_argument("--i_weights", type=int, default=20000, 207 | help='frequency of weight ckpt saving') 208 | parser.add_argument("--i_video", type=int, default=10000, 209 | help='frequency of render_poses video saving') 210 | 211 | # multiprocess learning 212 | parser.add_argument("--gpu_num", type=int, default='-1', 213 | help='number of processes, currently only support 1 gpu') 214 | return parser 215 | -------------------------------------------------------------------------------- /configs/debug_mpmesh.txt: -------------------------------------------------------------------------------- 1 | gpu_num = 1 2 | # Dataset related 3 | prefix = D:\MSI_NB\source\data\VideoLoops 4 | expname = ustfallclose720p32layer_debug 5 | datadir = data/fall4720p 6 | expdir = meshlogs 7 | factor = 2 8 | l_smooth_loss_weight = 0.1 9 | # sparsify_epoch = 20000 10 | patch_h_size = 360 11 | patch_w_size = 640 -------------------------------------------------------------------------------- /configs/debug_video.txt: -------------------------------------------------------------------------------- 1 | gpu_num = 1 2 | # Dataset related 3 | prefix = D:\MSI_NB\source\data\VideoLoops 4 | expname = usttap_debug 5 | datadir = data/ustfallclose720p 6 | expdir = postrebuttal_mpv 7 | factor = 2 8 | seed = 2 9 | model_type = MPMesh 10 | 11 | # mpi configuration 12 | mpi_h_scale = 0.1 13 | mpi_w_scale = 0.1 14 | mpi_h_verts = 27 15 | mpi_w_verts = 48 16 | mpi_d = 2 17 | atlas_grid_h = 2 18 | atlas_size_scale = 1 19 | 20 | # Training related 21 | patch_h_size = 16 22 | patch_w_size = 16 23 | patch_h_stride = 15 24 | patch_w_stride = 15 25 | lrate = 0.5 26 | lrate_decay = 100 27 | lrate_adaptive 28 | rgb_mlp_type = direct 29 | rgb_activate = sigmoid 30 | alpha_activate = sigmoid 31 | 32 | sparsity_loss_weight = 0.004 33 | rgb_smooth_loss_weight = 0.2 34 | a_smooth_loss_weight = 0.2 35 | 36 | i_img = 20 37 | i_print = 10 38 | i_weight = 50 39 | i_video = 2 40 | 41 | # mpv configuration 42 | pyr_minimal_dim = 65 43 | pyr_num_epoch = 50 44 | pyr_factor = 0.75 45 | init_std = 0.02 46 | 47 | loss_ref_idx = 1,6 48 | swd_macro_block = 45 49 | swd_loss_gain_ref = 3.5 50 | loss_name_ref = gpnn_lm 51 | swd_alpha_ref = 10000 52 | swd_patch_size_ref = 5 53 | swd_patcht_size_ref = 3 54 | swd_stride_ref = 2 55 | swd_stridet_ref = 1 56 | swd_dist_fn_ref = mse 57 | swd_rou_ref = -2 58 | swd_scaling_ref = 0.1 59 | loss_name = gpnn_lm 60 | swd_alpha = 10000 61 | swd_patch_size = 3 62 | swd_patcht_size = 3 63 | swd_stride = 2 64 | swd_stridet = 1 65 | swd_dist_fn = mse 66 | swd_rou = -2 67 | swd_scaling = 0.1 68 | -------------------------------------------------------------------------------- /configs/mpi_base.txt: -------------------------------------------------------------------------------- 1 | gpu_num = 1 2 | # Dataset related 3 | prefix = /d1/scratch/PI/psander/data/VideoLoops 4 | expdir = mpis 5 | factor = 2 6 | seed = 2 7 | model_type = MPMesh 8 | 9 | # mpi configuration 10 | vid2img_mode = dynamic 11 | learn_loop_mask 12 | mpi_h_scale = 1.6 13 | mpi_w_scale = 1.6 14 | mpi_h_verts = 36 15 | mpi_w_verts = 64 16 | mpi_d = 32 17 | atlas_grid_h = 4 18 | atlas_size_scale = 1 19 | 20 | # Training related 21 | scale_invariant 22 | add_intrin_noise 23 | sparsify_epoch = 119 24 | sparsify_alpha_thresh = 0.05 25 | sparsify_erode = 2 26 | N_iters = 140 27 | patch_h_size = 180 28 | patch_w_size = 320 29 | patch_h_stride = 90 30 | patch_w_stride = 160 31 | lrate = 0.05 32 | lrate_decay = 100 33 | rgb_mlp_type = direct 34 | rgb_activate = sigmoid 35 | alpha_activate = sigmoid 36 | 37 | sparsity_loss_weight = 0.004 38 | rgb_smooth_loss_weight = 0.2 39 | a_smooth_loss_weight = 0.5 40 | density_loss_weight = 0.02 41 | density_loss_epoch = 60 42 | 43 | i_img = 50 44 | i_print = 10 45 | i_weight = 60 46 | i_video = 20 47 | -------------------------------------------------------------------------------- /configs/mpis/1017palm.txt: -------------------------------------------------------------------------------- 1 | expname = 1017palm 2 | datadir = data/1017palm720p 3 | near_factor = 0.98 4 | far_factor = 3 5 | sparsify_rmfirstlayer = 1 6 | 7 | -------------------------------------------------------------------------------- /configs/mpis/1017yuan.txt: -------------------------------------------------------------------------------- 1 | expname = 1017yuan720p 2 | datadir = data/1017yuan720p 3 | near_factor = 0.95 4 | far_factor = 4 5 | sparsify_rmfirstlayer = 1 6 | -------------------------------------------------------------------------------- /configs/mpis/1020rock.txt: -------------------------------------------------------------------------------- 1 | expname = 1020rock720p 2 | datadir = data/1020rock720p 3 | near_factor = 0.98 4 | far_factor = 3 5 | sparsify_rmfirstlayer = 2 6 | sparsify_alpha_thresh = 0.02 7 | -------------------------------------------------------------------------------- /configs/mpis/1020ustfall1.txt: -------------------------------------------------------------------------------- 1 | expname = 1020ustfall1720p 2 | datadir = data/1020ustfall1720p 3 | near_factor = 0.98 4 | far_factor = 1.5 5 | sparsify_rmfirstlayer = 1 6 | 7 | 8 | sparsity_loss_weight = 0.004 9 | rgb_smooth_loss_weight = 0.2 10 | a_smooth_loss_weight = 0.5 11 | density_loss_weight = 0.02 12 | density_loss_epoch = 60 -------------------------------------------------------------------------------- /configs/mpis/1020ustfall2.txt: -------------------------------------------------------------------------------- 1 | expname = 1020ustfall2720p 2 | datadir = data/1020ustfall2720p 3 | near_factor = 0.75 4 | far_factor = 1.3 5 | sparsify_rmfirstlayer = 4 6 | -------------------------------------------------------------------------------- /configs/mpis/108fall1.txt: -------------------------------------------------------------------------------- 1 | expname = 108fall1 2 | datadir = data/fall1narrow720p 3 | near_factor = 0.90 4 | far_factor = 2 5 | sparsify_rmfirstlayer = 4 6 | 7 | 8 | -------------------------------------------------------------------------------- /configs/mpis/108fall2.txt: -------------------------------------------------------------------------------- 1 | expname = 108fall2 2 | datadir = data/fall2720p 3 | near_factor = 0.90 4 | far_factor = 1.5 5 | sparsify_rmfirstlayer = 4 6 | sparsify_alpha_thresh = 0.02 7 | -------------------------------------------------------------------------------- /configs/mpis/108fall3.txt: -------------------------------------------------------------------------------- 1 | expname = 108fall3 2 | datadir = data/fall3720p 3 | near_factor = 0.98 4 | far_factor = 5 5 | sparsify_rmfirstlayer = 0 6 | sparsify_alpha_thresh = 0.03 7 | -------------------------------------------------------------------------------- /configs/mpis/108fall4.txt: -------------------------------------------------------------------------------- 1 | expname = 108fall4 2 | datadir = data/fall4720p 3 | near_factor = 0.95 4 | far_factor = 1.2 5 | sparsify_rmfirstlayer = 2 6 | sparsify_alpha_thresh = 0.05 7 | -------------------------------------------------------------------------------- /configs/mpis/108fall5.txt: -------------------------------------------------------------------------------- 1 | expname = 108fall5 2 | datadir = data/fall5720p 3 | near_factor = 0.98 4 | far_factor = 2 5 | sparsify_rmfirstlayer = 0 6 | sparsify_alpha_thresh = 0.03 7 | -------------------------------------------------------------------------------- /configs/mpis/1101grass.txt: -------------------------------------------------------------------------------- 1 | expname = 1101grass 2 | datadir = data/1101grass720p 3 | near_factor = 0.99 4 | far_factor = 2.5 5 | sparsify_rmfirstlayer = 1 6 | 7 | -------------------------------------------------------------------------------- /configs/mpis/1101towerd.txt: -------------------------------------------------------------------------------- 1 | expname = 1101towerd 2 | datadir = data/1101towerd720p 3 | near_factor = 0.95 4 | far_factor = 2 5 | sparsify_rmfirstlayer = 2 6 | sparsify_alpha_thresh = 0.05 7 | 8 | sparsity_loss_weight = 0.001 9 | -------------------------------------------------------------------------------- /configs/mpis/110grasstree.txt: -------------------------------------------------------------------------------- 1 | expname = 110grasstree 2 | datadir = data/grasstree720p 3 | near_factor = 0.95 4 | far_factor = 5 5 | sparsify_rmfirstlayer = 2 6 | 7 | -------------------------------------------------------------------------------- /configs/mpis/110pillarrm.txt: -------------------------------------------------------------------------------- 1 | expname = 110pillar 2 | datadir = data/pillar720p 3 | near_factor = 0.9 4 | far_factor = 1.2 5 | sparsify_rmfirstlayer = 4 6 | sparsify_alpha_thresh = 0.03 7 | -------------------------------------------------------------------------------- /configs/mpis/ustfallclose.txt: -------------------------------------------------------------------------------- 1 | expname = ustfallclose 2 | datadir = data/ustfallclose720p 3 | near_factor = 1 4 | far_factor = 1.2 5 | sparsify_rmfirstlayer = 1 6 | -------------------------------------------------------------------------------- /configs/mpis/usttap.txt: -------------------------------------------------------------------------------- 1 | expname = usttap720p 2 | datadir = data/usttap720p 3 | near_factor = 0.95 4 | far_factor = 1.5 5 | sparsify_rmfirstlayer = 3 6 | sparsify_alpha_thresh = 0.03 7 | -------------------------------------------------------------------------------- /configs/mpv_base.txt: -------------------------------------------------------------------------------- 1 | gpu_num = 1 2 | # Dataset related 3 | prefix = /d1/scratch/PI/psander/data/VideoLoops 4 | expdir = mpvs 5 | factor = 2 6 | seed = 2 7 | model_type = MPMesh 8 | 9 | # mpi configuration, not important since we load from ckpt 10 | mpi_h_scale = 1.1 11 | mpi_w_scale = 1.1 12 | mpi_h_verts = 27 13 | mpi_w_verts = 48 14 | mpi_d = 32 15 | atlas_grid_h = 4 16 | atlas_size_scale = 1 17 | 18 | # Training related 19 | scale_invariant 20 | add_intrin_noise 21 | patch_h_size = 180 22 | patch_h_stride = 90 23 | patch_w_size = 320 24 | patch_w_stride = 160 25 | lrate = 0.5 26 | lrate_decay = 100 27 | lrate_adaptive 28 | rgb_mlp_type = direct 29 | rgb_activate = sigmoid 30 | alpha_activate = sigmoid 31 | 32 | sparsity_loss_weight = 0 33 | rgb_smooth_loss_weight = 0.2 34 | a_smooth_loss_weight = 0.2 35 | 36 | i_img = 20 37 | i_print = 10 38 | i_weight = 50 39 | i_video = 50 40 | 41 | # mpv configuration 42 | pyr_minimal_dim = 65 43 | pyr_num_epoch = 50 44 | pyr_factor = 0.75 45 | init_std = 0.02 46 | mpv_isloop 47 | 48 | 49 | swd_macro_block = 65 50 | swd_loss_gain_ref = 3.5 51 | loss_name_ref = gpnn_lm 52 | swd_alpha_ref = 0 53 | swd_patch_size_ref = 11 54 | swd_patcht_size_ref = 3 55 | swd_stride_ref = 4 56 | swd_stridet_ref = 1 57 | swd_dist_fn_ref = mse 58 | swd_rou_ref = -2 59 | swd_scaling_ref = 0.1 60 | loss_name = gpnn_lm 61 | swd_alpha = 10000 62 | swd_patch_size = 3 63 | swd_patcht_size = 3 64 | swd_stride = 2 65 | swd_stridet = 1 66 | swd_dist_fn = mse 67 | swd_rou = -2 68 | swd_scaling = 0.1 69 | 70 | -------------------------------------------------------------------------------- /configs/mpvs/1017palm.txt: -------------------------------------------------------------------------------- 1 | expname = 1017palm 2 | datadir = data/1017palm720p 3 | init_from = mpis/1017palm/epoch_0119.tar 4 | mpv_frm_num = 60 5 | test_view_idx = 0 6 | 7 | loss_ref_idx = 1,6 -------------------------------------------------------------------------------- /configs/mpvs/1017yuan.txt: -------------------------------------------------------------------------------- 1 | expname = 1017yuan 2 | datadir = data/1017yuan720p 3 | init_from = mpis/1017yuan720p/epoch_0119.tar 4 | mpv_frm_num = 50 5 | test_view_idx = 1 6 | 7 | loss_ref_idx = 2,6 -------------------------------------------------------------------------------- /configs/mpvs/1020rock.txt: -------------------------------------------------------------------------------- 1 | expname = 1020rock 2 | datadir = data/1020rock720p 3 | init_from = mpis/1020rock720p/epoch_0119.tar 4 | mpv_frm_num = 50 5 | test_view_idx = 0 6 | 7 | loss_ref_idx = 1,6 -------------------------------------------------------------------------------- /configs/mpvs/1020ustfall1.txt: -------------------------------------------------------------------------------- 1 | expname = 1020ustfall1 2 | datadir = data/1020ustfall1720p 3 | init_from = mpis/1020ustfall1720p/epoch_0119.tar 4 | mpv_frm_num = 60 5 | test_view_idx = 4 6 | 7 | loss_ref_idx = 3,7 -------------------------------------------------------------------------------- /configs/mpvs/1020ustfall2.txt: -------------------------------------------------------------------------------- 1 | expname = 1020ustfall2 2 | datadir = data/1020ustfall2720p 3 | init_from = mpis/1020ustfall2720p/epoch_0119.tar 4 | mpv_frm_num = 50 5 | test_view_idx = 1 6 | 7 | loss_ref_idx = 0,5 -------------------------------------------------------------------------------- /configs/mpvs/108fall1.txt: -------------------------------------------------------------------------------- 1 | expname = 108fall1 2 | datadir = data/fall1narrow720p 3 | init_from = mpis/108fall1/epoch_0119.tar 4 | mpv_frm_num = 50 5 | test_view_idx = 4 6 | 7 | loss_ref_idx = 2,5 -------------------------------------------------------------------------------- /configs/mpvs/108fall2.txt: -------------------------------------------------------------------------------- 1 | expname = 108fall2 2 | datadir = data/fall2720p 3 | init_from = mpis/108fall2/epoch_0119.tar 4 | mpv_frm_num = 50 5 | test_view_idx = 4 6 | 7 | loss_ref_idx = 0,6 -------------------------------------------------------------------------------- /configs/mpvs/108fall3.txt: -------------------------------------------------------------------------------- 1 | expname = 108fall3 2 | datadir = data/fall3720p 3 | init_from = mpis/108fall3/epoch_0119.tar 4 | mpv_frm_num = 50 5 | test_view_idx = 4 6 | 7 | loss_ref_idx = 0,6 -------------------------------------------------------------------------------- /configs/mpvs/108fall4.txt: -------------------------------------------------------------------------------- 1 | expname = 108fall4 2 | datadir = data/fall4720p 3 | init_from = mpis/108fall4/epoch_0119.tar 4 | mpv_frm_num = 50 5 | test_view_idx = 4 6 | 7 | loss_ref_idx = 1,8 -------------------------------------------------------------------------------- /configs/mpvs/108fall5.txt: -------------------------------------------------------------------------------- 1 | expname = 108fall5 2 | datadir = data/fall5720p 3 | init_from = mpis/108fall5/epoch_0119.tar 4 | mpv_frm_num = 50 5 | test_view_idx = 4 6 | 7 | loss_ref_idx = 3,8 -------------------------------------------------------------------------------- /configs/mpvs/1101grass.txt: -------------------------------------------------------------------------------- 1 | expname = 1101grass 2 | datadir = data/1101grass720p 3 | init_from = mpis/1101grass/epoch_0119.tar 4 | mpv_frm_num = 50 5 | test_view_idx = 5 6 | 7 | loss_ref_idx = 3,9 -------------------------------------------------------------------------------- /configs/mpvs/1101towerd.txt: -------------------------------------------------------------------------------- 1 | expname = 1101towerd 2 | datadir = data/1101towerd720p 3 | init_from = mpis/1101towerd/epoch_0119.tar 4 | mpv_frm_num = 60 5 | test_view_idx = 6 6 | 7 | loss_ref_idx = 3,8 -------------------------------------------------------------------------------- /configs/mpvs/110grasstree.txt: -------------------------------------------------------------------------------- 1 | expname = 110grasstree 2 | datadir = data/grasstree720p 3 | init_from = mpis/110grasstree/epoch_0119.tar 4 | mpv_frm_num = 50 5 | test_view_idx = 4 6 | 7 | loss_ref_idx = 1,7 -------------------------------------------------------------------------------- /configs/mpvs/110pillar.txt: -------------------------------------------------------------------------------- 1 | expname = 110pillar 2 | datadir = data/pillar720p 3 | init_from = mpis/110pillar/epoch_0119.tar 4 | mpv_frm_num = 50 5 | test_view_idx = 3 6 | 7 | loss_ref_idx = 2,8 -------------------------------------------------------------------------------- /configs/mpvs/ustfallclose.txt: -------------------------------------------------------------------------------- 1 | expname = ustfallclose 2 | datadir = data/ustfallclose720p 3 | init_from = mpis/ustfallclose/epoch_0119.tar 4 | mpv_frm_num = 60 5 | test_view_idx = 1 6 | 7 | loss_ref_idx = 0,2 -------------------------------------------------------------------------------- /configs/mpvs/usttap.txt: -------------------------------------------------------------------------------- 1 | expname = usttap 2 | datadir = data/usttap720p 3 | init_from = mpis/usttap720p/epoch_0119.tar 4 | mpv_frm_num = 60 5 | test_view_idx = 4 6 | 7 | loss_ref_idx = 0,2 -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import glob 5 | import imageio 6 | import numpy as np 7 | 8 | 9 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): 10 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) 11 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) 12 | bds = poses_arr[:, -2:].transpose([1, 0]) 13 | 14 | # img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ 15 | # if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] 16 | # sh = imageio.imread(img0).shape 17 | 18 | sfx = '' 19 | 20 | if factor is not None: 21 | sfx = '_{}'.format(factor) 22 | factor = factor 23 | else: 24 | factor = 1 25 | 26 | poses[:2, 4, :] = poses[:2, 4, :] / factor # hw 27 | poses[2, 4, :] = poses[2, 4, :] / factor # intrin 28 | 29 | if not load_imgs: 30 | return poses, bds, None 31 | 32 | imgdir = os.path.join(basedir, 'images' + sfx) 33 | if not os.path.exists(imgdir): 34 | print(imgdir, 'does not exist, returning') 35 | return 36 | 37 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if 38 | f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] 39 | if poses.shape[-1] != len(imgfiles): 40 | print('Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1])) 41 | return 42 | 43 | def imread(f): 44 | if f.endswith('png'): 45 | return imageio.imread(f, ignoregamma=True) 46 | else: 47 | return imageio.imread(f) 48 | 49 | imgs = [imread(f)[..., :3] / 255. for f in imgfiles] 50 | imgs = np.stack(imgs, -1) 51 | 52 | print('Loaded image data', imgs.shape, poses[:, -1, 0]) 53 | return poses, bds, imgs 54 | 55 | 56 | def load_llff_data(basedir, factor=8, recenter=True, bd_factor=(1, 1), spherify=False, path_epi=False, 57 | load_img=True, render_frm=120, render_scaling=1.): 58 | poses, bds, imgs = _load_data(basedir, factor=factor, load_imgs=load_img) 59 | # factor=8 downsamples original imgs by 8x 60 | print('Loaded', basedir, bds.min(), bds.max()) 61 | # for debug only 62 | # selected_idx = [0, 1, 2, 9, 8, 7, 10, 11, 19, 18] 63 | # imgs = imgs[..., selected_idx] 64 | # poses = poses[..., selected_idx] 65 | # bds = bds[..., selected_idx] 66 | 67 | # Correct rotation matrix ordering and move variable dim to axis 0 68 | poses = np.concatenate([poses[:, 1:2, :], poses[:, 0:1, :], -poses[:, 2:3, :], poses[:, 3:, :]], 1) 69 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 70 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) if imgs is not None else None 71 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 72 | 73 | # Rescale if bd_factor is provided 74 | bds = np.array([bds.min(), bds.max()]).astype(poses.dtype) 75 | sc = 1. / bds[0] 76 | poses[:, :3, 3] *= sc 77 | bds *= sc 78 | if bd_factor is not None: 79 | bds *= bd_factor 80 | 81 | if recenter: 82 | poses = recenter_poses(poses) 83 | 84 | # generate render_poses for video generation 85 | if spherify: 86 | poses, render_poses, bds = spherify_poses(poses, bds) 87 | 88 | else: 89 | c2w = poses_avg(poses) 90 | print('recentered', c2w.shape) 91 | print(c2w[:3, :4]) 92 | 93 | ## Get spiral 94 | # Get average pose 95 | up = normalize(poses[:, :3, 1].sum(0)) 96 | 97 | # Find a reasonable "focus depth" for this dataset 98 | close_depth, inf_depth = bds.min() * .9, bds.max() * 5. 99 | dt = .75 100 | mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth)) 101 | focal = mean_dz 102 | # Get radii for spiral path 103 | shrink_factor = .8 104 | zdelta = close_depth * .2 105 | tt = poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T 106 | rads = np.abs(tt).max(0) * 0.8 * render_scaling 107 | # rads = np.percentile(np.abs(tt), 90, 0) 108 | c2w_path = c2w 109 | N_views = render_frm 110 | N_rots = 2 111 | # Generate poses for spiral path 112 | # rads = [0.7, 0.2, 0.7] 113 | render_poses = render_path_spiral(c2w_path, up, rads, focal, zrate=.5, zdelta=zdelta, rots=N_rots, N=N_views) 114 | 115 | if path_epi: 116 | # zloc = np.percentile(tt, 10, 0)[2] 117 | rads[0] = rads[0] / 2 118 | render_poses = render_path_epi(c2w_path, up, rads[0], N_views) 119 | 120 | render_poses = np.array(render_poses).astype(np.float32) 121 | 122 | poses = poses.astype(np.float32) 123 | 124 | H, W, focal = poses[:, :3, -1].T 125 | poses = poses[:, :3, :4] 126 | intrins = np.zeros_like(poses[:, :3, :3]) 127 | intrins[:, -1, -1] = 1 128 | intrins[:, 0, 0] = focal 129 | intrins[:, 1, 1] = focal 130 | intrins[:, 0, 2] = 0.5 * W 131 | intrins[:, 1, 2] = 0.5 * H 132 | 133 | render_intrins = np.repeat(intrins[:1, ...], len(render_poses), 0) 134 | return imgs, poses, intrins, bds, render_poses, render_intrins 135 | 136 | 137 | def load_mv_videos(basedir, factor=1, recenter=True, bd_factor=(1, 1), render_frm=120, render_scaling=1): 138 | _, poses, intrins, bds, render_poses, render_intrins = load_llff_data(basedir, factor, recenter, 139 | bd_factor=bd_factor, 140 | load_img=False, 141 | render_frm=render_frm, 142 | render_scaling=render_scaling) 143 | videos_path = sorted(glob.glob(basedir + f"/videos_{factor}/*")) 144 | videos = [imageio.mimread(vp, memtest=False) for vp in videos_path] 145 | cap = cv2.VideoCapture(videos_path[0]) 146 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 147 | return videos, fps, poses, intrins, bds, render_poses, render_intrins 148 | 149 | 150 | def load_masks(imgpaths): 151 | msklist = [] 152 | 153 | for imgdir in imgpaths: 154 | mskdir = imgdir.replace('images', 'masks').replace('.jpg', '.png') 155 | msk = imageio.imread(mskdir) 156 | 157 | H, W = msk.shape[0:2] 158 | msk = msk / 255.0 159 | 160 | # msk = np.sum(msk, axis=2) 161 | # msk[msk < 3.0] = 0.0 162 | # msk[msk == 3.0] = 1.0 163 | # msk = 1.0 - msk 164 | 165 | newmsk = np.zeros((H, W), dtype=np.float32) 166 | newmsk[np.logical_and((msk[:, :, 0] == 0), (msk[:, :, 1] == 0), (msk[:, :, 2] == 1.0))] = 1.0 167 | 168 | # imageio.imwrite('newmsk.png', newmsk) 169 | # print(imgpaths, mskdir, H, W) 170 | # print(sss) 171 | 172 | msklist.append(newmsk) 173 | 174 | msklist = np.stack(msklist, 0) 175 | 176 | return msklist 177 | 178 | 179 | def has_matted(imgpaths): 180 | exampledir = imgpaths[-1].replace('images', 'images_rgba').replace('.jpg', '.png') 181 | return os.path.exists(exampledir) 182 | 183 | 184 | def load_matted(imgpaths): 185 | imglist = [] 186 | for imgdir in imgpaths: 187 | imgdir = imgdir.replace('images', 'images_rgba').replace('.jpg', '.png') 188 | rgba = imageio.imread(imgdir) 189 | assert rgba.shape[-1] == 4, "cannot load rgba png" 190 | rgba = rgba / 255.0 191 | rgba[..., :3] = rgba[..., :3] * rgba[..., 3:4] 192 | imglist.append(rgba) 193 | 194 | imglist = np.stack(imglist, 0) 195 | return imglist 196 | 197 | 198 | def load_images(imgpaths): 199 | imglist = [] 200 | 201 | for imgdir in imgpaths: 202 | img = imageio.imread(imgdir) 203 | img = img / 255.0 204 | imglist.append(img) 205 | 206 | imglist = np.stack(imglist, 0) 207 | 208 | return imglist 209 | 210 | 211 | def normalize(x): 212 | return x / np.linalg.norm(x) 213 | 214 | 215 | def viewmatrix(z, up, pos): 216 | vec2 = normalize(z) 217 | vec1_avg = up 218 | vec0 = normalize(np.cross(vec1_avg, vec2)) 219 | vec1 = normalize(np.cross(vec2, vec0)) 220 | m = np.stack([vec0, vec1, vec2, pos], 1) 221 | return m 222 | 223 | 224 | def poses_avg(poses): 225 | hwf = poses[0, :3, -1:] 226 | 227 | center = poses[:, :3, 3].mean(0) 228 | vec2 = normalize(poses[:, :3, 2].sum(0)) 229 | up = poses[:, :3, 1].sum(0) 230 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 231 | 232 | return c2w 233 | 234 | 235 | def recenter_poses(poses): 236 | poses_ = poses + 0 237 | bottom = np.reshape([0, 0, 0, 1.], [1, 4]) 238 | c2w = poses_avg(poses) 239 | c2w = np.concatenate([c2w[:3, :4], bottom], -2) 240 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) 241 | poses = np.concatenate([poses[:, :3, :4], bottom], -2) 242 | 243 | poses = np.linalg.inv(c2w) @ poses 244 | poses_[:, :3, :4] = poses[:, :3, :4] 245 | poses = poses_ 246 | return poses 247 | 248 | 249 | def render_path_spiral(c2w, up, rads, focal, zrate, zdelta, rots, N): 250 | render_poses = [] 251 | rads = np.array(list(rads) + [1.]) 252 | 253 | for theta in np.linspace(0., 2. * np.pi * rots, N + 1)[:-1]: 254 | # view direction 255 | # c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads) 256 | c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), (np.cos(theta * zrate) * zdelta) ** 2, 1.]) * rads) 257 | # camera poses 258 | z = normalize(np.array([0, 0, focal] - c)) 259 | render_poses.append(viewmatrix(z, up, c)) 260 | return np.stack(render_poses) 261 | -------------------------------------------------------------------------------- /evaluations/C3D_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class C3D(nn.Module): 5 | """ 6 | The C3D network as described in [1]. 7 | """ 8 | 9 | def __init__(self): 10 | super(C3D, self).__init__() 11 | 12 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 13 | self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) 14 | 15 | self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 16 | self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 17 | 18 | self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 19 | self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 20 | self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 21 | 22 | self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 23 | self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 24 | self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 25 | 26 | self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 27 | self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 28 | self.pool5 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1)) 29 | 30 | self.fc6 = nn.Linear(8192, 4096) 31 | self.fc7 = nn.Linear(4096, 4096) 32 | self.fc8 = nn.Linear(4096, 487) 33 | 34 | self.dropout = nn.Dropout(p=0.5) 35 | 36 | self.relu = nn.ReLU() 37 | self.softmax = nn.Softmax() 38 | 39 | def forward(self, x): 40 | 41 | h = self.relu(self.conv1(x)) 42 | h = self.pool1(h) 43 | 44 | h = self.relu(self.conv2(h)) 45 | h = self.pool2(h) 46 | 47 | h = self.relu(self.conv3a(h)) 48 | h = self.relu(self.conv3b(h)) 49 | h = self.pool3(h) 50 | 51 | h = self.relu(self.conv4a(h)) 52 | h = self.relu(self.conv4b(h)) 53 | h = self.pool4(h) 54 | 55 | h = self.relu(self.conv5a(h)) 56 | h = self.relu(self.conv5b(h)) 57 | h = self.pool5(h) 58 | 59 | # h = h.view(-1, 8192) 60 | # h = self.relu(self.fc6(h)) 61 | # h = self.dropout(h) 62 | # h = self.relu(self.fc7(h)) 63 | # h = self.dropout(h) 64 | # 65 | # logits = self.fc8(h) 66 | # probs = self.softmax(logits) 67 | 68 | return h 69 | 70 | """ 71 | References 72 | ---------- 73 | [1] Tran, Du, et al. "Learning spatiotemporal features with 3d convolutional networks." 74 | Proceedings of the IEEE international conference on computer vision. 2015. 75 | """ -------------------------------------------------------------------------------- /evaluations/LPIPS.py: -------------------------------------------------------------------------------- 1 | from .lpips.lpips import LPIPS 2 | import numpy as np 3 | 4 | LPIPS_network = None 5 | 6 | 7 | def _prepare_lpips(src, tar): 8 | global LPIPS_network 9 | if LPIPS_network is None: 10 | LPIPS_network = LPIPS() 11 | 12 | LPIPS_network.to(src.device) 13 | src = src.permute(0, 3, 1, 2) / (255 / 2) - 1 14 | tar = tar.permute(0, 3, 1, 2) / (255 / 2) - 1 15 | return src, tar 16 | 17 | 18 | def compute_lpips(src, tar): 19 | """ 20 | src/tar: tensor of F x H x W x 3, in (0, 255), rgb 21 | """ 22 | global LPIPS_network 23 | src, tar = _prepare_lpips(src, tar) 24 | 25 | def compute_one_frame(frame, tar): 26 | scores = [LPIPS_network(frame, tar_[None]).item() for tar_ in tar] 27 | return min(scores) 28 | 29 | err = [compute_one_frame(f[None], tar) for f in src] 30 | return np.array(err).mean() 31 | 32 | 33 | def compute_lpips_slidewindow(src, tar): 34 | """ 35 | src/tar: tensor of F x H x W x 3, in (0, 255), rgb 36 | """ 37 | global LPIPS_network 38 | if len(src) > len(tar): 39 | src, tar = tar, src 40 | src, tar = _prepare_lpips(src, tar) 41 | 42 | def compute_aligned_lpips(s, t): 43 | scores = [LPIPS_network(sf[None], tf[None]).item() for sf, tf in zip(s, t)] 44 | return np.mean(scores) 45 | 46 | err = [compute_aligned_lpips(src, tar[i: i + len(src)]) for i in range(len(tar) - len(src))] 47 | return np.array(err).min() 48 | -------------------------------------------------------------------------------- /evaluations/NNMSE.py: -------------------------------------------------------------------------------- 1 | from utils_vid import extract_3Dpatches, get_NN_indices_low_memory 2 | import warnings 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def compute_nnerr(src, tar, 8 | patch_size=7, stride=2, patcht_size=7, stridet=2, 9 | macro_block=65): 10 | """ 11 | x, tar: shape of B x 3 x f x h x w 12 | """ 13 | # standardlize the input 14 | t, h, w = src.shape[-3:] 15 | 16 | def fit_patch(s_, name, p_, st_): 17 | if (s_ - p_) % st_ != 0: 18 | new_s_ = (s_ - p_) // st_ * st_ + p_ 19 | warnings.warn(f'{name} doesnot satisfy ({name} - patch_size) % stride == 0. ' 20 | f'changing {name} from {s_} to {new_s_}') 21 | return new_s_ 22 | return s_ 23 | 24 | macro_block = fit_patch(macro_block, "macro_block", patch_size, stride) 25 | h = fit_patch(h, "patch_height", patch_size, stride) 26 | w = fit_patch(w, "patch_width", patch_size, stride) 27 | t = fit_patch(t, "frame_num", patcht_size, stridet) 28 | src = src[..., :t, :h, :w] 29 | tar = tar[..., :h, :w] 30 | 31 | with torch.no_grad(): 32 | macro_stride = macro_block - patch_size + stride 33 | h_starts = np.arange(0, h - macro_block + macro_stride, macro_stride) 34 | w_starts = np.arange(0, w - macro_block + macro_stride, macro_stride) 35 | errs = [] 36 | for h_start in h_starts: 37 | # if h - h_start < patch_size: # this checking is nolonger needed due to the fit_patch 38 | # h_start -= patch_size 39 | for w_start in w_starts: 40 | # if w - w_start < patch_size: 41 | # w_start -= patch_size 42 | src_crop = src[..., h_start: h_start + macro_block, w_start: w_start + macro_block] 43 | tar_crop = tar[..., h_start: h_start + macro_block, w_start: w_start + macro_block] 44 | # partation input into different patches and process individually 45 | projsrc = extract_3Dpatches(src_crop, patch_size, patcht_size, stride, stridet) # b, c, d, h, w 46 | b, c, d, h, w = projsrc.shape 47 | B = b * h * w 48 | D = d * h * w 49 | projsrc = projsrc.permute(0, 3, 4, 2, 1).reshape(B, -1, 3, patcht_size, patch_size, patch_size) 50 | projtar = extract_3Dpatches(tar_crop, patch_size, patcht_size, stride, stridet) # b, c, d, h, w 51 | projtar = projtar.permute(0, 3, 4, 2, 1).reshape(B, -1, 3, patcht_size, patch_size, patch_size) 52 | nns = get_NN_indices_low_memory(projsrc, projtar, None, 1024) 53 | projtar2src = projtar[torch.arange(B, device=nns.device)[:, None], nns] 54 | 55 | err = (projtar2src - projsrc).abs().mean().item() 56 | errs.append(err) 57 | 58 | return np.array(errs).mean() 59 | -------------------------------------------------------------------------------- /evaluations/SVFID.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .C3D_model import C3D 4 | from scipy import linalg 5 | 6 | 7 | C3D_network = None 8 | 9 | 10 | def calculate_batched_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 11 | diff = mu1 - mu2 12 | offset = np.eye(sigma1.shape[1])[None] * eps 13 | mats = (sigma1 + offset) @ (sigma2 + offset) 14 | # Product might be almost singular 15 | covmean = np.array([linalg.sqrtm(mat, disp=False)[0] for mat in mats]) 16 | 17 | # Numerical error might give slight imaginary component 18 | if np.iscomplexobj(covmean): 19 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 20 | m = np.max(np.abs(covmean.imag)) 21 | raise ValueError('Imaginary component {}'.format(m)) 22 | covmean = covmean.real 23 | 24 | tr_covmean = np.trace(covmean, axis1=1, axis2=2) 25 | 26 | return ((diff * diff).sum(axis=-1) + np.trace(sigma1, axis1=1, axis2=2) 27 | + np.trace(sigma2, axis1=1, axis2=2) - 2 * tr_covmean) 28 | 29 | 30 | def svfid(src, tar): 31 | """ 32 | src/tar: tensor of F x H x W x 3, in (0, 255), rgb 33 | """ 34 | global C3D_network 35 | if C3D_network is None: 36 | C3D_network = C3D() 37 | C3D_network.load_state_dict(torch.load('evaluations/c3d.pickle')) 38 | C3D_network.eval() 39 | 40 | C3D_network.to(src.device) 41 | with torch.no_grad(): 42 | src = src.permute(3, 0, 1, 2)[None] 43 | tar = tar.permute(3, 0, 1, 2)[None] # c, frm, h, w 44 | src_feat = C3D_network(src) 45 | tar_feat = C3D_network(tar) 46 | 47 | src_feat = src_feat[0, :50].permute(2, 3, 1, 0).flatten(0, 1) 48 | tar_feat = tar_feat[0, :50].permute(2, 3, 1, 0).flatten(0, 1) 49 | 50 | def batch_cov(points): 51 | B, N, D = points.size() 52 | mean = points.mean(dim=1, keepdims=True) 53 | diffs = (points - mean).reshape(B * N, D) 54 | prods = torch.bmm(diffs.unsqueeze(2), diffs.unsqueeze(1)).reshape(B, N, D, D) 55 | bcov = prods.sum(dim=1) / (N - 1) # Unbiased estimate 56 | return mean[:, 0], bcov # (B, D, D) 57 | 58 | src_mean, src_cov = batch_cov(src_feat) 59 | tar_mean, tar_cov = batch_cov(tar_feat) 60 | 61 | fid = calculate_batched_frechet_distance(src_mean.cpu().numpy(), 62 | src_cov.cpu().numpy(), 63 | tar_mean.cpu().numpy(), 64 | tar_cov.cpu().numpy()) 65 | return fid.mean() 66 | -------------------------------------------------------------------------------- /evaluations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limacv/VideoLoop3D/615c665a7e3ffaf079c4fb5c9cc596a3bcf4e136/evaluations/__init__.py -------------------------------------------------------------------------------- /evaluations/c3d_test.py: -------------------------------------------------------------------------------- 1 | """ How to use C3D network. """ 2 | import numpy as np 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | from os.path import join 8 | from glob import glob 9 | 10 | import skimage.io as io 11 | from skimage.transform import resize 12 | 13 | from C3D_model import C3D 14 | 15 | 16 | def get_sport_clip(clip_name, verbose=True): 17 | """ 18 | Loads a clip to be fed to C3D for classification. 19 | 20 | Parameters 21 | ---------- 22 | clip_name: str 23 | the name of the clip (subfolder in 'data'). 24 | verbose: bool 25 | if True, shows the unrolled clip (default is True). 26 | Returns 27 | ------- 28 | Tensor 29 | a pytorch batch (n, ch, fr, h, w). 30 | """ 31 | 32 | clip = sorted(glob(join('D:\\MSI_NB\\source\\repos\\c3d-pytorch\\data', clip_name, '*.png'))) 33 | clip = np.array([resize(io.imread(frame), output_shape=(112, 200), preserve_range=True) for frame in clip]) 34 | clip = clip[:, :, 44:44 + 112, :] # crop centrally 35 | 36 | if verbose: 37 | clip_img = np.reshape(clip.transpose(1, 0, 2, 3), (112, 16 * 112, 3)) 38 | io.imshow(clip_img.astype(np.uint8)) 39 | io.show() 40 | 41 | clip = clip.transpose(3, 0, 1, 2) # ch, fr, h, w 42 | clip = np.expand_dims(clip, axis=0) # batch axis 43 | clip = np.float32(clip) 44 | 45 | return torch.from_numpy(clip) 46 | 47 | 48 | def read_labels_from_file(filepath): 49 | """ 50 | Reads Sport1M labels from file 51 | 52 | Parameters 53 | ---------- 54 | filepath: str 55 | the file. 56 | 57 | Returns 58 | ------- 59 | list 60 | list of sport names. 61 | """ 62 | with open(filepath, 'r') as f: 63 | labels = [line.strip() for line in f.readlines()] 64 | return labels 65 | 66 | 67 | def main(): 68 | """ 69 | Main function. 70 | """ 71 | 72 | # load a clip to be predicted 73 | X = get_sport_clip('roger') 74 | X = Variable(X) 75 | X = X.cuda() 76 | 77 | # get network pretrained model 78 | net = C3D() 79 | net.load_state_dict(torch.load('c3d.pickle')) 80 | net.cuda() 81 | net.evaluate() 82 | 83 | # perform prediction 84 | prediction = net(X) 85 | prediction = prediction.data.cpu().numpy() 86 | 87 | # read labels 88 | labels = read_labels_from_file('labels.txt') 89 | 90 | # print top predictions 91 | top_inds = prediction[0].argsort()[::-1][:5] # reverse sort and take five largest items 92 | print('\nTop 5:') 93 | for i in top_inds: 94 | print('{:.5f} {}'.format(prediction[0][i], labels[i])) 95 | 96 | 97 | # entry point 98 | if __name__ == '__main__': 99 | main() -------------------------------------------------------------------------------- /evaluations/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import torch 7 | # from torch.autograd import Variable 8 | 9 | from .lpips import * 10 | -------------------------------------------------------------------------------- /evaluations/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torch.autograd import Variable 7 | import numpy as np 8 | from . import pretrained_networks as pn 9 | import torch.nn 10 | 11 | 12 | def normalize_tensor(in_feat, eps=1e-10): 13 | norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) 14 | return in_feat / (norm_factor + eps) 15 | 16 | 17 | def l2(p0, p1, range=255.): 18 | return .5 * np.mean((p0 / range - p1 / range) ** 2) 19 | 20 | 21 | def psnr(p0, p1, peak=255.): 22 | return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2)) 23 | 24 | 25 | def dssim(p0, p1, range=255.): 26 | from skimage.measure import compare_ssim 27 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 28 | 29 | 30 | def rgb2lab(in_img, mean_cent=False): 31 | from skimage import color 32 | img_lab = color.rgb2lab(in_img) 33 | if mean_cent: 34 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 35 | return img_lab 36 | 37 | 38 | def tensor2np(tensor_obj): 39 | # change dimension of a tensor object into a numpy array 40 | return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) 41 | 42 | 43 | def np2tensor(np_obj): 44 | # change dimenion of np array into tensor array 45 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 46 | 47 | 48 | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): 49 | # image tensor to lab tensor 50 | from skimage import color 51 | 52 | img = tensor2im(image_tensor) 53 | img_lab = color.rgb2lab(img) 54 | if mc_only: 55 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 56 | if to_norm and not mc_only: 57 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 58 | img_lab = img_lab / 100. 59 | 60 | return np2tensor(img_lab) 61 | 62 | 63 | def tensorlab2tensor(lab_tensor, return_inbnd=False): 64 | from skimage import color 65 | import warnings 66 | warnings.filterwarnings("ignore") 67 | 68 | lab = tensor2np(lab_tensor) * 100. 69 | lab[:, :, 0] = lab[:, :, 0] + 50 70 | 71 | rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1) 72 | if return_inbnd: 73 | # convert back to lab, see if we match 74 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 75 | mask = 1. * np.isclose(lab_back, lab, atol=2.) 76 | mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) 77 | return im2tensor(rgb_back), mask 78 | else: 79 | return im2tensor(rgb_back) 80 | 81 | 82 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.): 83 | image_numpy = image_tensor[0].cpu().float().numpy() 84 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 85 | return image_numpy.astype(imtype) 86 | 87 | 88 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): 89 | return torch.tensor((image / factor - cent) 90 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 91 | 92 | 93 | def tensor2vec(vector_tensor): 94 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 95 | 96 | 97 | def voc_ap(rec, prec, use_07_metric=False): 98 | """ ap = voc_ap(rec, prec, [use_07_metric]) 99 | Compute VOC AP given precision and recall. 100 | If use_07_metric is true, uses the 101 | VOC 07 11 point method (default:False). 102 | """ 103 | if use_07_metric: 104 | # 11 point metric 105 | ap = 0. 106 | for t in np.arange(0., 1.1, 0.1): 107 | if np.sum(rec >= t) == 0: 108 | p = 0 109 | else: 110 | p = np.max(prec[rec >= t]) 111 | ap = ap + p / 11. 112 | else: 113 | # correct AP calculation 114 | # first append sentinel values at the end 115 | mrec = np.concatenate(([0.], rec, [1.])) 116 | mpre = np.concatenate(([0.], prec, [0.])) 117 | 118 | # compute the precision envelope 119 | for i in range(mpre.size - 1, 0, -1): 120 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 121 | 122 | # to calculate area under PR curve, look for points 123 | # where X axis (recall) changes value 124 | i = np.where(mrec[1:] != mrec[:-1])[0] 125 | 126 | # and sum (\Delta recall) * prec 127 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 128 | return ap 129 | 130 | def spatial_average(in_tens, keepdim=True): 131 | return in_tens.mean([2, 3], keepdim=keepdim) 132 | 133 | 134 | def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W 135 | in_H, in_W = in_tens.shape[2], in_tens.shape[3] 136 | return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens) 137 | 138 | 139 | # Learned perceptual metric 140 | class LPIPS(nn.Module): 141 | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, 142 | pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True): 143 | # lpips - [True] means with linear calibration on top of base network 144 | # pretrained - [True] means load linear weights 145 | 146 | super(LPIPS, self).__init__() 147 | if verbose: 148 | print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' % 149 | ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) 150 | 151 | self.pnet_type = net 152 | self.pnet_tune = pnet_tune 153 | self.pnet_rand = pnet_rand 154 | self.spatial = spatial 155 | self.lpips = lpips # false means baseline of just averaging all layers 156 | self.version = version 157 | self.scaling_layer = ScalingLayer() 158 | 159 | if self.pnet_type in ['vgg', 'vgg16']: 160 | net_type = pn.vgg16 161 | self.chns = [64, 128, 256, 512, 512] 162 | elif self.pnet_type == 'alex': 163 | net_type = pn.alexnet 164 | self.chns = [64, 192, 384, 256, 256] 165 | elif self.pnet_type == 'squeeze': 166 | net_type = pn.squeezenet 167 | self.chns = [64, 128, 256, 384, 384, 512, 512] 168 | self.L = len(self.chns) 169 | 170 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 171 | 172 | if lpips: 173 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 174 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 175 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 176 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 177 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 178 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 179 | if self.pnet_type == 'squeeze': # 7 layers for squeezenet 180 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 181 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 182 | self.lins += [self.lin5, self.lin6] 183 | self.lins = nn.ModuleList(self.lins) 184 | 185 | if pretrained: 186 | if model_path is None: 187 | import inspect 188 | import os 189 | model_path = os.path.abspath( 190 | os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net))) 191 | 192 | if verbose: 193 | print('Loading model from: %s' % model_path) 194 | self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) 195 | 196 | if eval_mode: 197 | self.eval() 198 | 199 | def forward(self, in0, in1, retPerLayer=False, normalize=False): 200 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] 201 | in0 = 2 * in0 - 1 202 | in1 = 2 * in1 - 1 203 | 204 | # v0.0 - original release had a bug, where input was not scaled 205 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else ( 206 | in0, in1) 207 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 208 | feats0, feats1, diffs = {}, {}, {} 209 | 210 | for kk in range(self.L): 211 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 212 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 213 | 214 | if self.lpips: 215 | if self.spatial: 216 | res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] 217 | else: 218 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] 219 | else: 220 | if self.spatial: 221 | res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] 222 | else: 223 | res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] 224 | 225 | val = res[0] 226 | for l in range(1, self.L): 227 | val += res[l] 228 | 229 | # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True) 230 | # b = torch.max(self.lins[kk](feats0[kk]**2)) 231 | # for kk in range(self.L): 232 | # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True) 233 | # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2))) 234 | # a = a/self.L 235 | # from IPython import embed 236 | # embed() 237 | # return 10*torch.log10(b/a) 238 | 239 | if retPerLayer: 240 | return val, res 241 | else: 242 | return val 243 | 244 | 245 | class ScalingLayer(nn.Module): 246 | def __init__(self): 247 | super(ScalingLayer, self).__init__() 248 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 249 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 250 | 251 | def forward(self, inp): 252 | return (inp - self.shift) / self.scale 253 | 254 | 255 | class NetLinLayer(nn.Module): 256 | ''' A single linear layer which does a 1x1 conv ''' 257 | 258 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 259 | super(NetLinLayer, self).__init__() 260 | 261 | layers = [nn.Dropout(), ] if use_dropout else [] 262 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 263 | self.model = nn.Sequential(*layers) 264 | 265 | def forward(self, x): 266 | return self.model(x) 267 | 268 | 269 | class Dist2LogitLayer(nn.Module): 270 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 271 | 272 | def __init__(self, chn_mid=32, use_sigmoid=True): 273 | super(Dist2LogitLayer, self).__init__() 274 | 275 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ] 276 | layers += [nn.LeakyReLU(0.2, True), ] 277 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ] 278 | layers += [nn.LeakyReLU(0.2, True), ] 279 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ] 280 | if use_sigmoid: 281 | layers += [nn.Sigmoid(), ] 282 | self.model = nn.Sequential(*layers) 283 | 284 | def forward(self, d0, d1, eps=0.1): 285 | return self.model.forward(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1)) 286 | 287 | 288 | class BCERankingLoss(nn.Module): 289 | def __init__(self, chn_mid=32): 290 | super(BCERankingLoss, self).__init__() 291 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 292 | # self.parameters = list(self.net.parameters()) 293 | self.loss = torch.nn.BCELoss() 294 | 295 | def forward(self, d0, d1, judge): 296 | per = (judge + 1.) / 2. 297 | self.logit = self.net.forward(d0, d1) 298 | return self.loss(self.logit, per) 299 | 300 | 301 | # L2, DSSIM metrics 302 | class FakeNet(nn.Module): 303 | def __init__(self, use_gpu=True, colorspace='Lab'): 304 | super(FakeNet, self).__init__() 305 | self.use_gpu = use_gpu 306 | self.colorspace = colorspace 307 | 308 | 309 | class L2(FakeNet): 310 | def forward(self, in0, in1, retPerLayer=None): 311 | assert (in0.size()[0] == 1) # currently only supports batchSize 1 312 | 313 | if self.colorspace == 'RGB': 314 | (N, C, X, Y) = in0.size() 315 | value = torch.mean(torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y), 316 | dim=3).view(N) 317 | return value 318 | elif self.colorspace == 'Lab': 319 | value = l2(tensor2np(tensor2tensorlab(in0.data, to_norm=False)), 320 | tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype( 321 | 'float') 322 | ret_var = Variable(torch.Tensor((value,))) 323 | if self.use_gpu: 324 | ret_var = ret_var.cuda() 325 | return ret_var 326 | 327 | 328 | class DSSIM(FakeNet): 329 | 330 | def forward(self, in0, in1, retPerLayer=None): 331 | assert (in0.size()[0] == 1) # currently only supports batchSize 1 332 | 333 | if self.colorspace == 'RGB': 334 | value = dssim(1. * tensor2im(in0.data), 1. * tensor2im(in1.data), range=255.).astype( 335 | 'float') 336 | elif self.colorspace == 'Lab': 337 | value = dssim(tensor2np(tensor2tensorlab(in0.data, to_norm=False)), 338 | tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype( 339 | 'float') 340 | ret_var = Variable(torch.Tensor((value,))) 341 | if self.use_gpu: 342 | ret_var = ret_var.cuda() 343 | return ret_var 344 | 345 | 346 | def print_network(net): 347 | num_params = 0 348 | for param in net.parameters(): 349 | num_params += param.numel() 350 | print('Network', net) 351 | print('Total number of parameters: %d' % num_params) 352 | -------------------------------------------------------------------------------- /evaluations/lpips/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | 5 | 6 | class squeezenet(torch.nn.Module): 7 | def __init__(self, requires_grad=False, pretrained=True): 8 | super(squeezenet, self).__init__() 9 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | self.slice6 = torch.nn.Sequential() 16 | self.slice7 = torch.nn.Sequential() 17 | self.N_slices = 7 18 | for x in range(2): 19 | self.slice1.add_module(str(x), pretrained_features[x]) 20 | for x in range(2, 5): 21 | self.slice2.add_module(str(x), pretrained_features[x]) 22 | for x in range(5, 8): 23 | self.slice3.add_module(str(x), pretrained_features[x]) 24 | for x in range(8, 10): 25 | self.slice4.add_module(str(x), pretrained_features[x]) 26 | for x in range(10, 11): 27 | self.slice5.add_module(str(x), pretrained_features[x]) 28 | for x in range(11, 12): 29 | self.slice6.add_module(str(x), pretrained_features[x]) 30 | for x in range(12, 13): 31 | self.slice7.add_module(str(x), pretrained_features[x]) 32 | if not requires_grad: 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def forward(self, X): 37 | h = self.slice1(X) 38 | h_relu1 = h 39 | h = self.slice2(h) 40 | h_relu2 = h 41 | h = self.slice3(h) 42 | h_relu3 = h 43 | h = self.slice4(h) 44 | h_relu4 = h 45 | h = self.slice5(h) 46 | h_relu5 = h 47 | h = self.slice6(h) 48 | h_relu6 = h 49 | h = self.slice7(h) 50 | h_relu7 = h 51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7']) 52 | out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) 53 | 54 | return out 55 | 56 | 57 | class alexnet(torch.nn.Module): 58 | def __init__(self, requires_grad=False, pretrained=True): 59 | super(alexnet, self).__init__() 60 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 61 | self.slice1 = torch.nn.Sequential() 62 | self.slice2 = torch.nn.Sequential() 63 | self.slice3 = torch.nn.Sequential() 64 | self.slice4 = torch.nn.Sequential() 65 | self.slice5 = torch.nn.Sequential() 66 | self.N_slices = 5 67 | for x in range(2): 68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 69 | for x in range(2, 5): 70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 71 | for x in range(5, 8): 72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 73 | for x in range(8, 10): 74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 75 | for x in range(10, 12): 76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 77 | if not requires_grad: 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | def forward(self, X): 82 | h = self.slice1(X) 83 | h_relu1 = h 84 | h = self.slice2(h) 85 | h_relu2 = h 86 | h = self.slice3(h) 87 | h_relu3 = h 88 | h = self.slice4(h) 89 | h_relu4 = h 90 | h = self.slice5(h) 91 | h_relu5 = h 92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 94 | 95 | return out 96 | 97 | 98 | class vgg16(torch.nn.Module): 99 | def __init__(self, requires_grad=False, pretrained=True): 100 | super(vgg16, self).__init__() 101 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 102 | self.slice1 = torch.nn.Sequential() 103 | self.slice2 = torch.nn.Sequential() 104 | self.slice3 = torch.nn.Sequential() 105 | self.slice4 = torch.nn.Sequential() 106 | self.slice5 = torch.nn.Sequential() 107 | self.N_slices = 5 108 | for x in range(4): 109 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 110 | for x in range(4, 9): 111 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 112 | for x in range(9, 16): 113 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 114 | for x in range(16, 23): 115 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 116 | for x in range(23, 30): 117 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 118 | if not requires_grad: 119 | for param in self.parameters(): 120 | param.requires_grad = False 121 | 122 | def forward(self, X): 123 | h = self.slice1(X) 124 | h_relu1_2 = h 125 | h = self.slice2(h) 126 | h_relu2_2 = h 127 | h = self.slice3(h) 128 | h_relu3_3 = h 129 | h = self.slice4(h) 130 | h_relu4_3 = h 131 | h = self.slice5(h) 132 | h_relu5_3 = h 133 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 134 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 135 | 136 | return out 137 | 138 | 139 | class resnet(torch.nn.Module): 140 | def __init__(self, requires_grad=False, pretrained=True, num=18): 141 | super(resnet, self).__init__() 142 | if (num == 18): 143 | self.net = tv.resnet18(pretrained=pretrained) 144 | elif (num == 34): 145 | self.net = tv.resnet34(pretrained=pretrained) 146 | elif (num == 50): 147 | self.net = tv.resnet50(pretrained=pretrained) 148 | elif (num == 101): 149 | self.net = tv.resnet101(pretrained=pretrained) 150 | elif (num == 152): 151 | self.net = tv.resnet152(pretrained=pretrained) 152 | self.N_slices = 5 153 | 154 | self.conv1 = self.net.conv1 155 | self.bn1 = self.net.bn1 156 | self.relu = self.net.relu 157 | self.maxpool = self.net.maxpool 158 | self.layer1 = self.net.layer1 159 | self.layer2 = self.net.layer2 160 | self.layer3 = self.net.layer3 161 | self.layer4 = self.net.layer4 162 | 163 | def forward(self, X): 164 | h = self.conv1(X) 165 | h = self.bn1(h) 166 | h = self.relu(h) 167 | h_relu1 = h 168 | h = self.maxpool(h) 169 | h = self.layer1(h) 170 | h_conv2 = h 171 | h = self.layer2(h) 172 | h_conv3 = h 173 | h = self.layer3(h) 174 | h_conv4 = h 175 | h = self.layer4(h) 176 | h_conv5 = h 177 | 178 | outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5']) 179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 180 | 181 | return out 182 | -------------------------------------------------------------------------------- /evaluations/lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limacv/VideoLoop3D/615c665a7e3ffaf079c4fb5c9cc596a3bcf4e136/evaluations/lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /evaluations/lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limacv/VideoLoop3D/615c665a7e3ffaf079c4fb5c9cc596a3bcf4e136/evaluations/lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /evaluations/lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limacv/VideoLoop3D/615c665a7e3ffaf079c4fb5c9cc596a3bcf4e136/evaluations/lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /evaluations/metrics.py: -------------------------------------------------------------------------------- 1 | from skimage import metrics 2 | import torch.hub 3 | from evaluations.lpips.lpips import LPIPS 4 | import os 5 | import numpy as np 6 | 7 | photometric = { 8 | "mse": None, 9 | "ssim": None, 10 | "psnr": None, 11 | "lpips": None 12 | } 13 | 14 | 15 | def compute_img_metric(im1t: torch.Tensor, im2t: torch.Tensor, 16 | metric="mse", mask=None, range01=True): 17 | """ 18 | Args: 19 | im1t: tensor that has shape of batched images, *range from [-1, 1]* 20 | im2t: tensor that has shape of batched images, *range from [-1, 1]* 21 | metric: choose among mse, psnr, ssim, lpips 22 | mask: optional mask, tensor of shape [B, H, W] or [B, H, W, 1] 23 | """ 24 | if metric not in photometric.keys(): 25 | raise RuntimeError(f"img_utils:: metric {metric} not recognized") 26 | if photometric[metric] is None: 27 | if metric == "mse": 28 | photometric[metric] = metrics.mean_squared_error 29 | elif metric == "ssim": 30 | photometric[metric] = metrics.structural_similarity 31 | elif metric == "psnr": 32 | photometric[metric] = metrics.peak_signal_noise_ratio 33 | elif metric == "lpips": 34 | photometric[metric] = LPIPS().cpu() 35 | 36 | if mask is not None: 37 | if mask.dim() == 3: 38 | mask = mask.unsqueeze(1) 39 | if mask.shape[1] == 1: 40 | mask = mask.permute(0, 2, 3, 1).cpu() 41 | batchsz, hei, wid, _ = mask.shape 42 | 43 | if range01: 44 | im1t = im1t * 2 - 1 45 | im2t = im2t * 2 - 1 46 | 47 | im1t = im1t.clamp(-1, 1).detach().cpu() 48 | im2t = im2t.clamp(-1, 1).detach().cpu() 49 | 50 | if im1t.shape[-1] == 3: 51 | im1t = im1t.permute(0, 3, 1, 2) 52 | im2t = im2t.permute(0, 3, 1, 2) 53 | 54 | if mask is not None: 55 | im1t = im1t * mask.permute(0, 3, 1, 2) 56 | im2t = im2t * mask.permute(0, 3, 1, 2) 57 | 58 | im1 = im1t.permute(0, 2, 3, 1).numpy() 59 | im2 = im2t.permute(0, 2, 3, 1).numpy() 60 | mask = mask.numpy() 61 | batchsz, hei, wid, _ = im1.shape 62 | values = [] 63 | 64 | for i in range(batchsz): 65 | if metric in ["mse", "psnr"]: 66 | value = photometric[metric]( 67 | im1[i], im2[i] 68 | ) 69 | if mask is not None: 70 | pixelnum = mask[i % len(mask), ..., 0].sum() 71 | if metric == "mse": 72 | value = value * hei * wid / pixelnum 73 | else: 74 | value = value - 10 * np.log10(hei * wid / pixelnum) 75 | elif metric in ["ssim"]: 76 | value, ssimmap = photometric["ssim"]( 77 | im1[i], im2[i], channel_axis=-1, full=True 78 | ) 79 | if mask is not None: 80 | value = (ssimmap * mask[i % len(mask)]).sum() / mask[i % len(mask)].sum() / 3 81 | elif metric in ["lpips"]: 82 | value = photometric[metric]( 83 | im1t[i:i + 1], im2t[i:i + 1] 84 | ) 85 | else: 86 | raise NotImplementedError 87 | values.append(value) 88 | 89 | return sum(values) / len(values) 90 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10 2 | imageio 3 | imageio-ffmpeg 4 | tensorboard 5 | setuptools==59.5.0 6 | opencv-python 7 | unfoldNd 8 | pytorch_msssim 9 | scikit-image 10 | configargparse 11 | tqdm -------------------------------------------------------------------------------- /run_all.sh: -------------------------------------------------------------------------------- 1 | { 2 | 3 | ############################################# 4 | # Run all the 1st stage 5 | ############################################# 6 | 7 | CUDA_VISIBLE_DEVICES=0 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/108fall1.txt & 8 | CUDA_VISIBLE_DEVICES=1 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/108fall2.txt & 9 | CUDA_VISIBLE_DEVICES=2 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/108fall3.txt & 10 | CUDA_VISIBLE_DEVICES=3 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/108fall4.txt & 11 | CUDA_VISIBLE_DEVICES=4 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/108fall5.txt & 12 | CUDA_VISIBLE_DEVICES=5 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/110grasstree.txt & 13 | CUDA_VISIBLE_DEVICES=6 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/110pillarrm.txt & 14 | CUDA_VISIBLE_DEVICES=7 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/1017palm.txt & 15 | 16 | wait 17 | 18 | CUDA_VISIBLE_DEVICES=8 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/1017yuan.txt & 19 | CUDA_VISIBLE_DEVICES=7 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/1020rock.txt & 20 | CUDA_VISIBLE_DEVICES=6 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/1020ustfall1.txt & 21 | CUDA_VISIBLE_DEVICES=5 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/1020ustfall2.txt & 22 | CUDA_VISIBLE_DEVICES=4 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/1101grass.txt & 23 | CUDA_VISIBLE_DEVICES=3 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/1101towerd.txt & 24 | CUDA_VISIBLE_DEVICES=1 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/ustfallclose.txt & 25 | CUDA_VISIBLE_DEVICES=0 python train_3d.py --config configs/mpi_base.txt --config1 configs/mpis/usttap.txt & 26 | 27 | wait 28 | 29 | ############################################# 30 | # Run all the 2nd stage 31 | ############################################# 32 | 33 | CUDA_VISIBLE_DEVICES=0 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/108fall1.txt & 34 | CUDA_VISIBLE_DEVICES=1 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/108fall2.txt & 35 | CUDA_VISIBLE_DEVICES=2 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/108fall3.txt & 36 | CUDA_VISIBLE_DEVICES=3 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/108fall4.txt & 37 | CUDA_VISIBLE_DEVICES=4 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/108fall5.txt & 38 | CUDA_VISIBLE_DEVICES=5 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/110grasstree.txt & 39 | CUDA_VISIBLE_DEVICES=6 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/110pillar.txt & 40 | CUDA_VISIBLE_DEVICES=7 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/1017palm.txt & 41 | 42 | 43 | wait 44 | 45 | CUDA_VISIBLE_DEVICES=8 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/1017yuan.txt & 46 | CUDA_VISIBLE_DEVICES=7 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/1020rock.txt & 47 | CUDA_VISIBLE_DEVICES=6 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/1020ustfall1.txt & 48 | CUDA_VISIBLE_DEVICES=5 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/1020ustfall2.txt & 49 | CUDA_VISIBLE_DEVICES=4 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/1101grass.txt & 50 | CUDA_VISIBLE_DEVICES=3 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/1101towerd.txt & 51 | CUDA_VISIBLE_DEVICES=1 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/ustfallclose.txt & 52 | CUDA_VISIBLE_DEVICES=0 python train_3dvid.py --config configs/mpv_base.txt --config1 configs/mpvs/usttap.txt & 53 | 54 | 55 | ############################################# 56 | #export all meshes 57 | ############################################# 58 | 59 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/108fall1.txt & 60 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/108fall2.txt 61 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/108fall3.txt & 62 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/108fall4.txt 63 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/108fall5.txt & 64 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/110grasstree.txt 65 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/110pillar.txt & 66 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/1017palm.txt 67 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/1017yuan.txt & 68 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/1020rock.txt 69 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/1020ustfall1.txt & 70 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/1020ustfall2.txt 71 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/1101grass.txt & 72 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/1101towerd.txt 73 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/ustfallclose.txt & 74 | python script_export_mesh.py --config configs/mpv_base.txt --config1 configs/mpvs/usttap.txt 75 | 76 | 77 | 78 | wait 79 | 80 | exit 81 | } 82 | -------------------------------------------------------------------------------- /scripts/colmaps/__init__.py: -------------------------------------------------------------------------------- 1 | from colmaps.llffposes.pose_utils import gen_poses -------------------------------------------------------------------------------- /scripts/colmaps/llffposes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limacv/VideoLoop3D/615c665a7e3ffaf079c4fb5c9cc596a3bcf4e136/scripts/colmaps/llffposes/__init__.py -------------------------------------------------------------------------------- /scripts/colmaps/llffposes/colmap_read_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch) 31 | 32 | import os 33 | import sys 34 | import collections 35 | import numpy as np 36 | import struct 37 | 38 | 39 | CameraModel = collections.namedtuple( 40 | "CameraModel", ["model_id", "model_name", "num_params"]) 41 | Camera = collections.namedtuple( 42 | "Camera", ["id", "model", "width", "height", "params"]) 43 | BaseImage = collections.namedtuple( 44 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 45 | Point3D = collections.namedtuple( 46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 47 | 48 | class Image(BaseImage): 49 | def qvec2rotmat(self): 50 | return qvec2rotmat(self.qvec) 51 | 52 | 53 | CAMERA_MODELS = { 54 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 55 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 56 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 57 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 58 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 59 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 60 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 61 | CameraModel(model_id=7, model_name="FOV", num_params=5), 62 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 63 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 64 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 65 | } 66 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \ 67 | for camera_model in CAMERA_MODELS]) 68 | 69 | 70 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 71 | """Read and unpack the next bytes from a binary file. 72 | :param fid: 73 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 74 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 75 | :param endian_character: Any of {@, =, <, >, !} 76 | :return: Tuple of read and unpacked values. 77 | """ 78 | data = fid.read(num_bytes) 79 | return struct.unpack(endian_character + format_char_sequence, data) 80 | 81 | 82 | def read_cameras_text(path): 83 | """ 84 | see: src/base/reconstruction.cc 85 | void Reconstruction::WriteCamerasText(const std::string& path) 86 | void Reconstruction::ReadCamerasText(const std::string& path) 87 | """ 88 | cameras = {} 89 | with open(path, "r") as fid: 90 | while True: 91 | line = fid.readline() 92 | if not line: 93 | break 94 | line = line.strip() 95 | if len(line) > 0 and line[0] != "#": 96 | elems = line.split() 97 | camera_id = int(elems[0]) 98 | model = elems[1] 99 | width = int(elems[2]) 100 | height = int(elems[3]) 101 | params = np.array(tuple(map(float, elems[4:]))) 102 | cameras[camera_id] = Camera(id=camera_id, model=model, 103 | width=width, height=height, 104 | params=params) 105 | return cameras 106 | 107 | 108 | def read_cameras_binary(path_to_model_file): 109 | """ 110 | see: src/base/reconstruction.cc 111 | void Reconstruction::WriteCamerasBinary(const std::string& path) 112 | void Reconstruction::ReadCamerasBinary(const std::string& path) 113 | """ 114 | cameras = {} 115 | with open(path_to_model_file, "rb") as fid: 116 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 117 | for camera_line_index in range(num_cameras): 118 | camera_properties = read_next_bytes( 119 | fid, num_bytes=24, format_char_sequence="iiQQ") 120 | camera_id = camera_properties[0] 121 | model_id = camera_properties[1] 122 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 123 | width = camera_properties[2] 124 | height = camera_properties[3] 125 | num_params = CAMERA_MODEL_IDS[model_id].num_params 126 | params = read_next_bytes(fid, num_bytes=8*num_params, 127 | format_char_sequence="d"*num_params) 128 | cameras[camera_id] = Camera(id=camera_id, 129 | model=model_name, 130 | width=width, 131 | height=height, 132 | params=np.array(params)) 133 | assert len(cameras) == num_cameras 134 | return cameras 135 | 136 | 137 | def read_images_text(path): 138 | """ 139 | see: src/base/reconstruction.cc 140 | void Reconstruction::ReadImagesText(const std::string& path) 141 | void Reconstruction::WriteImagesText(const std::string& path) 142 | """ 143 | images = {} 144 | with open(path, "r") as fid: 145 | while True: 146 | line = fid.readline() 147 | if not line: 148 | break 149 | line = line.strip() 150 | if len(line) > 0 and line[0] != "#": 151 | elems = line.split() 152 | image_id = int(elems[0]) 153 | qvec = np.array(tuple(map(float, elems[1:5]))) 154 | tvec = np.array(tuple(map(float, elems[5:8]))) 155 | camera_id = int(elems[8]) 156 | image_name = elems[9] 157 | elems = fid.readline().split() 158 | xys = np.column_stack([tuple(map(float, elems[0::3])), 159 | tuple(map(float, elems[1::3]))]) 160 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 161 | images[image_id] = Image( 162 | id=image_id, qvec=qvec, tvec=tvec, 163 | camera_id=camera_id, name=image_name, 164 | xys=xys, point3D_ids=point3D_ids) 165 | return images 166 | 167 | 168 | def read_images_binary(path_to_model_file): 169 | """ 170 | see: src/base/reconstruction.cc 171 | void Reconstruction::ReadImagesBinary(const std::string& path) 172 | void Reconstruction::WriteImagesBinary(const std::string& path) 173 | """ 174 | images = {} 175 | with open(path_to_model_file, "rb") as fid: 176 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 177 | for image_index in range(num_reg_images): 178 | binary_image_properties = read_next_bytes( 179 | fid, num_bytes=64, format_char_sequence="idddddddi") 180 | image_id = binary_image_properties[0] 181 | qvec = np.array(binary_image_properties[1:5]) 182 | tvec = np.array(binary_image_properties[5:8]) 183 | camera_id = binary_image_properties[8] 184 | image_name = "" 185 | current_char = read_next_bytes(fid, 1, "c")[0] 186 | while current_char != b"\x00": # look for the ASCII 0 entry 187 | image_name += current_char.decode("utf-8") 188 | current_char = read_next_bytes(fid, 1, "c")[0] 189 | num_points2D = read_next_bytes(fid, num_bytes=8, 190 | format_char_sequence="Q")[0] 191 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 192 | format_char_sequence="ddq"*num_points2D) 193 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 194 | tuple(map(float, x_y_id_s[1::3]))]) 195 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 196 | images[image_id] = Image( 197 | id=image_id, qvec=qvec, tvec=tvec, 198 | camera_id=camera_id, name=image_name, 199 | xys=xys, point3D_ids=point3D_ids) 200 | return images 201 | 202 | 203 | def read_points3D_text(path): 204 | """ 205 | see: src/base/reconstruction.cc 206 | void Reconstruction::ReadPoints3DText(const std::string& path) 207 | void Reconstruction::WritePoints3DText(const std::string& path) 208 | """ 209 | points3D = {} 210 | with open(path, "r") as fid: 211 | while True: 212 | line = fid.readline() 213 | if not line: 214 | break 215 | line = line.strip() 216 | if len(line) > 0 and line[0] != "#": 217 | elems = line.split() 218 | point3D_id = int(elems[0]) 219 | xyz = np.array(tuple(map(float, elems[1:4]))) 220 | rgb = np.array(tuple(map(int, elems[4:7]))) 221 | error = float(elems[7]) 222 | image_ids = np.array(tuple(map(int, elems[8::2]))) 223 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 224 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 225 | error=error, image_ids=image_ids, 226 | point2D_idxs=point2D_idxs) 227 | return points3D 228 | 229 | 230 | def read_points3d_binary(path_to_model_file): 231 | """ 232 | see: src/base/reconstruction.cc 233 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 234 | void Reconstruction::WritePoints3DBinary(const std::string& path) 235 | """ 236 | points3D = {} 237 | with open(path_to_model_file, "rb") as fid: 238 | num_points = read_next_bytes(fid, 8, "Q")[0] 239 | for point_line_index in range(num_points): 240 | binary_point_line_properties = read_next_bytes( 241 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 242 | point3D_id = binary_point_line_properties[0] 243 | xyz = np.array(binary_point_line_properties[1:4]) 244 | rgb = np.array(binary_point_line_properties[4:7]) 245 | error = np.array(binary_point_line_properties[7]) 246 | track_length = read_next_bytes( 247 | fid, num_bytes=8, format_char_sequence="Q")[0] 248 | track_elems = read_next_bytes( 249 | fid, num_bytes=8*track_length, 250 | format_char_sequence="ii"*track_length) 251 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 252 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 253 | points3D[point3D_id] = Point3D( 254 | id=point3D_id, xyz=xyz, rgb=rgb, 255 | error=error, image_ids=image_ids, 256 | point2D_idxs=point2D_idxs) 257 | return points3D 258 | 259 | 260 | def read_model(path, ext): 261 | if ext == ".txt": 262 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 263 | images = read_images_text(os.path.join(path, "images" + ext)) 264 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 265 | else: 266 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 267 | images = read_images_binary(os.path.join(path, "images" + ext)) 268 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) 269 | return cameras, images, points3D 270 | 271 | 272 | def qvec2rotmat(qvec): 273 | return np.array([ 274 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 275 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 276 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 277 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 278 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 279 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 280 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 281 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 282 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 283 | 284 | 285 | def rotmat2qvec(R): 286 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 287 | K = np.array([ 288 | [Rxx - Ryy - Rzz, 0, 0, 0], 289 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 290 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 291 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 292 | eigvals, eigvecs = np.linalg.eigh(K) 293 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 294 | if qvec[0] < 0: 295 | qvec *= -1 296 | return qvec 297 | 298 | 299 | def main(): 300 | if len(sys.argv) != 3: 301 | print("Usage: python read_model.py path/to/model/folder [.txt,.bin]") 302 | return 303 | 304 | cameras, images, points3D = read_model(path=sys.argv[1], ext=sys.argv[2]) 305 | 306 | print("num_cameras:", len(cameras)) 307 | print("num_images:", len(images)) 308 | print("num_points3D:", len(points3D)) 309 | 310 | 311 | if __name__ == "__main__": 312 | main() 313 | -------------------------------------------------------------------------------- /scripts/colmaps/llffposes/colmap_wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | 5 | # $ DATASET_PATH=/path/to/dataset 6 | 7 | # $ colmap feature_extractor \ 8 | # --database_path $DATASET_PATH/database.db \ 9 | # --image_path $DATASET_PATH/images 10 | 11 | # $ colmap exhaustive_matcher \ 12 | # --database_path $DATASET_PATH/database.db 13 | 14 | # $ mkdir $DATASET_PATH/sparse 15 | 16 | # $ colmap mapper \ 17 | # --database_path $DATASET_PATH/database.db \ 18 | # --image_path $DATASET_PATH/images \ 19 | # --output_path $DATASET_PATH/sparse 20 | 21 | # $ mkdir $DATASET_PATH/dense 22 | colmap_path = "D:\\MSI_NB\\source\\util\\COLMAP-3.6-exe\\COLMAP.bat" 23 | 24 | 25 | def run_colmap(basedir, match_type, pipeline, imagedir='images', share_intrin=True): 26 | logfile_name = os.path.join(basedir, 'colmap_output.txt') 27 | logfile = open(logfile_name, 'w') 28 | 29 | if "feature_extractor" in pipeline: 30 | feature_extractor_args = [ 31 | colmap_path, 'feature_extractor', 32 | '--database_path', os.path.join(basedir, 'database.db'), 33 | '--image_path', os.path.join(basedir, imagedir), 34 | '--ImageReader.camera_model', 'SIMPLE_PINHOLE' 35 | # '--SiftExtraction.use_gpu', '0', 36 | ] 37 | if share_intrin: 38 | feature_extractor_args += ['--ImageReader.single_camera', '1'] 39 | feat_output = (subprocess.check_output(feature_extractor_args, universal_newlines=True)) 40 | logfile.write(feat_output) 41 | print('Features extracted') 42 | 43 | if "matcher" in pipeline: 44 | exhaustive_matcher_args = [ 45 | colmap_path, match_type, 46 | '--database_path', os.path.join(basedir, 'database.db'), 47 | ] 48 | 49 | match_output = (subprocess.check_output(exhaustive_matcher_args, universal_newlines=True)) 50 | logfile.write(match_output) 51 | print('Features matched') 52 | 53 | if "mapper" in pipeline: 54 | p = os.path.join(basedir, 'sparse') 55 | if not os.path.exists(p): 56 | os.makedirs(p) 57 | 58 | # mapper_args = [ 59 | # 'colmap', 'mapper', 60 | # '--database_path', os.path.join(basedir, 'database.db'), 61 | # '--image_path', os.path.join(basedir, 'images'), 62 | # '--output_path', os.path.join(basedir, 'sparse'), 63 | # '--Mapper.num_threads', '16', 64 | # '--Mapper.init_min_tri_angle', '4', 65 | # ] 66 | mapper_args = [ 67 | colmap_path, 'mapper', 68 | '--database_path', os.path.join(basedir, 'database.db'), 69 | '--image_path', os.path.join(basedir, imagedir), 70 | '--output_path', os.path.join(basedir, 'sparse'), # --export_path changed to --output_path in colmap 3.6 71 | '--Mapper.num_threads', '12', 72 | '--Mapper.init_min_tri_angle', '4', 73 | '--Mapper.multiple_models', '0', 74 | # '--Mapper.extract_colors', '0', 75 | ] 76 | 77 | map_output = (subprocess.check_output(mapper_args, universal_newlines=True)) 78 | 79 | logfile.write(map_output) 80 | 81 | if "convert" in pipeline: 82 | converter_args = [ 83 | colmap_path, 'model_converter', 84 | '--input_path', os.path.join(basedir, 'sparse/0'), 85 | '--output_path', os.path.join(basedir, 'sparse/0'), 86 | '--output_type', 'TXT', 87 | ] 88 | 89 | converter_output = (subprocess.check_output(converter_args, universal_newlines=True)) 90 | print('Txt model converted') 91 | 92 | logfile.write(converter_output) 93 | logfile.close() 94 | print('Sparse map created') 95 | 96 | print('Finished running COLMAP, see {} for logs'.format(logfile_name)) 97 | -------------------------------------------------------------------------------- /scripts/colmaps/llffposes/pose_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import imageio 5 | from glob import glob 6 | from numpy.lib.arraysetops import isin 7 | from numpy.lib.function_base import iterable 8 | import skimage.transform 9 | 10 | from .colmap_wrapper import run_colmap 11 | from . import colmap_read_model as read_model 12 | 13 | 14 | def load_colmap_data(realdir): 15 | camerasfile = os.path.join(realdir, 'sparse/0/cameras.bin') 16 | camdata = read_model.read_cameras_binary(camerasfile) 17 | 18 | # cam = camdata[camdata.keys()[0]] 19 | list_of_keys = list(camdata.keys()) 20 | cams = [camdata[k] for k in list_of_keys] 21 | print('Cameras', len(cams)) 22 | 23 | hwf = [(cam.height, cam.width, cam.params[0]) for cam in cams] 24 | # w, h, f = factor * w, factor * h, factor * f 25 | hwf = np.array(hwf).reshape([-1, 3]).T 26 | 27 | imagesfile = os.path.join(realdir, 'sparse/0/images.bin') 28 | imdata = read_model.read_images_binary(imagesfile) 29 | 30 | w2c_mats = [] 31 | bottom = np.array([0, 0, 0, 1.]).reshape([1, 4]) 32 | 33 | names = [imdata[k].name for k in imdata] 34 | print('Images #', len(names)) 35 | perm = np.argsort(names) 36 | for k in imdata: 37 | im = imdata[k] 38 | R = im.qvec2rotmat() 39 | t = im.tvec.reshape([3, 1]) 40 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) 41 | w2c_mats.append(m) 42 | 43 | w2c_mats = np.stack(w2c_mats, 0) 44 | c2w_mats = np.linalg.inv(w2c_mats) 45 | 46 | poses = c2w_mats[:, :3, :4].transpose([1, 2, 0]) 47 | hwf = hwf[:, None, :] 48 | if hwf.shape[-1] != poses.shape[-1]: 49 | hwf = hwf.repeat(poses.shape[-1], -1) 50 | poses = np.concatenate([poses, hwf], 1) 51 | 52 | points3dfile = os.path.join(realdir, 'sparse/0/points3D.bin') 53 | pts3d = read_model.read_points3d_binary(points3dfile) 54 | 55 | # must switch to [-u, r, -t] from [r, -u, t], NOT [r, u, -t] 56 | poses = np.concatenate([poses[:, 1:2, :], poses[:, 0:1, :], -poses[:, 2:3, :], poses[:, 3:4, :], poses[:, 4:5, :]], 57 | 1) 58 | 59 | return poses, pts3d, perm, names 60 | 61 | 62 | def save_poses(basedir, poses, pts3d, perm): 63 | pts_arr = [] 64 | vis_arr = [] 65 | for k in pts3d: 66 | pts_arr.append(pts3d[k].xyz) 67 | cams = [0] * poses.shape[-1] 68 | for ind in pts3d[k].image_ids: 69 | if len(cams) < ind - 1: 70 | print('ERROR: the correct camera poses for current points cannot be accessed') 71 | return 72 | cams[ind - 1] = 1 73 | vis_arr.append(cams) 74 | 75 | pts_arr = np.array(pts_arr) 76 | vis_arr = np.array(vis_arr) 77 | print('Points', pts_arr.shape, 'Visibility', vis_arr.shape) 78 | 79 | zvals = np.sum(-(pts_arr[:, np.newaxis, :].transpose([2, 0, 1]) - poses[:3, 3:4, :]) * poses[:3, 2:3, :], 0) 80 | valid_z = zvals[vis_arr == 1] 81 | print('Depth stats', valid_z.min(), valid_z.max(), valid_z.mean()) 82 | 83 | save_arr = [] 84 | for i in perm: 85 | vis = vis_arr[:, i] 86 | zs = zvals[:, i] 87 | zs = zs[vis == 1] 88 | close_depth, inf_depth = np.percentile(zs, .1), np.percentile(zs, 99.9) 89 | # print( i, close_depth, inf_depth ) 90 | 91 | save_arr.append(np.concatenate([poses[..., i].ravel(), np.array([close_depth, inf_depth])], 0)) 92 | save_arr = np.array(save_arr) 93 | 94 | np.save(os.path.join(basedir, 'poses_bounds.npy'), save_arr) 95 | 96 | 97 | def minify_v0(basedir, factors=[], resolutions=[]): 98 | needtoload = False 99 | for r in factors: 100 | imgdir = os.path.join(basedir, 'images_{}'.format(r)) 101 | if not os.path.exists(imgdir): 102 | needtoload = True 103 | for r in resolutions: 104 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) 105 | if not os.path.exists(imgdir): 106 | needtoload = True 107 | if not needtoload: 108 | return 109 | 110 | def downsample(imgs, f): 111 | sh = list(imgs.shape) 112 | sh = sh[:-3] + [sh[-3] // f, f, sh[-2] // f, f, sh[-1]] 113 | imgs = np.reshape(imgs, sh) 114 | imgs = np.mean(imgs, (-2, -4)) 115 | return imgs 116 | 117 | imgdir = os.path.join(basedir, 'images') 118 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 119 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] 120 | imgs = np.stack([imageio.imread(img) / 255. for img in imgs], 0) 121 | 122 | for r in factors + resolutions: 123 | if isinstance(r, int): 124 | name = 'images_{}'.format(r) 125 | else: 126 | name = 'images_{}x{}'.format(r[1], r[0]) 127 | imgdir = os.path.join(basedir, name) 128 | if os.path.exists(imgdir): 129 | continue 130 | print('Minifying', r, basedir) 131 | 132 | if isinstance(r, int): 133 | imgs_down = downsample(imgs, r) 134 | else: 135 | imgs_down = skimage.transform.resize(imgs, [imgs.shape[0], r[0], r[1], imgs.shape[-1]], 136 | order=1, mode='constant', cval=0, clip=True, preserve_range=False, 137 | anti_aliasing=True, anti_aliasing_sigma=None) 138 | 139 | os.makedirs(imgdir) 140 | for i in range(imgs_down.shape[0]): 141 | imageio.imwrite(os.path.join(imgdir, 'image{:03d}.png'.format(i)), (255 * imgs_down[i]).astype(np.uint8)) 142 | 143 | 144 | def minify(basedir, factors=[], resolutions=[]): 145 | needtoload = False 146 | for r in factors: 147 | imgdir = os.path.join(basedir, 'images_{}'.format(r)) 148 | if not os.path.exists(imgdir): 149 | needtoload = True 150 | for r in resolutions: 151 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) 152 | if not os.path.exists(imgdir): 153 | needtoload = True 154 | if not needtoload: 155 | return 156 | 157 | from shutil import copy 158 | from subprocess import check_output 159 | 160 | imgdir = os.path.join(basedir, 'images') 161 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 162 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] 163 | imgdir_orig = imgdir 164 | 165 | wd = os.getcwd() 166 | 167 | for r in factors + resolutions: 168 | 169 | if isinstance(r, int): 170 | name = 'images_{}'.format(r) 171 | resizearg = '{}%'.format(int(100. / r)) 172 | resizeargcv = 1. / r 173 | else: 174 | name = 'images_{}x{}'.format(r[1], r[0]) 175 | resizearg = '{}x{}'.format(r[0], r[1]) 176 | resizeargcv = (r[1], r[0]) 177 | imgdir = os.path.join(basedir, name) 178 | if os.path.exists(imgdir): 179 | continue 180 | 181 | print('Minifying', r, basedir) 182 | print("now using opencv") 183 | import cv2 184 | from glob import glob 185 | os.makedirs(imgdir) 186 | 187 | for image_path in glob(imgdir_orig + "/*.jpg"): 188 | tempimg = cv2.imread(image_path) 189 | if isinstance(resizeargcv, tuple): 190 | tempimg = cv2.resize(tempimg, resizeargcv, interpolation=cv2.INTER_AREA) 191 | else: 192 | tempimg = cv2.resize(tempimg, None, fx=resizeargcv, fy=resizeargcv, interpolation=cv2.INTER_AREA) 193 | 194 | image_file = os.path.basename(image_path).split('.')[0] 195 | output_path = imgdir + '/' + image_file + ".png" 196 | cv2.imwrite(output_path, tempimg) 197 | 198 | # check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) 199 | # 200 | # ext = imgs[0].split('.')[-1] 201 | # args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) 202 | # print(args) 203 | # os.chdir(imgdir) 204 | # check_output(args, shell=True) 205 | # os.chdir(wd) 206 | 207 | # if ext != 'png': 208 | # check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) 209 | # print('Removed duplicates') 210 | print('Done') 211 | 212 | 213 | def load_data(basedir, factor=None, width=None, height=None, load_imgs=True): 214 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) 215 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) 216 | bds = poses_arr[:, -2:].transpose([1, 0]) 217 | 218 | img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ 219 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] 220 | sh = imageio.imread(img0).shape 221 | 222 | sfx = '' 223 | 224 | if factor is not None: 225 | sfx = '_{}'.format(factor) 226 | minify(basedir, factors=[factor]) 227 | factor = factor 228 | elif height is not None: 229 | factor = sh[0] / float(height) 230 | width = int(sh[1] / factor) 231 | minify(basedir, resolutions=[[height, width]]) 232 | sfx = '_{}x{}'.format(width, height) 233 | elif width is not None: 234 | factor = sh[1] / float(width) 235 | height = int(sh[0] / factor) 236 | minify(basedir, resolutions=[[height, width]]) 237 | sfx = '_{}x{}'.format(width, height) 238 | else: 239 | factor = 1 240 | 241 | imgdir = os.path.join(basedir, 'images' + sfx) 242 | if not os.path.exists(imgdir): 243 | print(imgdir, 'does not exist, returning') 244 | return 245 | 246 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if 247 | f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] 248 | if poses.shape[-1] != len(imgfiles): 249 | print('Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1])) 250 | return 251 | 252 | sh = imageio.imread(imgfiles[0]).shape 253 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) 254 | poses[2, 4, :] = poses[2, 4, :] * 1. / factor 255 | 256 | if not load_imgs: 257 | return poses, bds 258 | 259 | # imgs = [imageio.imread(f, ignoregamma=True)[...,:3]/255. for f in imgfiles] 260 | def imread(f): 261 | if f.endswith('png'): 262 | return imageio.imread(f, ignoregamma=True) 263 | else: 264 | return imageio.imread(f) 265 | 266 | imgs = imgs = [imread(f)[..., :3] / 255. for f in imgfiles] 267 | imgs = np.stack(imgs, -1) 268 | 269 | print('Loaded image data', imgs.shape, poses[:, -1, 0]) 270 | return poses, bds, imgs 271 | 272 | 273 | def gen_poses(basedir, match_type, factors=None, usedown=False, share_intrin=True): 274 | if factors is not None and not os.path.exists(f"{basedir}/images_{factors}"): 275 | print("Minify") 276 | minify(basedir, [factors] if not isinstance(factors, list) else factors) 277 | 278 | if os.path.exists(os.path.join(basedir, "poses_bounds.npy")): 279 | print("exists poses_bounds.npy, will do nothing and exit") 280 | exit() 281 | files_needed = ['{}.bin'.format(f) for f in ['cameras', 'images', 'points3D']] 282 | if os.path.exists(os.path.join(basedir, 'sparse/0')): 283 | files_had = os.listdir(os.path.join(basedir, 'sparse/0')) 284 | else: 285 | files_had = [] 286 | if not all([f in files_had for f in files_needed]): 287 | print('Need to run COLMAP') 288 | run_colmap(basedir, match_type, ["feature_extractor", "matcher", "mapper"], 289 | imagedir=f"images_{factors}" if usedown else "images", 290 | share_intrin=share_intrin) 291 | else: 292 | print('Don\'t need to run COLMAP, only convert') 293 | run_colmap(basedir, match_type, []) 294 | 295 | print('Post-colmap') 296 | 297 | poses, pts3d, perm, names = load_colmap_data(basedir) 298 | # check whether all images are successfully registered 299 | all_names = glob(os.path.join(basedir, "images", "*.jpg")) + glob(os.path.join(basedir, "images", "*.png")) 300 | all_names = [os.path.basename(name_) for name_ in all_names] 301 | names = {name_.split('.')[0] for name_ in names} 302 | all_names = {name_.split('.')[0] for name_ in all_names} 303 | failed_names = all_names - names 304 | if len(failed_names) > 0: 305 | print("Oops! some images failed to register, the names are:\n") 306 | print(', '.join(failed_names) + '\n') 307 | print("Please delete and run again") 308 | exit() 309 | 310 | if usedown: # scale the hwf to the original resolution 311 | poses[:, -1, :] *= factors 312 | 313 | save_poses(basedir, poses, pts3d, perm) 314 | 315 | print('Done with imgs2poses') 316 | 317 | return True 318 | -------------------------------------------------------------------------------- /scripts/script_evaluate_ours.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import imageio 4 | import numpy as np 5 | import math 6 | import torch.nn as nn 7 | import time 8 | from torch.utils.tensorboard import SummaryWriter 9 | from torch.utils.data import DataLoader, Dataset 10 | from MPV import * 11 | 12 | from dataloader import load_mv_videos, poses_avg 13 | from utils import * 14 | import shutil 15 | from datetime import datetime 16 | import cv2 17 | from config_parser import config_parser 18 | from tqdm import tqdm, trange 19 | from copy import deepcopy 20 | from evaluations.SVFID import svfid 21 | from evaluations.LPIPS import compute_lpips, compute_lpips_slidewindow 22 | from evaluations.NNMSE import compute_nnerr 23 | 24 | # Flag 25 | COMPUTE_STATIC = True 26 | COMPUTE_DYN = True 27 | COMPUTE_LPIPS = True 28 | COMPUTE_NNMSE = True 29 | COMPUTE_LOOPQ = True 30 | COMPUTE_SVFID = True 31 | 32 | 33 | def evaluate(args): 34 | device = 'cuda:0' 35 | if args.gpu_num <= 0: 36 | device = 'cpu' 37 | print(f"Using CPU for training") 38 | 39 | expname = args.expname + args.expname_postfix 40 | print(f"Evaluating: {expname}") 41 | args.datadir = args.datadir.rstrip('/\\') 42 | if args.datadir.endswith("_loop"): 43 | print(f"Warning!!! Detect data pointing to the looping dataset, " 44 | f"will change from {args.datadir} to {args.datadir[:-5]}") 45 | args.datadir = args.datadir[:-5] 46 | 47 | datadir = os.path.join(args.prefix, args.datadir) 48 | expdir = os.path.join(args.prefix, args.expdir) 49 | videos, FPS, poses, intrins, bds, render_poses, render_intrins = \ 50 | load_mv_videos(basedir=datadir, 51 | factor=args.factor, 52 | bd_factor=(args.near_factor, args.far_factor), 53 | recenter=True) 54 | 55 | H, W = videos[0][0].shape[0:2] 56 | print('Loaded llff', H, W, poses.shape, intrins.shape, render_poses.shape, bds.shape) 57 | test_view = args.test_view_idx 58 | test_view = list(map(int, test_view.split(','))) if len(test_view) > 0 else list(range(V)) 59 | # filter out test view 60 | videos = [videos[train_i] for train_i in test_view] 61 | videos = [np.array(vid) for vid in videos] 62 | poses = poses[test_view] 63 | intrins = intrins[test_view] 64 | print(f'Test view: {test_view}') 65 | V = len(videos) 66 | 67 | # generate loopmask 68 | loopmasks = [compute_loopable_mask(v_ / 255) for v_ in videos] 69 | loopmasks = [- np.array(m_).astype(np.float32) + 1 for m_ in loopmasks] 70 | ref_pose = poses_avg(poses)[:, :4] 71 | ref_extrin = pose2extrin_np(ref_pose) 72 | ref_intrin = intrins[0] 73 | ref_near, ref_far = bds.min(), bds.max() 74 | 75 | # Create nerf model 76 | if args.model_type == "MPMesh": 77 | args.mpi_h_scale = args.mpi_w_scale = 0.01 78 | nerf = MPMeshVid(args, H, W, ref_extrin, ref_intrin, ref_near, ref_far) 79 | else: 80 | raise RuntimeError(f"Unrecognized model type {args.model_type}") 81 | 82 | nerf = nn.DataParallel(nerf, list(range(args.gpu_num))) 83 | nerf.to(device) 84 | extrins = pose2extrin_np(poses) 85 | extrins = torch.tensor(extrins).float() 86 | poses = torch.tensor(poses).float() 87 | intrins = torch.tensor(intrins).float() 88 | 89 | ########################## 90 | # load from checkpoint 91 | ckpts = [os.path.join(expdir, expname, f) 92 | for f in sorted(os.listdir(os.path.join(expdir, expname))) if 'tar' in f] 93 | if len(ckpts) > 0: 94 | ckpt_path = ckpts[-1] 95 | print(f"Using ckpt {ckpt_path}") 96 | else: 97 | raise RuntimeError("Failed, cannot find any ckpts") 98 | print('Reloading from', ckpt_path) 99 | ckpt = torch.load(ckpt_path) 100 | 101 | state_dict = ckpt['network_state_dict'] 102 | nerf.module.init_from_mpi(state_dict) 103 | nerf.to(device) 104 | 105 | # ########################## 106 | # evaluating ours 107 | # ########################## 108 | ours_rgb = [] 109 | print('Begin') 110 | moviebase = os.path.join(expdir, expname, f'eval_') 111 | with torch.no_grad(): 112 | nerf.eval() 113 | 114 | for viewi in range(V): 115 | torch.cuda.empty_cache() 116 | r_pose = extrins[viewi: viewi + 1] 117 | r_intrin = intrins[viewi: viewi + 1] 118 | ts = torch.arange(nerf.module.frm_num).long() 119 | rgb = [nerf(H, W, r_pose, r_intrin, ts[ti: ti + 2])[0] for ti in range(0, len(ts), 2)] 120 | rgb = torch.concat(rgb) 121 | rgb = rgb.permute(0, 2, 3, 1).cpu().numpy() 122 | rgb = to8b(rgb) 123 | ours_rgb.append(rgb) 124 | 125 | # ######################## 126 | # Computing metrics. gt, pred are videos F x H x W x 3, in (0, 255), rgb 127 | # ######################## 128 | crop = 40 129 | videos = [vid[:, crop:-crop, crop:-crop] for vid in videos] 130 | ours_rgb = [vid[:, crop:-crop, crop:-crop] for vid in ours_rgb] 131 | loopmasks = [m_[crop:-crop, crop:-crop] for m_ in loopmasks] 132 | # torch.cuda.empty_cache() 133 | # fids = [] 134 | # print("computing svfid error") 135 | # for viewi in trange(V): 136 | # gt = videos[viewi] 137 | # pred = ours_rgb[viewi] 138 | # gt = [cv2.resize(gt_[12:12 + 336, 152: 152 + 336], (112, 112)) for gt_ in gt] 139 | # pred = [cv2.resize(p_[12:12 + 336, 152: 152 + 336], (112, 112)) for p_ in pred] 140 | # gt = torch.tensor(np.array(gt)).cuda().float() / 255 141 | # pred = torch.tensor(np.array(pred)).cuda().float() / 255 142 | # try: 143 | # fid = svfid(gt, pred) 144 | # except Exception as e: 145 | # print(e) 146 | # fid = -1 147 | # 148 | # fids.append(fid) 149 | if COMPUTE_STATIC: 150 | torch.cuda.empty_cache() 151 | print("computing static error") 152 | static_psnr = [] 153 | static_ssim = [] 154 | from evaluations.metrics import compute_img_metric 155 | for viewi in trange(V): 156 | gt = videos[viewi] 157 | pred = ours_rgb[viewi] 158 | frm_min = min(len(gt), len(pred)) 159 | gt, pred = gt[:frm_min] / 255, pred[:frm_min] / 255 160 | lmask = loopmasks[viewi] 161 | psnr = compute_img_metric(torch.tensor(gt), torch.tensor(pred), "psnr", torch.tensor(lmask[None])) 162 | ssim = compute_img_metric(torch.tensor(gt), torch.tensor(pred), "ssim", torch.tensor(lmask[None])) 163 | static_psnr.append(psnr) 164 | static_ssim.append(ssim) 165 | else: 166 | static_psnr = [0] * V 167 | static_ssim = [1] * V 168 | 169 | if COMPUTE_DYN: 170 | torch.cuda.empty_cache() 171 | dyns = [] 172 | print("computing dynamic error") 173 | for viewi in trange(V): 174 | gt = videos[viewi] 175 | pred = ours_rgb[viewi] 176 | stdgt = np.std(gt, axis=0) 177 | stdpred = np.std(pred, axis=0) 178 | err = ((stdgt - stdpred) ** 2).mean() 179 | dyns.append(err) 180 | else: 181 | dyns = [0] * V 182 | 183 | if COMPUTE_LPIPS: 184 | torch.cuda.empty_cache() 185 | lpips = [] 186 | lpips_sw = [] 187 | print("computing lpips error") 188 | for viewi in trange(V): 189 | gt = videos[viewi] 190 | pred = ours_rgb[viewi] 191 | gt = torch.tensor(np.array(gt)).cuda().float() 192 | pred = torch.tensor(np.array(pred)).cuda().float() 193 | lpip = compute_lpips(pred, gt) 194 | lpipsw = compute_lpips_slidewindow(pred, gt) 195 | lpips.append(lpip) 196 | lpips_sw.append(lpipsw) 197 | else: 198 | lpips = [0] * V 199 | lpips_sw = [0] * V 200 | 201 | patch_sizes = [5, 11, 17] 202 | stride_sizes = [2, 4, 6] 203 | patcht_sizes = [7, 5, 3] 204 | stridet_sizes = [1, 1, 1] 205 | if COMPUTE_LOOPQ: 206 | torch.cuda.empty_cache() 207 | loop_qualitys = [] 208 | print("computing Loop Quality") 209 | for viewi in trange(V): 210 | gt = videos[viewi] 211 | pred = ours_rgb[viewi] 212 | gt = torch.tensor(np.array(gt)).cuda().float().permute(3, 0, 1, 2)[None] 213 | pred = torch.tensor(np.array(pred)).cuda().float().permute(3, 0, 1, 2)[None] 214 | 215 | loop_quality = [] 216 | for i, (psz, ssz, pszt, sszt) in enumerate(zip(patch_sizes, stride_sizes, patcht_sizes, stridet_sizes)): 217 | pred_seam = torch.cat([ 218 | pred[:, :, -pszt + 1:], pred[:, :, :pszt - 1] 219 | ], dim=2) 220 | loop_quality.append(compute_nnerr(pred_seam, gt, psz, ssz, pszt, sszt)) 221 | 222 | loop_qualitys.append(loop_quality) 223 | else: 224 | loop_qualitys = [[0] * len(patch_sizes)] * V 225 | 226 | if COMPUTE_NNMSE: 227 | torch.cuda.empty_cache() 228 | nnmses_complete = [] 229 | nnmses_coherent = [] 230 | print("computing NN error") 231 | for viewi in trange(V): 232 | gt = videos[viewi] 233 | pred = ours_rgb[viewi] 234 | gt = torch.tensor(np.array(gt)).cuda().float().permute(3, 0, 1, 2)[None] 235 | pred = torch.tensor(np.array(pred)).cuda().float().permute(3, 0, 1, 2)[None] 236 | 237 | complete, coherent = [], [] 238 | for i, (psz, ssz, pszt, sszt) in enumerate(zip(patch_sizes, stride_sizes, patcht_sizes, stridet_sizes)): 239 | complete.append(compute_nnerr(gt, pred, psz, ssz, pszt, sszt)) 240 | coherent.append(compute_nnerr(pred, gt, psz, ssz, pszt, sszt)) 241 | 242 | nnmses_complete.append(complete) # forward 243 | nnmses_coherent.append(coherent) # backward 244 | else: 245 | nnmses_complete = [[0] * len(patch_sizes)] * V 246 | nnmses_coherent = [[0] * len(patch_sizes)] * V 247 | 248 | mean = lambda x: sum(x) / len(x) 249 | names = ["name", "nnf", "nnb", "dyn", "lpips", "lpips_sw", "loop", "psnr", "ssim"] + \ 250 | [f"nnf_p{p}s{s}pt{pt}st{st}" for p, s, pt, st in zip(patch_sizes, stride_sizes, patcht_sizes, stridet_sizes)] + \ 251 | [f"nnb_p{p}s{s}pt{pt}st{st}" for p, s, pt, st in zip(patch_sizes, stride_sizes, patcht_sizes, stridet_sizes)] + \ 252 | [f"loop_p{p}s{s}pt{pt}st{st}" for p, s, pt, st in zip(patch_sizes, stride_sizes, patcht_sizes, stridet_sizes)] 253 | with open(moviebase + "metrics.txt", 'w') as f: 254 | f.write(", ".join(names) + "\n") 255 | dataname = os.path.basename(datadir) 256 | 257 | forwards = np.zeros(len(patch_sizes) + 1) 258 | backwards = np.zeros(len(patch_sizes) + 1) 259 | loops = np.zeros(len(patch_sizes) + 1) 260 | for viewi in range(V): 261 | f.write(f"{dataname}_view{viewi}, ") 262 | f.write(", ".join(map(str, 263 | [mean(nnmses_complete[viewi]), mean(nnmses_coherent[viewi]), 264 | dyns[viewi], lpips[viewi], lpips_sw[viewi], mean(loop_qualitys[viewi]), 265 | static_psnr[viewi], static_ssim[viewi]]))) 266 | f.write(", ") 267 | f.write(", ".join(map(str, nnmses_complete[viewi]))) 268 | f.write(", ") 269 | f.write(", ".join(map(str, nnmses_coherent[viewi]))) 270 | f.write(", ") 271 | f.write(", ".join(map(str, loop_qualitys[viewi]))) 272 | f.write("\n") 273 | 274 | forwards[:len(patch_sizes)] += nnmses_complete[viewi] 275 | forwards[-1] += mean(nnmses_complete[viewi]) 276 | backwards[:len(patch_sizes)] += nnmses_coherent[viewi] 277 | backwards[-1] += mean(nnmses_coherent[viewi]) 278 | loops[:len(patch_sizes)] += loop_qualitys[viewi] 279 | loops[-1] += mean(loop_qualitys[viewi]) 280 | 281 | forwards = forwards / V 282 | backwards = backwards / V 283 | loops = loops / V 284 | f.write(f"{dataname}, ") 285 | f.write(", ".join(map(str, 286 | [forwards[-1], backwards[-1], 287 | mean(dyns), mean(lpips), mean(lpips_sw), loops[-1], 288 | mean(static_psnr), mean(static_ssim)]))) 289 | f.write(", ") 290 | f.write(", ".join(map(str, forwards[:-1].tolist()))) 291 | f.write(", ") 292 | f.write(", ".join(map(str, backwards[:-1].tolist()))) 293 | f.write(", ") 294 | f.write(", ".join(map(str, loops[:-1].tolist()))) 295 | f.write("\n") 296 | 297 | 298 | if __name__ == '__main__': 299 | parser = config_parser() 300 | args = parser.parse_args() 301 | np.random.seed(args.seed) 302 | torch.manual_seed(args.seed) 303 | torch.cuda.manual_seed_all(args.seed) 304 | 305 | evaluate(args) 306 | 307 | -------------------------------------------------------------------------------- /scripts/script_export_mesh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import imageio 5 | import json 6 | from dataloader import load_llff_data 7 | from config_parser import config_parser 8 | from utils import save_obj_with_vcolor, save_obj_multimaterial, normalize_uv, cull_unused 9 | 10 | PATCH_SIZE = 16 11 | 12 | 13 | # merge the neighbor pixels to prevent tiling artifact 14 | # failed experiment, turns out this will leads to artifact 15 | # so this function is currently unused 16 | def merge_neighbor_pixels(v, f, uv, uvf, atlas): # len(f) == len(uvf) 17 | # faces of a quad: 0 - 1 two face is arraged as (0, 1, 3), (3, 2, 0) 18 | # | \ | 19 | # 2 - 3 20 | h, w, _ = atlas.shape 21 | f = f.reshape(-1, 6) 22 | uvf = uvf.reshape(-1, 6) 23 | 24 | # find neighbor edge (horizon 25 | edge_h = f[:, [0, 4, 1, 2]].reshape(-1, 2) 26 | edge_h_uv = uvf[:, [0, 4, 1, 2]].reshape(-1, 2) 27 | edge_h_flat = edge_h[:, 0] + edge_h[:, 1] * len(v) 28 | sortidx = np.argsort(edge_h_flat) 29 | edge_h_flat = edge_h_flat[sortidx] 30 | edge_h_uv = edge_h_uv[sortidx] 31 | _, idx, counts = np.unique(edge_h_flat, return_index=True, return_counts=True) 32 | idx0 = idx[np.argwhere(counts == 2)][:, 0] 33 | idx1 = idx0 + 1 34 | 35 | uv_select = uv[edge_h_uv[idx0]] 36 | uv_select1 = uv[edge_h_uv[idx1]] 37 | # | | 38 | x_idx0, x_idx1, y_idx0, y_idx1 = uv_select[:, 0, 0], uv_select1[:, 0, 0], uv_select[:, 0, 1], uv_select1[:, 0, 1] 39 | x_idx0 = np.round((x_idx0 + 1) / (2 / (w - 1))).astype(np.int32) 40 | x_idx1 = np.round((x_idx1 + 1) / (2 / (w - 1))).astype(np.int32) 41 | y_idx0 = np.round((y_idx0 + 1) / (2 / (h - 1))).astype(np.int32) 42 | y_idx1 = np.round((y_idx1 + 1) / (2 / (h - 1))).astype(np.int32) 43 | rang = np.arange(PATCH_SIZE) 44 | pix_loc0 = [np.stack([x_idx0[:, None].repeat(len(rang), 1), y_idx0[:, None] + rang[None]], axis=-1)] 45 | pix_loc1 = [np.stack([x_idx1[:, None].repeat(len(rang), 1), y_idx1[:, None] + rang[None]], axis=-1)] 46 | 47 | # find neighbor edge (vertical) 48 | edge_v = f[:, [0, 1, 4, 2]].reshape(-1, 2) 49 | edge_v_uv = uvf[:, [0, 1, 4, 2]].reshape(-1, 2) 50 | edge_v_flat = edge_v[:, 0] + edge_v[:, 1] * len(v) 51 | sortidx = np.argsort(edge_v_flat) 52 | edge_v_flat = edge_v_flat[sortidx] 53 | edge_v_uv = edge_v_uv[sortidx] 54 | _, idx, counts = np.unique(edge_v_flat, return_index=True, return_counts=True) 55 | idx0 = idx[np.argwhere(counts == 2)][:, 0] 56 | idx1 = idx0 + 1 57 | 58 | uv_select = uv[edge_v_uv[idx0]] 59 | uv_select1 = uv[edge_v_uv[idx1]] 60 | # -- 61 | # -- 62 | x_idx0, x_idx1, y_idx0, y_idx1 = uv_select[:, 0, 0], uv_select1[:, 0, 0], uv_select[:, 0, 1], uv_select1[:, 0, 1] 63 | x_idx0 = np.round((x_idx0 + 1) / (2 / (w - 1))).astype(np.int32) 64 | x_idx1 = np.round((x_idx1 + 1) / (2 / (w - 1))).astype(np.int32) 65 | y_idx0 = np.round((y_idx0 + 1) / (2 / (h - 1))).astype(np.int32) 66 | y_idx1 = np.round((y_idx1 + 1) / (2 / (h - 1))).astype(np.int32) 67 | rang = np.arange(PATCH_SIZE) 68 | pix_loc0 += [np.stack([x_idx0[:, None] + rang[None], y_idx0[:, None].repeat(len(rang), 1)], axis=-1)] 69 | pix_loc1 += [np.stack([x_idx1[:, None] + rang[None], y_idx1[:, None].repeat(len(rang), 1)], axis=-1)] 70 | 71 | pix1 = np.concatenate(pix_loc0).reshape(-1, 2) 72 | pix2 = np.concatenate(pix_loc1).reshape(-1, 2) 73 | return pix1, pix2 74 | 75 | 76 | def export_mpv_repr(args): 77 | prefix = args.prefix 78 | expname = args.expname + args.expname_postfix 79 | outpath = os.path.join(prefix, args.mesh_folder, expname) 80 | os.makedirs(outpath, exist_ok=True) 81 | 82 | data_dir = os.path.join(prefix, args.datadir) 83 | _, poses, intrins, bds, render_poses, render_intrins = \ 84 | load_llff_data(data_dir, args.factor, False, 85 | bd_factor=(args.near_factor, args.far_factor), 86 | load_img=False) 87 | 88 | # figuring out the camera geometry 89 | normalize = lambda x: x / np.linalg.norm(x) 90 | up = normalize(poses[:, :3, 1].sum(0)).tolist() 91 | up[1] = -up[1] 92 | 93 | # Find a reasonable "focus depth" for this dataset 94 | close_depth, inf_depth = bds.min() * .9, bds.max() * 5. 95 | focal = 1. / (((1. - .75) / close_depth + .75 / inf_depth)) 96 | # Get radii for spiral path 97 | tt = poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T 98 | rads = np.abs(tt).max(0) * 0.8 99 | f, cy = intrins[:, 0, 0].mean(), intrins[:, 1, -1].mean() 100 | 101 | json_dict = { 102 | "fps": 25, 103 | "fov": np.arctan(cy / f) * 2 / np.pi * 180, 104 | "frame_count": args.mpv_frm_num, 105 | "near": float(bds.min()), 106 | "far": float(bds.max()), 107 | 108 | "up": up, 109 | "lookat": [0, 0, focal], 110 | "limit": rads.tolist(), 111 | } 112 | jsonobj = json.dumps(json_dict, indent=4) 113 | with open(os.path.join(outpath, "meta.json"), 'w') as f: 114 | f.write(jsonobj) 115 | 116 | # saving others 117 | ckpt_file = os.path.join(prefix, args.expdir, expname, "l5_epoch_0049.tar") 118 | state_dict = torch.load(ckpt_file) 119 | state_dict = state_dict['network_state_dict'] 120 | 121 | atlas_h_static, atlas_w_static = state_dict["self.atlas_full_h"], state_dict["self.atlas_full_w"] 122 | atlas_h_dynamic, atlas_w_dynamic = state_dict["self.atlas_full_dyn_h"], state_dict["self.atlas_full_dyn_w"] 123 | 124 | verts = state_dict['_verts'].cpu().numpy() 125 | 126 | # static mesh 127 | uvs_static = state_dict['uvs'].cpu().numpy() 128 | faces_static = state_dict['faces'].cpu().numpy() 129 | uvfaces_static = state_dict['uvfaces'].cpu().numpy() 130 | atlas_static = torch.sigmoid(state_dict['atlas']) 131 | atlas_static = np.clip(atlas_static[0].permute(1, 2, 0).cpu().numpy() * 255, 0, 255).astype(np.uint8) 132 | 133 | # dynamic mesh 134 | uvs_dynamic = state_dict['uvs_dyn'].cpu().numpy() 135 | faces_dynamic = state_dict['faces_dyn'].cpu().numpy() 136 | uvfaces_dynamic = state_dict['uvfaces_dyn'].cpu().numpy() 137 | atlas_dynamic = torch.sigmoid(state_dict['atlas_dyn']) 138 | atlas_dynamic = np.clip(atlas_dynamic.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) 139 | frame_num = len(atlas_dynamic) 140 | assert frame_num == args.mpv_frm_num, "Error: detect unmatched frame count" 141 | # saving geometry 142 | 143 | # pix1, pix2 = merge_neighbor_pixels(verts, faces_static, uvs_static, uvfaces_static, atlas_static) 144 | # color1, color2 = atlas_static[pix1[:, 1], pix1[:, 0]], atlas_static[pix2[:, 1], pix2[:, 0]] 145 | # color = np.minimum(color1, color2) 146 | # atlas_static[pix1[:, 1], pix1[:, 0]] = color 147 | # atlas_static[pix2[:, 1], pix2[:, 0]] = color 148 | # # will chagne atlas_dynamic 149 | # pix1, pix2 = merge_neighbor_pixels(verts, faces_dynamic, uvs_dynamic, uvfaces_dynamic, atlas_dynamic[0]) 150 | # color1, color2 = atlas_dynamic[:, pix1[:, 1], pix1[:, 0]], atlas_dynamic[:, pix2[:, 1], pix2[:, 0]] 151 | # color = np.minimum(color1, color2) 152 | # atlas_dynamic[:, pix1[:, 1], pix1[:, 0]] = color 153 | # atlas_dynamic[:, pix2[:, 1], pix2[:, 0]] = color 154 | 155 | uvs_static = normalize_uv(uvs_static, atlas_h_static, atlas_w_static) 156 | uvs_dynamic = normalize_uv(uvs_dynamic, atlas_h_dynamic, atlas_w_dynamic) 157 | 158 | # save static 159 | staticv, staticf = cull_unused(verts, faces_static) 160 | staticuv, staticuvf = cull_unused(uvs_static, uvfaces_static) 161 | staticcolor = np.zeros_like(staticv) 162 | staticcolor[:, 0] = 1 163 | staticvc = np.concatenate([staticv, staticcolor], -1) 164 | 165 | dynamicv, dynamicf = cull_unused(verts, faces_dynamic) 166 | dynamicuv, dynamicuvf = cull_unused(uvs_dynamic, uvfaces_dynamic) 167 | dynamiccolor = np.zeros_like(dynamicv) 168 | dynamiccolor[:, 1] = 1 169 | dynamicvc = np.concatenate([dynamicv, dynamiccolor], -1) 170 | 171 | # concate two meshes 172 | newv = np.concatenate([staticvc, dynamicvc]) 173 | newuv = np.concatenate([staticuv, dynamicuv]) 174 | newf = np.concatenate([staticf, dynamicf + len(staticvc)]) 175 | newuvf = np.concatenate([staticuvf, dynamicuvf + len(staticuv)]) 176 | 177 | # order the face 178 | depth = newv[newf[:, 0]][:, 2] 179 | ordr = np.argsort(depth)[::-1] 180 | newf = newf[ordr] 181 | newuvf = newuvf[ordr] 182 | 183 | save_obj_with_vcolor(os.path.join(outpath, "geometry.obj"), 184 | newv, newf, newuv, newuvf) 185 | 186 | imageio.imwrite(os.path.join(outpath, "static.png"), atlas_static) 187 | vidoutpath = os.path.join(outpath, "dynamic") 188 | os.makedirs(vidoutpath, exist_ok=True) 189 | for i in range(frame_num): 190 | imageio.imwrite(os.path.join(vidoutpath, f"{i:04d}.png"), 191 | atlas_dynamic[i]) 192 | 193 | 194 | if __name__ == "__main__": 195 | parser = config_parser() 196 | parser.add_argument("--mesh_folder", type=str, default="meshes", 197 | help='') 198 | args = parser.parse_args() 199 | export_mpv_repr(args) 200 | -------------------------------------------------------------------------------- /scripts/script_owndata_step1_standardization.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import argparse 3 | import imageio 4 | import os 5 | import numpy as np 6 | 7 | parser = argparse.ArgumentParser(description='davince to my') 8 | parser.add_argument('--input_path', required=True) 9 | parser.add_argument('--output_prefix', default="../data", help="Where to put the results") 10 | parser.add_argument('--factor', default="1,2", help="factors") 11 | args = parser.parse_args() 12 | 13 | # args = argparse.Namespace( 14 | # input_path = "../data/1017ustspring.mov", 15 | # output_prefix = "../data/1017ustspring720p/", 16 | # factor = [1, 2], 17 | # ) 18 | 19 | if isinstance(args.factor, str): 20 | args.factor = list(map(int, args.factor.split(','))) 21 | print(f"Saving to {args.output_prefix}") 22 | 23 | 24 | def saving2prefix(frames, prefix, factors): 25 | avg_img = np.array(frames) 26 | avg_img = np.mean(avg_img, 0) 27 | avg_img = avg_img.astype(np.uint8) 28 | avg_outp = prefix + f"/images/{clip_id:04d}.png" 29 | os.makedirs(os.path.dirname(avg_outp), exist_ok=True) 30 | imageio.imwrite(avg_outp, avg_img) 31 | 32 | for factor in factors: 33 | vid_outp = prefix + f"/videos_{factor}/{clip_id:04d}.mp4" 34 | os.makedirs(os.path.dirname(vid_outp), exist_ok=True) 35 | images = [cv2.resize(im, None, None, 1 / factor, 1 / factor) for im in frames] 36 | imageio.mimwrite(vid_outp, images, fps=25, macro_block_size=1, quality=8) 37 | 38 | 39 | cap = cv2.VideoCapture(args.input_path) 40 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 41 | factors = args.factor 42 | if isinstance(factors, (int, float)): 43 | factors = [factors] 44 | 45 | clip_id = 0 46 | images = [] 47 | imgsums = [] 48 | while True: 49 | ret, img = cap.read() 50 | if not ret: 51 | if len(images) > 0: # saving 52 | saving2prefix(images, args.output_prefix, factors) 53 | print(f'saving videos of frame {len(images)}') 54 | clip_id += 1 55 | break 56 | 57 | sum_ = img.mean() 58 | imgsums.append(sum_) 59 | if sum_ < 10: # new sequence 60 | if len(images) > 0: # saving 61 | saving2prefix(images, args.output_prefix, factors) 62 | print(f'saving videos of frame {len(images)}') 63 | clip_id += 1 64 | # reinitialize 65 | images = [] 66 | else: 67 | images.append(img[..., ::-1]) 68 | 69 | import matplotlib.pyplot as plt 70 | 71 | plt.plot(imgsums) 72 | plt.show() 73 | -------------------------------------------------------------------------------- /scripts/script_owndata_step2_genllffpose.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Takes in the scenedir and register the camera poses 3 | ''' 4 | from colmaps import gen_poses 5 | import sys 6 | import argparse 7 | 8 | parser = argparse.ArgumentParser(description='colmap') 9 | parser.add_argument('--scenedir', type=str, required=True) 10 | parser.add_argument('--share_intrin', action='store_true') 11 | args = parser.parse_args() 12 | 13 | scenedir = args.scenedir 14 | share_intrin = args.share_intrin 15 | factors = 1 16 | use_lowres = False 17 | match_type = 'exhaustive_matcher' # exhaustive_matcher or sequential_matcher 18 | 19 | # ======================================================= 20 | 21 | if __name__ == '__main__': 22 | gen_poses(scenedir, match_type, factors, usedown=use_lowres, share_intrin=share_intrin) 23 | -------------------------------------------------------------------------------- /scripts/script_render_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import imageio 4 | import numpy as np 5 | import math 6 | import torch.nn as nn 7 | import time 8 | from torch.utils.tensorboard import SummaryWriter 9 | from torch.utils.data import DataLoader, Dataset 10 | from MPV import * 11 | 12 | from dataloader import load_mv_videos, poses_avg 13 | from utils import * 14 | import shutil 15 | from datetime import datetime 16 | import cv2 17 | from config_parser import config_parser 18 | from tqdm import tqdm, trange 19 | from copy import deepcopy 20 | 21 | 22 | def evaluate(args): 23 | device = 'cuda:0' 24 | if args.gpu_num <= 0: 25 | device = 'cpu' 26 | print(f"Using CPU for training") 27 | 28 | expname = args.expname + args.expname_postfix 29 | print(f"Rendering: {expname}") 30 | datadir = os.path.join(args.prefix, args.datadir) 31 | expdir = os.path.join(args.prefix, args.expdir) 32 | 33 | # figure out render_frm to be consistent 34 | render_frm = args.f if args.f > 0 else (120 // args.mpv_frm_num + 1) * args.mpv_frm_num 35 | print(f"loading render pose with {render_frm} frames") 36 | videos, FPS, poses, intrins, bds, render_poses, render_intrins = \ 37 | load_mv_videos(basedir=datadir, 38 | factor=args.factor, 39 | bd_factor=(args.near_factor, args.far_factor), 40 | recenter=True, 41 | render_frm=render_frm, 42 | render_scaling=args.render_scaling) 43 | 44 | H, W = videos[0][0].shape[0:2] 45 | V = len(videos) 46 | print('Loaded llff', V, H, W, poses.shape, intrins.shape, render_poses.shape, bds.shape) 47 | 48 | # figure out view to be rendered 49 | view_poses, view_intrins = render_poses.copy(), render_intrins.copy() 50 | render_t = np.arange(len(render_poses)) % args.mpv_frm_num 51 | if args.v == 'test': 52 | args.v = args.test_view_idx.split(',')[0] 53 | 54 | if len(args.v) > 0: 55 | render_t = render_t[:args.mpv_frm_num] 56 | if args.v[0] == 'r': 57 | v = int(args.v[1:]) 58 | view_poses[:] = view_poses[v:v+1] 59 | view_intrins[:] = render_intrins[v:v+1] 60 | print(f"Rendering view {v} in render_pose") 61 | else: 62 | v = int(args.v) 63 | view_poses[:] = poses[v:v+1] 64 | view_intrins[:] = intrins[v:v+1] 65 | print(f"Rendering view {v}") 66 | 67 | # figure out time to be rendered 68 | if len(args.t) > 0: 69 | if ',' in args.t and ':' not in args.t: 70 | time_range = list(map(int, args.t.split(','))) 71 | render_t = render_t[time_range] 72 | elif ':' in args.t: 73 | slices = args.t.split(',') 74 | render_t = [] 75 | for slic in slices: 76 | start, end = list(map(int, slic.split(':'))) 77 | step = 1 if start <= end else -1 78 | render_t.append(np.arange(start, end, step)) 79 | render_t = np.concatenate(render_t) 80 | else: 81 | time_range = [int(args.t)] 82 | render_t = render_t[time_range] 83 | 84 | view_poses = view_poses[:len(render_t)] 85 | view_intrins = view_intrins[:len(render_t)] 86 | print(f"Rendering time {render_t}") 87 | 88 | ref_pose = poses_avg(poses)[:, :4] 89 | ref_extrin = pose2extrin_np(ref_pose) 90 | ref_intrin = intrins[0] 91 | ref_near, ref_far = bds.min(), bds.max() 92 | 93 | # Create nerf model 94 | if args.model_type == "MPMesh": 95 | nerf = MPMeshVid(args, H, W, ref_extrin, ref_intrin, ref_near, ref_far) 96 | else: 97 | raise RuntimeError(f"Unrecognized model type {args.model_type}") 98 | 99 | nerf = nn.DataParallel(nerf, list(range(args.gpu_num))) 100 | nerf.to(device) 101 | 102 | view_extrins = pose2extrin_np(view_poses) 103 | view_extrins = torch.tensor(view_extrins).float() 104 | view_intrins = torch.tensor(view_intrins).float() 105 | 106 | ########################## 107 | # load from checkpoint 108 | ckpts = [os.path.join(expdir, expname, f) 109 | for f in sorted(os.listdir(os.path.join(expdir, expname))) if 'tar' in f] 110 | if len(ckpts) > 0: 111 | ckpt_path = ckpts[-1] 112 | print(f"Using ckpt {ckpt_path}") 113 | else: 114 | raise RuntimeError("Failed, cannot find any ckpts") 115 | print('Reloading from', ckpt_path) 116 | ckpt = torch.load(ckpt_path) 117 | 118 | state_dict = ckpt['network_state_dict'] 119 | nerf.module.init_from_mpi(state_dict) 120 | nerf.to(device) 121 | 122 | # ########################## 123 | # start rendering 124 | # ########################## 125 | 126 | print('Begin') 127 | moviebase = os.path.join(expdir, expname, f'renderonly') 128 | os.makedirs(moviebase, exist_ok=True) 129 | with torch.no_grad(): 130 | nerf.eval() 131 | 132 | rgbs = [] 133 | for viewi in trange(len(view_poses)): 134 | r_pose = view_extrins[viewi: viewi + 1] 135 | r_intrin = view_intrins[viewi: viewi + 1] 136 | t = render_t[viewi: viewi + 1] 137 | rgb, extra = nerf(H, W, r_pose, r_intrin, t) 138 | rgb = rgb.permute(0, 2, 3, 1).cpu().numpy()[0] 139 | rgbs.append(to8b(rgb)) 140 | 141 | if len(rgbs) < 3: 142 | print("too less frames, force to write images") 143 | args.type += 'seq' 144 | 145 | if 'seq' in args.type: 146 | for i, rgb in enumerate(rgbs): 147 | imageio.imwrite(moviebase + f'/view{args.v}t{args.t}_{i:04d}.png', rgb) 148 | else: 149 | imageio.mimwrite(moviebase + f'/view{args.v}t{args.t}.mp4', rgbs, fps=25, quality=8, macro_block_size=1) 150 | 151 | 152 | if __name__ == '__main__': 153 | parser = config_parser() 154 | parser.add_argument("--v", type=str, default='', 155 | help='render view control, empty to be render_pose, r# to be #-th render pose, ' 156 | '# to be #-th training pose') 157 | parser.add_argument("--t", type=str, default='', 158 | help='render time control, empty to be arange(len(views)), ' 159 | '#,#,# to be #-th frame, #:# to be #(include) to #(exclude) frame, :, can be mixed') 160 | parser.add_argument("--f", type=int, default=-1, 161 | help='overwrite the frame number when loading the render pose') 162 | parser.add_argument("--type", type=str, default='vid', 163 | help='choose among seq, vid') 164 | parser.add_argument("--render_scaling", type=float, default=1, 165 | help='radius of the render spire') 166 | 167 | args = parser.parse_args() 168 | np.random.seed(args.seed) 169 | torch.manual_seed(args.seed) 170 | torch.cuda.manual_seed_all(args.seed) 171 | 172 | evaluate(args) 173 | 174 | -------------------------------------------------------------------------------- /teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limacv/VideoLoop3D/615c665a7e3ffaf079c4fb5c9cc596a3bcf4e136/teaser.jpg -------------------------------------------------------------------------------- /train_3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import imageio 4 | import numpy as np 5 | import math 6 | import torch.nn as nn 7 | import time 8 | from torch.utils.tensorboard import SummaryWriter 9 | from MPI import * 10 | from torch.utils.data import Dataset, DataLoader 11 | from dataloader import load_mv_videos, poses_avg, load_llff_data 12 | from utils import * 13 | import shutil 14 | from datetime import datetime 15 | import cv2 16 | from tqdm import tqdm, trange 17 | from config_parser import config_parser 18 | 19 | 20 | class MVPatchDataset(Dataset): 21 | def __init__(self, resize_hw, videos, patch_size, patch_stride, poses, intrins, mode='average'): 22 | super().__init__() 23 | h_raw, w_raw, _ = videos[0][0].shape[-3:] 24 | self.h, self.w = resize_hw 25 | self.v = len(videos) 26 | self.poses = poses.clone().cpu() 27 | self.intrins = intrins.clone().cpu() 28 | self.intrins[:, :2] *= torch.tensor([self.w / w_raw, self.h / h_raw]).reshape(1, 2, 1).type_as(intrins) 29 | self.patch_h_size, self.patch_w_size = patch_size 30 | self.mode = mode 31 | if self.h * self.w < self.patch_h_size * self.patch_w_size: 32 | patch_wh_start = torch.tensor([[0, 0]]).long().reshape(-1, 2) 33 | pad_info = [0, 0, 0, 0] 34 | self.patch_h_size, self.patch_w_size = self.h, self.w 35 | else: 36 | patch_wh_start, pad_info = generate_patchinfo(self.h, self.w, patch_size, patch_stride) 37 | 38 | patch_wh_start = patch_wh_start[None, ...].expand(self.v, -1, 2) 39 | view_index = np.arange(self.v)[:, None, None].repeat(patch_wh_start.shape[1], axis=1) 40 | self.patch_wh_start = patch_wh_start.reshape(-1, 2).cpu() 41 | self.view_index = view_index.reshape(-1).tolist() 42 | 43 | self.images = [] 44 | self.dynmask = [] 45 | # TODO: for debug only, delete this 46 | # self.images = [torch.rand(3, self.h, self.w)] * self.v 47 | # self.dynmask = [torch.rand(self.h, self.w)] * self.v 48 | # return 49 | for video in videos: 50 | vid = np.array([cv2.resize(img, (self.w, self.h)) for img in video]) / 255 51 | # mid 52 | if self.mode == 'median': 53 | img = np.median(vid, axis=0) 54 | elif self.mode == 'average': 55 | # aveage 56 | img = vid.mean(axis=0) 57 | elif self.mode == 'first': 58 | img = vid[0] 59 | elif self.mode.startswith('dynamic'): 60 | # emphsize the dynamics 61 | weight = np.linalg.norm(vid - vid.mean(axis=0, keepdims=True), axis=-1, keepdims=True) 62 | k = self.mode.lstrip('dynamic') 63 | k = 1 if len(k) == 0 else float(k) 64 | weight = k * weight + (1 - k) 65 | weight = np.clip(weight, 1e-10, 999999) 66 | img = (vid * weight).sum(axis=0) / weight.sum(axis=0) 67 | elif self.mode.startswith('blur'): 68 | b = self.mode.lstrip('blur') 69 | b = 11 if len(b) == 0 else int(b) 70 | vid_blur = np.array([cv2.GaussianBlur(v_, (b, b), 0) for v_ in vid]) 71 | vid_blur_avg = vid_blur.mean(axis=0, keepdims=True) 72 | weight = np.linalg.norm(vid_blur - vid_blur_avg, axis=-1, keepdims=True) 73 | weight = np.clip(weight * 3, 0.001, 3) 74 | img = (vid_blur * weight).sum(axis=0) / weight.sum(axis=0) 75 | else: 76 | raise RuntimeError(f"Unrecognized vid2img_mode={self.mode}") 77 | 78 | img = torch.tensor(img).permute(2, 0, 1) 79 | loopmask = compute_loopable_mask(vid) 80 | loopmask = torch.tensor(loopmask).type_as(img) 81 | self.images.append(img) 82 | self.dynmask.append(loopmask) 83 | print(f"Dataset: generate {len(self)} patches for training, pad {pad_info} to videos") 84 | 85 | def __len__(self): 86 | return len(self.patch_wh_start) 87 | 88 | def __getitem__(self, item): 89 | w_start, h_start = self.patch_wh_start[item] 90 | view_idx = self.view_index[item] 91 | pose = self.poses[view_idx] 92 | intrin = get_new_intrin(self.intrins[view_idx], h_start, w_start).float() 93 | crops = self.images[view_idx][..., h_start: h_start + self.patch_h_size, w_start: w_start + self.patch_w_size] 94 | crops_ma = self.dynmask[view_idx][h_start: h_start + self.patch_h_size, w_start: w_start + self.patch_w_size] 95 | return w_start, h_start, pose, intrin, crops.cuda(), crops_ma.cuda() 96 | 97 | 98 | def train(): 99 | parser = config_parser() 100 | args = parser.parse_args() 101 | np.random.seed(args.seed) 102 | torch.manual_seed(args.seed) 103 | torch.cuda.manual_seed_all(args.seed) 104 | 105 | expname = args.expname + args.expname_postfix 106 | # set up multi-processing 107 | if args.gpu_num == -1: 108 | args.gpu_num = torch.cuda.device_count() 109 | print(f"Using {args.gpu_num} GPU(s)") 110 | 111 | print(f"Training: {expname}") 112 | datadir = os.path.join(args.prefix, args.datadir) 113 | expdir = os.path.join(args.prefix, args.expdir) 114 | videos, _, poses, intrins, bds, render_poses, render_intrins = \ 115 | load_mv_videos(basedir=datadir, 116 | factor=args.factor, 117 | bd_factor=(args.near_factor, args.far_factor), 118 | recenter=True) 119 | 120 | H, W = videos[0][0].shape[0:2] 121 | V = len(videos) 122 | print('Loaded llff', V, H, W, poses.shape, intrins.shape, render_poses.shape, bds.shape) 123 | 124 | ref_pose = poses_avg(poses)[:, :4] 125 | ref_extrin = pose2extrin_np(ref_pose) 126 | ref_intrin = intrins.mean(0) 127 | ref_near, ref_far = bds.min(), bds.max() 128 | 129 | # Summary writers 130 | writer = SummaryWriter(os.path.join(expdir, expname)) 131 | 132 | # Create log dir and copy the config file 133 | file_path = os.path.join(expdir, expname, f"source_{datetime.now().timestamp():.0f}") 134 | os.makedirs(file_path, exist_ok=True) 135 | f = os.path.join(file_path, 'args.txt') 136 | with open(f, 'w') as file: 137 | for arg in sorted(vars(args)): 138 | attr = getattr(args, arg) 139 | file.write('{} = {}\n'.format(arg, attr)) 140 | if args.config is not None: 141 | f = os.path.join(file_path, 'config.txt') 142 | with open(f, 'w') as file: 143 | file.write(open(args.config, 'r').read()) 144 | if args.config1 is not None and len(args.config1) > 0: 145 | f = os.path.join(file_path, 'config1.txt') 146 | with open(f, 'w') as file: 147 | file.write(open(args.config1, 'r').read()) 148 | files_copy = [fs for fs in os.listdir(".") if ".py" in fs] 149 | for fc in files_copy: 150 | shutil.copyfile(f"./{fc}", os.path.join(file_path, fc)) 151 | 152 | # Create nerf model 153 | if args.model_type == "MPMesh": 154 | nerf = MPMesh(args, H, W, ref_extrin, ref_intrin, ref_near, ref_far) 155 | else: 156 | raise RuntimeError(f"Unrecognized model type {args.model_type}") 157 | 158 | nerf = nn.DataParallel(nerf, list(range(args.gpu_num))) 159 | optimizer = nerf.module.get_optimizer() 160 | 161 | render_extrins = pose2extrin_np(render_poses) 162 | render_extrins = torch.tensor(render_extrins).float() 163 | render_intrins = torch.tensor(render_intrins).float() 164 | 165 | ###################### 166 | # if optimize poses 167 | poses = torch.tensor(poses) 168 | intrins = torch.tensor(intrins) 169 | 170 | ########################## 171 | # Load checkpoints 172 | ckpts = [os.path.join(expdir, expname, f) 173 | for f in sorted(os.listdir(os.path.join(expdir, expname))) if 'tar' in f] 174 | print('Found ckpts', ckpts) 175 | 176 | start = 0 177 | if len(args.init_from) > 0: 178 | ckpt_path = os.path.join(args.prefix, args.init_from) 179 | assert os.path.exists(ckpt_path), f"Trying to load from {ckpt_path} but it doesn't exist" 180 | print('Reloading from', ckpt_path) 181 | ckpt = torch.load(ckpt_path) 182 | 183 | start = ckpt['epoch_i'] 184 | state_dict = ckpt['network_state_dict'] 185 | nerf.module.init_from_mpi(state_dict) 186 | nerf.cuda() 187 | 188 | # begin of run one iteration (one patch) 189 | def run_iter(stepi, optimizer_, datainfo_): 190 | h_starts, w_starts, b_pose, b_intrin, b_rgbs, b_loopmask = datainfo_ 191 | b_extrin = pose2extrin_torch(b_pose) 192 | patch_h, patch_w = b_rgbs.shape[-2:] 193 | 194 | if args.add_intrin_noise: 195 | dxy = torch.rand(2).type_as(b_intrin) - 0.5 # half pixel 196 | b_intrin = b_intrin.clone() 197 | b_intrin[:, :2, 2] += dxy 198 | 199 | nerf.train() 200 | rgb, extra = nerf(patch_h, patch_w, b_extrin, b_intrin) 201 | if args.learn_loop_mask: 202 | loop_mask = rgb[:, -1] 203 | # simple MSE 204 | # loop_loss = img2mse(loop_mask, b_loopmask) 205 | 206 | # entropy loss 207 | loop_mask = torch.clamp(loop_mask, 0.001, 1 - 0.001) 208 | entropy = b_loopmask * torch.log(loop_mask) + (1 - b_loopmask) * torch.log(1 - loop_mask) 209 | loop_loss = - entropy.mean() 210 | 211 | rgb = rgb[:, :3] 212 | else: 213 | loop_loss = 0 214 | 215 | # RGB loss 216 | if args.scale_invariant: 217 | scale = torch.exp(torch.log((b_rgbs + 0.01) / (rgb.detach() + 0.01)).mean()) 218 | scale = (scale + 3) / 4 # prevent scaling ambiguouity 219 | rgb = rgb * scale 220 | img_loss = img2mse(rgb, b_rgbs) 221 | psnr = mse2psnr(img_loss) 222 | 223 | # define extra losses here 224 | args_var = vars(args) 225 | extra_losses = {} 226 | for k, v in extra.items(): 227 | if args_var[f"{k}_loss_weight"] > 0: 228 | extra_losses[k] = extra[k].mean() * args_var[f"{k}_loss_weight"] 229 | 230 | loss = img_loss + loop_loss 231 | for v in extra_losses.values(): 232 | loss = loss + v 233 | 234 | optimizer_.zero_grad() 235 | loss.backward() 236 | if hasattr(nerf.module, "post_backward"): 237 | nerf.module.post_backward() 238 | optimizer_.step() 239 | 240 | if stepi % args.i_img == 0: 241 | writer.add_scalar('aloss/psnr', psnr, stepi) 242 | writer.add_scalar('aloss/mse_loss', loss, stepi) 243 | for k, v in extra.items(): 244 | writer.add_scalar(f'{k}', float(v.mean()), stepi) 245 | for name, newlr in name_lrates: 246 | writer.add_scalar(f'lr/{name}', newlr, stepi) 247 | 248 | if stepi % args.i_print == 0: 249 | epoch_tqdm.set_description(f"[TRAIN] Iter: {stepi} Loss: {loss.item():.4f} PSNR: {psnr.item():.4f}", 250 | "|".join([f"{k}: {v.item():.4f}" for k, v in extra_losses.items()])) 251 | 252 | # end of run one iteration 253 | 254 | # ########################## 255 | # start training 256 | # ########################## 257 | print('Begin') 258 | old_density_loss_weight = args.density_loss_weight 259 | 260 | dataset = MVPatchDataset((H, W), videos, 261 | (args.patch_h_size, args.patch_w_size), 262 | (args.patch_h_stride, args.patch_w_stride), 263 | poses, intrins, args.vid2img_mode) 264 | 265 | # visualize the image 266 | for viewi, (img, loopma) in enumerate(zip(dataset.images, dataset.dynmask)): 267 | p = os.path.join(expdir, expname, f"imgvis_{args.vid2img_mode}", f"{viewi:04d}.png") 268 | os.makedirs(os.path.dirname(p), exist_ok=True) 269 | imageio.imwrite(p, to8b(img.permute(1, 2, 0).cpu().numpy())) 270 | pm = os.path.join(expdir, expname, f"loopvis", f"{viewi:04d}.png") 271 | os.makedirs(os.path.dirname(pm), exist_ok=True) 272 | imageio.imwrite(pm, to8b(loopma.cpu().numpy())) 273 | dataloader = DataLoader(dataset, 1, shuffle=True) 274 | 275 | iter_total_step = 0 276 | epoch_tqdm = trange(args.N_iters) 277 | for epoch_i in epoch_tqdm: 278 | if epoch_i < start: 279 | continue 280 | 281 | # doing epoch specific task 282 | if epoch_i == args.sparsify_epoch: 283 | print("Sparsifying mesh models") 284 | nerf.module.sparsify_faces(erode_num=args.sparsify_erode, alpha_thresh=args.sparsify_alpha_thresh) 285 | optimizer = nerf.module.get_optimizer() 286 | 287 | if epoch_i == args.direct2sh_epoch: 288 | print("Converting direct to data_sh") 289 | nerf.module.direct2sh() 290 | optimizer = nerf.module.get_optimizer() 291 | 292 | pct = np.clip(epoch_i / (args.density_loss_epoch + 1), 0, 1) 293 | args.density_loss_weight = pct * pct * old_density_loss_weight 294 | # print(f"densitylossweight = {args.density_loss_weight}") 295 | 296 | for iter_i, datainfo in enumerate(dataloader): 297 | if hasattr(nerf.module, "update_step"): 298 | nerf.module.update_step(iter_total_step) 299 | 300 | ### update learning rate ### 301 | name_lrates = nerf.module.get_lrate(iter_total_step) 302 | for (lrname, new_lrate), param_group in zip(name_lrates, optimizer.param_groups): 303 | param_group['lr'] = new_lrate 304 | 305 | # train for one interation 306 | datainfo = [d.cuda() for d in datainfo] 307 | run_iter(iter_total_step, optimizer, datainfo) 308 | 309 | iter_total_step += 1 310 | 311 | ################################ 312 | if (epoch_i + 1) % args.i_weights == 0: 313 | save_path = os.path.join(expdir, expname, f'epoch_{epoch_i:04d}.tar') 314 | save_dict = { 315 | 'epoch_i': epoch_i, 316 | 'network_state_dict': nerf.module.state_dict() 317 | } 318 | torch.save(save_dict, save_path) 319 | 320 | if (epoch_i + 1) % args.i_video == 0: 321 | moviebase = os.path.join(expdir, expname, f'epoch_{epoch_i:04d}_') 322 | if hasattr(nerf.module, "save_mesh"): 323 | prefix = os.path.join(expdir, expname, f"mesh_epoch_{epoch_i:04d}") 324 | nerf.module.save_mesh(prefix) 325 | 326 | if hasattr(nerf.module, "save_texture"): 327 | prefix = os.path.join(expdir, expname, f"texture_epoch_{epoch_i:04d}") 328 | nerf.module.save_texture(prefix) 329 | 330 | if args.learn_loop_mask and hasattr(nerf.module, "save_loopmask"): 331 | prefix = os.path.join(expdir, expname, f"loopable_epoch_{epoch_i:04d}") 332 | nerf.module.save_loopmask(prefix) 333 | 334 | print('render poses shape', render_extrins.shape, render_intrins.shape) 335 | with torch.no_grad(): 336 | nerf.eval() 337 | 338 | rgbs = [] 339 | loopmasks = [] 340 | for ri in range(len(render_extrins)): 341 | r_pose = render_extrins[ri:ri + 1] 342 | r_intrin = render_intrins[ri:ri + 1] 343 | 344 | rgbl, extra = nerf(H, W, r_pose, r_intrin) 345 | if args.learn_loop_mask: 346 | rgb, loopmask = rgbl[:, :3], rgbl[:, -1] 347 | loopmask = loopmask[0].cpu().numpy() 348 | loopmask = np.stack([np.zeros_like(loopmask), loopmask, np.zeros_like(loopmask)], -1) 349 | loopmasks.append(loopmask) 350 | else: 351 | rgb = rgbl 352 | rgb = rgb[0].permute(1, 2, 0).cpu().numpy() 353 | rgbs.append(rgb) 354 | 355 | rgbs = np.array(rgbs) 356 | imageio.mimwrite(moviebase + '_rgb.mp4', to8b(rgbs), fps=25, quality=8) 357 | if len(loopmasks) > 0: 358 | loopmasks = np.array(loopmasks) 359 | imageio.mimwrite(moviebase + '_loopable.mp4', to8b(loopmasks), fps=25, quality=8) 360 | 361 | 362 | if __name__ == '__main__': 363 | train() 364 | -------------------------------------------------------------------------------- /train_3dvid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import imageio 4 | import numpy as np 5 | import math 6 | import torch.nn as nn 7 | import time 8 | from torch.utils.tensorboard import SummaryWriter 9 | from torch.utils.data import DataLoader, Dataset 10 | from MPV import * 11 | 12 | from dataloader import load_mv_videos, poses_avg 13 | from utils import * 14 | import shutil 15 | from datetime import datetime 16 | import cv2 17 | from config_parser import config_parser 18 | from tqdm import tqdm, trange 19 | from copy import deepcopy 20 | 21 | 22 | class MVVidPatchDataset(Dataset): 23 | def __init__(self, resize_hw, videos, patch_size, patch_stride, poses, intrins, 24 | loss_configs=None): 25 | super().__init__() 26 | h_raw, w_raw, _ = videos[0][0].shape[-3:] 27 | self.h, self.w = resize_hw 28 | self.v = len(videos) 29 | self.poses = poses.clone().cpu() 30 | self.intrins = intrins.clone().cpu() 31 | self.intrins[:, :2] *= torch.tensor([self.w / w_raw, self.h / h_raw]).reshape(1, 2, 1).type_as(intrins) 32 | self.patch_h_size, self.patch_w_size = patch_size 33 | if self.h * self.w < self.patch_h_size * self.patch_w_size: 34 | patch_wh_start = torch.tensor([[0, 0]]).long().reshape(-1, 2) 35 | pad_info = [0, 0, 0, 0] 36 | self.patch_h_size, self.patch_w_size = self.h, self.w 37 | else: 38 | patch_wh_start, pad_info = generate_patchinfo(self.h, self.w, patch_size, patch_stride) 39 | 40 | patch_wh_start = patch_wh_start[None, ...].expand(self.v, -1, 2) 41 | view_index = np.arange(self.v)[:, None, None].repeat(patch_wh_start.shape[1], axis=1) 42 | self.patch_wh_start = patch_wh_start.reshape(-1, 2).cpu() 43 | self.view_index = view_index.reshape(-1).tolist() 44 | self.loss_configs = loss_configs 45 | assert len(self.loss_configs) == self.v 46 | 47 | self.videos = [] 48 | for video in videos: 49 | vid = np.array([cv2.resize(img, (self.w, self.h)) for img in video]) 50 | vid = torch.tensor(vid, device='cpu') 51 | vid = (vid / 255).permute(0, 3, 1, 2) 52 | vid = torchf.pad(vid, pad_info) 53 | self.videos.append(vid) 54 | print(f"Dataset: generate {len(self)} patches for training, pad {pad_info} to videos") 55 | 56 | def __len__(self): 57 | return len(self.patch_wh_start) 58 | 59 | def __getitem__(self, item): 60 | w_start, h_start = self.patch_wh_start[item] 61 | view_idx = self.view_index[item] 62 | pose = self.poses[view_idx] 63 | intrin = get_new_intrin(self.intrins[view_idx], h_start, w_start).float() 64 | crops = self.videos[view_idx][..., h_start: h_start + self.patch_h_size, w_start: w_start + self.patch_w_size] 65 | cfg = deepcopy(self.loss_configs[view_idx]) 66 | return w_start, h_start, pose, intrin, crops, cfg 67 | 68 | 69 | def train(args): 70 | device = 'cuda:0' 71 | if args.gpu_num <= 0: 72 | device = 'cpu' 73 | print(f"Using CPU for training") 74 | 75 | expname = args.expname + args.expname_postfix 76 | print(f"Training: {expname}") 77 | datadir = os.path.join(args.prefix, args.datadir) 78 | expdir = os.path.join(args.prefix, args.expdir) 79 | videos, FPS, poses, intrins, bds, render_poses, render_intrins = \ 80 | load_mv_videos(basedir=datadir, 81 | factor=args.factor, 82 | bd_factor=(args.near_factor, args.far_factor), 83 | recenter=True) 84 | 85 | H, W = videos[0][0].shape[0:2] 86 | V = len(videos) 87 | 88 | print('Loaded llff', V, H, W, poses.shape, intrins.shape, render_poses.shape, bds.shape) 89 | test_view = args.test_view_idx 90 | test_view = list(map(int, test_view.split(','))) if len(test_view) > 0 else [] 91 | train_view = sorted(list(set(range(V)) - set(test_view))) 92 | # filter out test view 93 | videos = [videos[train_i] for train_i in train_view] 94 | poses = poses[train_view] 95 | intrins = intrins[train_view] 96 | print(f'Training view: {train_view}') 97 | 98 | ref_pose = poses_avg(poses)[:, :4] 99 | ref_extrin = pose2extrin_np(ref_pose) 100 | ref_intrin = intrins[0] 101 | ref_near, ref_far = bds.min(), bds.max() 102 | 103 | # Resove pyramid related configs, controled by (pyr_stage, pyr_factor, N_iters) 104 | # or (pyr_minimal_dim, pyr_factor, pyr_num_epoch) 105 | if args.pyr_minimal_dim < 0: 106 | # store the iter_num when starting the stage 107 | pyr_stages = list(map(int, args.pyr_stage.split(','))) if len(args.pyr_stage) > 0 else [] 108 | pyr_stages = np.array([0] + pyr_stages + [args.N_iters]) # one default stage 109 | pyr_num_epoch = pyr_stages[1:] - pyr_stages[:-1] 110 | pyr_factors = [args.pyr_factor ** i for i in list(range(len(pyr_num_epoch)))[::-1]] 111 | pyr_hw = [(int(H * f), int(W * f)) for f in pyr_factors] 112 | else: 113 | num_stage = int(np.log(args.pyr_minimal_dim / min(H, W)) / np.log(args.pyr_factor)) + 1 114 | pyr_factors = [args.pyr_factor ** i for i in list(range(num_stage))[::-1]] 115 | pyr_hw = [(int(H * f), int(W * f)) for f in pyr_factors] 116 | pyr_num_epoch = [args.pyr_num_epoch] * num_stage 117 | print("Pyramid info: ") 118 | for leveli, (f_, hw_, num_step_) in enumerate(zip(pyr_factors, pyr_hw, pyr_num_epoch)): 119 | print(f" level {leveli}: factor {f_} [{hw_[0]} x {hw_[1]}] run for {num_step_} iterations") 120 | # end of pyramid infomation 121 | 122 | # Summary writers 123 | writer = SummaryWriter(os.path.join(expdir, expname)) 124 | 125 | # Create log dir and copy the config file 126 | file_path = os.path.join(expdir, expname, f"source_{datetime.now().timestamp():.0f}") 127 | os.makedirs(file_path, exist_ok=True) 128 | f = os.path.join(file_path, 'args.txt') 129 | with open(f, 'w') as file: 130 | for arg in sorted(vars(args)): 131 | attr = getattr(args, arg) 132 | file.write('{} = {}\n'.format(arg, attr)) 133 | if len(args.config) > 0: 134 | f = os.path.join(file_path, 'config.txt') 135 | with open(f, 'w') as file: 136 | file.write(open(args.config, 'r').read()) 137 | if len(args.config1) > 0: 138 | f = os.path.join(file_path, 'config1.txt') 139 | with open(f, 'w') as file: 140 | file.write(open(args.config1, 'r').read()) 141 | files_copy = [fs for fs in os.listdir(".") if ".py" in fs] 142 | for fc in files_copy: 143 | shutil.copyfile(f"./{fc}", os.path.join(file_path, fc)) 144 | 145 | # Create nerf model 146 | if args.model_type == "MPMesh": 147 | nerf = MPMeshVid(args, H, W, ref_extrin, ref_intrin, ref_near, ref_far) 148 | else: 149 | raise RuntimeError(f"Unrecognized model type {args.model_type}") 150 | 151 | nerf = DataParallelCPU(nerf) if device == 'cpu' else nn.DataParallel(nerf, list(range(args.gpu_num))) 152 | nerf.to(device) 153 | render_extrins = pose2extrin_np(render_poses) 154 | render_extrins = torch.tensor(render_extrins).float() 155 | render_intrins = torch.tensor(render_intrins).float() 156 | 157 | poses = torch.tensor(poses) 158 | intrins = torch.tensor(intrins) 159 | 160 | # figuring out the loss config 161 | loss_config_other = { 162 | "loss_name": args.loss_name, 163 | "patch_size": args.swd_patch_size, 164 | "patcht_size": args.swd_patcht_size, 165 | "stride": args.swd_stride, 166 | "stridet": args.swd_stridet, 167 | "alpha": args.swd_alpha, 168 | "rou": args.swd_rou, 169 | "scaling": args.swd_scaling, 170 | "dist_fn": args.swd_dist_fn, 171 | "macro_block": args.swd_macro_block, 172 | "factor": args.swd_factor, 173 | } 174 | loss_config_ref = { 175 | "loss_name": args.loss_name_ref, 176 | "loss_gain": args.swd_loss_gain_ref, 177 | "patch_size": args.swd_patch_size_ref, 178 | "patcht_size": args.swd_patcht_size_ref, 179 | "stride": args.swd_stride_ref, 180 | "stridet": args.swd_stridet_ref, 181 | "alpha": args.swd_alpha_ref, 182 | "rou": args.swd_rou_ref, 183 | "scaling": args.swd_scaling_ref, 184 | "dist_fn": args.swd_dist_fn_ref, 185 | "macro_block": args.swd_macro_block, 186 | "factor": args.swd_factor_ref, 187 | } 188 | loss_cfgs = [loss_config_other] * V 189 | ref_idxs = list(map(int, args.loss_ref_idx.split(','))) 190 | for ref_idx in ref_idxs: 191 | loss_cfgs[ref_idx] = loss_config_ref 192 | loss_cfgs = [loss_cfgs[i] for i in train_view] 193 | 194 | epoch_total_step = 0 195 | iter_total_step = 0 196 | 197 | ########################## 198 | # load from checkpoint 199 | ckpts = [os.path.join(expdir, expname, f) 200 | for f in sorted(os.listdir(os.path.join(expdir, expname))) if 'tar' in f] 201 | print('Found ckpts', ckpts) 202 | 203 | if len(args.init_from) > 0: 204 | ckpt_path = os.path.join(args.prefix, args.init_from) 205 | assert os.path.exists(ckpt_path), f"Trying to load from {ckpt_path} but it doesn't exist" 206 | print('Reloading from', ckpt_path) 207 | ckpt = torch.load(ckpt_path) 208 | 209 | state_dict = ckpt['network_state_dict'] 210 | nerf.module.init_from_mpi(state_dict) 211 | nerf.to(device) 212 | 213 | # begin of run one iteration (one patch) 214 | def run_iter(stepi, optimizer_, datainfo_): 215 | datainfo_ = [d.to(device) if torch.is_tensor(d) else d for d in datainfo_] 216 | h_starts, w_starts, b_pose, b_intrin, b_rgbs, loss_cfg = datainfo_ 217 | if args.fp16: 218 | b_rgbs = b_rgbs.half() 219 | b_extrin = pose2extrin_torch(b_pose) 220 | patch_h, patch_w = b_rgbs.shape[-2:] 221 | 222 | if args.add_intrin_noise: 223 | dxy = torch.rand(2).type_as(b_intrin) - 0.5 # half pixel 224 | b_intrin = b_intrin.clone() 225 | b_intrin[:, :2, 2] += dxy 226 | 227 | nerf.train() 228 | rgb, extra = nerf(patch_h, patch_w, b_extrin, b_intrin, res=b_rgbs, losscfg=loss_cfg) 229 | 230 | swd_loss = extra.pop("swd").mean() 231 | # define extra losses here 232 | args_var = vars(args) 233 | extra_losses = {} 234 | for k, v in extra.items(): 235 | if args_var[f"{k}_loss_weight"] > 0: 236 | extra_losses[k] = extra[k].mean() * args_var[f"{k}_loss_weight"] 237 | 238 | loss = swd_loss 239 | for v in extra_losses.values(): 240 | loss = loss + v 241 | 242 | optimizer_.zero_grad() 243 | loss.backward() 244 | optimizer_.step() 245 | 246 | if (stepi + 1) % args.i_img == 0: 247 | writer.add_scalar('aloss/swd', swd_loss.item(), stepi) 248 | for k, v in extra.items(): 249 | writer.add_scalar(f'{k}', float(v.mean()), stepi) 250 | for name, newlr in name_lrates: 251 | writer.add_scalar(f'lr/{name}', newlr, stepi) 252 | 253 | if (stepi + 1) % args.i_print == 0: 254 | epoch_tqdm.set_description(f"[TRAIN] Iter: {stepi} Loss: {loss.item():.4f} SWD: {swd_loss.item():.4f}", 255 | "|".join([f"{k}: {v.item():.4f}" for k, v in extra_losses.items()])) 256 | 257 | # end of run one iteration 258 | 259 | # ########################## 260 | # start training 261 | # ########################## 262 | print('Begin') 263 | for pyr_i, (train_factor, hw, num_epoch) in enumerate(zip(pyr_factors, pyr_hw, pyr_num_epoch)): 264 | nerf.module.lod(train_factor) 265 | optimizer = nerf.module.get_optimizer(step=0) 266 | torch.cuda.empty_cache() 267 | # generate dataset and optimizer 268 | dataset = MVVidPatchDataset(hw, videos, 269 | (args.patch_h_size, args.patch_w_size), 270 | (args.patch_h_stride, args.patch_w_stride), 271 | poses, intrins, loss_configs=loss_cfgs) 272 | dataloader = DataLoader(dataset, 1, shuffle=True) 273 | epoch_tqdm = trange(num_epoch) 274 | for epoch_i in epoch_tqdm: 275 | for iter_i, datainfo in enumerate(dataloader): 276 | 277 | if hasattr(nerf.module, "update_step"): 278 | nerf.module.update_step(epoch_total_step) 279 | 280 | # update learning rate 281 | name_lrates = nerf.module.get_lrate(epoch_i) 282 | 283 | if args.lrate_adaptive: 284 | name_lrates = [(n_, lr_ / len(dataset)) for n_, lr_ in name_lrates] 285 | 286 | for (lrname, new_lrate), param_group in zip(name_lrates, optimizer.param_groups): 287 | param_group['lr'] = new_lrate 288 | 289 | # train for one interation 290 | run_iter(iter_total_step, optimizer, datainfo) 291 | 292 | iter_total_step += 1 293 | 294 | # saving after epoch 295 | if (epoch_total_step + 1) % args.i_weights == 0: 296 | save_path = os.path.join(expdir, expname, f'l{pyr_i}_epoch_{epoch_i:04d}.tar') 297 | save_dict = { 298 | 'epoch_i': epoch_i, 299 | 'epoch_total_step': epoch_total_step, 300 | 'iter_total_step': iter_total_step, 301 | 'pyr_i': pyr_i, 302 | 'train_factor': train_factor, 303 | 'hw': hw, 304 | 'network_state_dict': nerf.module.state_dict(), 305 | } 306 | torch.save(save_dict, save_path) 307 | 308 | if (epoch_total_step + 1) % args.i_video == 0: 309 | moviebase = os.path.join(expdir, expname, f'l{pyr_i}_{epoch_i:04d}_') 310 | if hasattr(nerf.module, "save_mesh"): 311 | prefix = os.path.join(expdir, expname, f"mesh_l{pyr_i}_{epoch_i:04d}") 312 | nerf.module.save_mesh(prefix) 313 | 314 | if hasattr(nerf.module, "save_texture"): 315 | prefix = os.path.join(expdir, expname, f"texture_l{pyr_i}_{epoch_i:04d}") 316 | nerf.module.save_texture(prefix) 317 | 318 | print('render poses shape', render_extrins.shape, render_intrins.shape) 319 | with torch.no_grad(): 320 | nerf.eval() 321 | 322 | rgbs = [] 323 | for ri in range(len(render_extrins)): 324 | r_pose = render_extrins[ri:ri + 1] 325 | r_intrin = render_intrins[ri:ri + 1] 326 | 327 | rgb, extra = nerf(H, W, r_pose, r_intrin, ts=[ri % args.mpv_frm_num]) 328 | rgb = rgb[0].permute(1, 2, 0).cpu().numpy() 329 | rgbs.append(rgb) 330 | 331 | rgbs = np.array(rgbs) 332 | imageio.mimwrite(moviebase + '_rgb.mp4', to8b(rgbs), fps=FPS, quality=8) 333 | 334 | epoch_total_step += 1 335 | 336 | 337 | if __name__ == '__main__': 338 | parser = config_parser() 339 | args = parser.parse_args() 340 | np.random.seed(args.seed) 341 | torch.manual_seed(args.seed) 342 | torch.cuda.manual_seed_all(args.seed) 343 | 344 | train(args) 345 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision.utils import save_image 8 | from torchvision.transforms import GaussianBlur 9 | from pytorch3d.structures import Meshes 10 | from pytorch3d.renderer import rasterize_meshes 11 | from pytorch3d.renderer.mesh.rasterizer import Fragments 12 | import cv2 13 | 14 | 15 | img2mse = lambda x, y: torch.mean((x - y) ** 2) 16 | mse2psnr = lambda x: -10. * torch.log(x) / torch.log(torch.tensor([10.]).type_as(x)) 17 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 18 | 19 | 20 | class SimpleRasterizer(nn.Module): 21 | def __init__(self, raster_settings=None, adaptive_layernum=True): 22 | """ 23 | max_faces_per_bin = int(max(10000, meshes._F / 5)) 24 | """ 25 | super().__init__() 26 | if raster_settings is None: 27 | raster_settings = RasterizationSettings() 28 | self.raster_settings = raster_settings 29 | self.adaptive_layer_num = adaptive_layernum 30 | 31 | def forward(self, vertices, faces): 32 | """ 33 | Args: 34 | vertices: B, N, 3 35 | faces: B, N, 3 36 | """ 37 | raster_settings = self.raster_settings 38 | 39 | # By default, turn on clip_barycentric_coords if blur_radius > 0. 40 | # When blur_radius > 0, a face can be matched to a pixel that is outside the 41 | # face, resulting in negative barycentric coordinates. 42 | clip_barycentric_coords = raster_settings.clip_barycentric_coords 43 | if clip_barycentric_coords is None: 44 | clip_barycentric_coords = raster_settings.blur_radius > 0.0 45 | 46 | # If not specified, infer perspective_correct and z_clip_value from the camera 47 | perspective_correct = True 48 | z_clip = raster_settings.z_clip_value 49 | # z_clip should be set to >0 value if there are some meshes comming near the camera 50 | 51 | fragment = rasterize_meshes( 52 | Meshes(vertices, faces), 53 | image_size=raster_settings.image_size, 54 | blur_radius=raster_settings.blur_radius, 55 | faces_per_pixel=raster_settings.faces_per_pixel, 56 | bin_size=raster_settings.bin_size, 57 | max_faces_per_bin=raster_settings.max_faces_per_bin, 58 | clip_barycentric_coords=clip_barycentric_coords, 59 | perspective_correct=perspective_correct, 60 | cull_backfaces=raster_settings.cull_backfaces, 61 | z_clip_value=z_clip, 62 | cull_to_frustum=raster_settings.cull_to_frustum, 63 | ) 64 | if self.adaptive_layer_num: 65 | with torch.no_grad(): 66 | pix_to_face = fragment[0] 67 | pix_to_face_max = pix_to_face.reshape(-1, pix_to_face.shape[-1]).max(dim=0)[0] 68 | num_layer = torch.count_nonzero(pix_to_face_max > 0).item() 69 | fragment = [frag[:, :, :, :num_layer] for frag in fragment] 70 | return Fragments(*fragment) # pix_to_face, zbuf, bary_coords, dists 71 | 72 | 73 | def frag2uv(frag: Fragments, uvs, uvfaces): 74 | """ 75 | return MPI mask, uv coordinate 76 | """ 77 | pixel_to_face, depths, bary_coords = frag.pix_to_face, frag.zbuf, frag.bary_coords 78 | # currently the batching is not supported 79 | mask = pixel_to_face.reshape(-1) >= 0 80 | mask_flat = mask.reshape(-1) 81 | faces_ma_dyn = pixel_to_face.reshape(-1)[mask_flat] 82 | uv_indices = uvfaces[faces_ma_dyn] 83 | uvs = uvs[uv_indices] # N, 3, n_feat 84 | bary_coords_ma = bary_coords.reshape(-1, 3)[mask_flat, :] # N, 3 85 | uvs = (bary_coords_ma[..., None] * uvs).sum(dim=-2) 86 | return mask, uvs 87 | 88 | 89 | class ParamsWithGradGain(nn.Module): 90 | def __init__(self, param, grad_gain=1.): 91 | super(ParamsWithGradGain, self).__init__() 92 | if grad_gain == 0: 93 | self.register_buffer("param", param) 94 | else: 95 | self.register_parameter("param", nn.Parameter(param, requires_grad=True)) 96 | 97 | def grad_gain_fn(grad): 98 | return grad * grad_gain 99 | 100 | if grad_gain != 1: 101 | self.param.register_hook(grad_gain_fn) 102 | 103 | def forward(self): 104 | return self.param 105 | 106 | 107 | # class InputAdaptor(nn.Module): 108 | # def __init__(self, H, W): 109 | # super().__init__() 110 | # flows = torch 111 | # 112 | # def forward(self, rgb): 113 | 114 | 115 | def generate_patchinfo(H_, W_, patch_size_, patch_stride_): 116 | patch_h_size, patch_w_size = patch_size_ 117 | patch_h_stride, patch_w_stride = patch_stride_ 118 | 119 | # generate patch information 120 | patch_h_start = np.arange(0, H_ - patch_h_size + patch_h_stride, patch_h_stride) 121 | patch_w_start = np.arange(0, W_ - patch_w_size + patch_w_stride, patch_w_stride) 122 | 123 | patch_wh_start = np.meshgrid(patch_h_start, patch_w_start) 124 | patch_wh_start = np.stack(patch_wh_start[::-1], axis=-1).reshape(-1, 2)[None, ...] 125 | 126 | patch_wh_start = patch_wh_start.reshape(-1, 2) 127 | patch_wh_start = torch.tensor(patch_wh_start) 128 | 129 | H_pad = patch_h_start.max() + patch_h_size - H_ 130 | W_pad = patch_w_start.max() + patch_w_size - W_ 131 | assert patch_h_stride > H_pad >= 0 and patch_w_stride > W_pad >= 0, "bug occurs!" 132 | 133 | pad_info = [0, W_pad, 0, H_pad] 134 | return patch_wh_start, pad_info 135 | 136 | 137 | def xyz2uv_stereographic(xyz: torch.Tensor, normalized=False): 138 | """ 139 | xyz: tensor of size (B, 3) 140 | """ 141 | if not normalized: 142 | xyz = xyz / xyz.norm(dim=-1, keepdim=True) 143 | x, y, z = torch.split(xyz, 1, dim=-1) 144 | z = torch.clamp_max(z, 0.99) 145 | denorm = torch.reciprocal(-z + 1) 146 | u, v = x * denorm, y * denorm 147 | return torch.cat([u, v], dim=-1) 148 | 149 | 150 | def uv2xyz_stereographic(uv: torch.Tensor): 151 | u, v = torch.split(uv, 1, dim=-1) 152 | u2v2 = u ** 2 + v ** 2 153 | x = u * 2 / (u2v2 + 1) 154 | y = v * 2 / (u2v2 + 1) 155 | z = (u2v2 - 1) / (u2v2 + 1) 156 | return torch.cat([x, y, z], dim=-1) 157 | 158 | 159 | def get_rays_np(H, W, K, c2w): 160 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 161 | pixelpoints = np.stack([i, j, np.ones_like(i)], -1)[..., np.newaxis] 162 | localpoints = np.linalg.inv(K) @ pixelpoints 163 | 164 | rays_d = (c2w[:3, :3] @ localpoints)[..., 0] 165 | rays_o = np.broadcast_to(c2w[:3, -1], np.shape(rays_d)) 166 | return rays_o, rays_d 167 | 168 | 169 | def get_rays_tensor(H, W, K, c2w): 170 | i, j = torch.meshgrid(torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H)) 171 | i = i.t() 172 | j = j.t() 173 | 174 | pixelpoints = torch.stack([i, j, torch.ones_like(i)], -1).unsqueeze(-1) 175 | localpoints = torch.matmul(torch.inverse(K), pixelpoints) 176 | 177 | rays_d = torch.matmul(c2w[:3, :3], localpoints)[..., 0] 178 | rays_o = c2w[:3, -1].expand(rays_d.shape) 179 | return rays_o, rays_d 180 | 181 | 182 | def get_rays_tensor_batches(H, W, K, c2w): 183 | i, j = torch.meshgrid([torch.linspace(0, W - 1, W, device=K.device), 184 | torch.linspace(0, H - 1, H, device=K.device)]) 185 | i = i.t() 186 | j = j.t() 187 | 188 | pixelpoints = torch.stack([i, j, torch.ones_like(i)], -1)[None, ..., None] 189 | localpoints = torch.matmul(torch.inverse(K)[:, None, None], pixelpoints) 190 | 191 | rays_d = torch.matmul(c2w[:, None, None, :3, :3], localpoints)[..., 0] 192 | rays_o = c2w[:, :3, -1].expand(rays_d.shape) 193 | return rays_o, rays_d 194 | 195 | 196 | def get_new_intrin(old_intrin, new_h_start, new_w_start): 197 | new_intrin = old_intrin.clone() if isinstance(old_intrin, torch.Tensor) else old_intrin.copy() 198 | new_intrin[..., 0, 2] -= new_w_start 199 | new_intrin[..., 1, 2] -= new_h_start 200 | return new_intrin 201 | 202 | 203 | def pose2extrin_np(pose: np.ndarray): 204 | if pose.shape[-2] == 3: 205 | bottom = pose[..., :1, :].copy() 206 | bottom[..., :] = [0, 0, 0, 1] 207 | pose = np.concatenate([pose, bottom], axis=-2) 208 | return np.linalg.inv(pose) 209 | 210 | 211 | def pose2extrin_torch(pose): 212 | """ 213 | pose to extrin or extrin to pose (equivalent) 214 | """ 215 | if pose.shape[-2] == 3: 216 | bottom = pose[..., :1, :].detach().clone() 217 | bottom[..., :] = torch.tensor([0, 0, 0, 1.]) 218 | pose = torch.cat([pose, bottom], dim=-2) 219 | return torch.inverse(pose) 220 | 221 | 222 | def raw2poses(rot_raw, tran_raw, intrin_raw): 223 | x = rot_raw[..., 0] 224 | x = x / torch.norm(x, dim=-1, keepdim=True) 225 | z = torch.cross(x, rot_raw[..., 1]) 226 | z = z / torch.norm(z, dim=-1, keepdim=True) 227 | y = torch.cross(z, x) 228 | rot = torch.stack([x, y, z], dim=-1) 229 | pose = torch.cat([rot, tran_raw[..., None]], dim=-1) 230 | bottom = torch.tensor([0, 0, 1]).type_as(intrin_raw).reshape(-1, 1, 3).expand(len(intrin_raw), -1, -1) 231 | intrinsic = torch.cat([intrin_raw, bottom], dim=1) 232 | return pose, intrinsic 233 | 234 | 235 | def get_batched_rays_tensor(H, W, Ks, c2ws): 236 | i, j = torch.meshgrid(torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H)) 237 | i = i.t() 238 | j = j.t() 239 | 240 | pixelpoints = torch.stack([i, j, torch.ones_like(i)], -1)[None, ..., None] 241 | localpoints = torch.matmul(torch.inverse(Ks)[:, None, None, ...], pixelpoints) 242 | 243 | rays_d = torch.matmul(c2ws[:, None, None, :3, :3], localpoints)[..., 0] 244 | rays_o = c2ws[:, None, None, :3, -1].expand(rays_d.shape) 245 | return torch.stack([rays_o, rays_d], dim=1) 246 | 247 | 248 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 249 | # Get pdf 250 | weights = weights + 1e-5 # prevent nans 251 | pdf = weights / torch.sum(weights, -1, keepdim=True) 252 | cdf = torch.cumsum(pdf, -1) 253 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) 254 | 255 | # Take uniform samples 256 | if det: 257 | u = torch.linspace(0., 1., steps=N_samples) 258 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 259 | else: 260 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) 261 | 262 | # Pytest, overwrite u with numpy's fixed random numbers 263 | if pytest: 264 | np.random.seed(0) 265 | new_shape = list(cdf.shape[:-1]) + [N_samples] 266 | if det: 267 | u = np.linspace(0., 1., N_samples) 268 | u = np.broadcast_to(u, new_shape) 269 | else: 270 | u = np.random.rand(*new_shape) 271 | u = torch.Tensor(u) 272 | 273 | # Invert CDF 274 | u = u.contiguous() 275 | inds = torch.searchsorted(cdf, u, right=True) 276 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 277 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 278 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 279 | 280 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 281 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 282 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 283 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 284 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 285 | 286 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 287 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 288 | t = (u - cdf_g[..., 0]) / denom 289 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 290 | 291 | return samples 292 | 293 | 294 | def gaussian(img, kernel_size): 295 | return GaussianBlur(kernel_size)(img) 296 | 297 | 298 | def dilate(alpha: torch.Tensor, kernelsz=3, dilate=1): 299 | """ 300 | alpha: B x L x H x W 301 | """ 302 | padding = (dilate * (kernelsz - 1) + 1) // 2 303 | batchsz, layernum, hei, wid = alpha.shape 304 | alphaunfold = torch.nn.Unfold(kernelsz, dilation=dilate, padding=padding, stride=1)(alpha.reshape(-1, 1, hei, wid)) 305 | alphaunfold = alphaunfold.max(dim=1)[0] 306 | return alphaunfold.reshape_as(alpha) 307 | 308 | 309 | def erode(alpha: torch.Tensor, kernelsz=3, dilate=1): 310 | """ 311 | alpha: B x L x H x W 312 | """ 313 | padding = (dilate * (kernelsz - 1) + 1) // 2 314 | batchsz, layernum, hei, wid = alpha.shape 315 | alphaunfold = torch.nn.Unfold(kernelsz, dilation=dilate, padding=padding, stride=1)(alpha.reshape(-1, 1, hei, wid)) 316 | alphaunfold = alphaunfold.min(dim=1)[0] 317 | return alphaunfold.reshape_as(alpha) 318 | 319 | 320 | class DataParallelCPU: 321 | def __init__(self, module: nn.Module): 322 | self.module = module 323 | 324 | def to(self, device): 325 | self.module.to(device) 326 | 327 | def train(self): 328 | self.module.train() 329 | 330 | def eval(self): 331 | self.module.eval() 332 | 333 | def __call__(self, *args, **kwargs): 334 | return self.module(*args, **kwargs) 335 | 336 | 337 | def compute_loopable_mask(vid, eps=15 / 255, factor=2): 338 | ori_size = vid[0].shape[:2] 339 | 340 | vid0 = cv2.resize(vid[0], None, None, 1 / factor, 1 / factor) 341 | rises = np.zeros_like(vid0) > 0 342 | falls = np.zeros_like(vid0) > 0 343 | minval = vid0 344 | maxval = vid0 345 | for im in vid[1:]: 346 | im_down = cv2.resize(im, None, None, 1 / factor, 1 / factor) 347 | minval = np.minimum(minval, im_down) 348 | maxval = np.maximum(maxval, im_down) 349 | rises = np.logical_or(im_down - minval > eps, rises) 350 | falls = np.logical_or(maxval - im_down > eps, falls) 351 | 352 | unchangging = np.logical_and(np.logical_not(rises), np.logical_not(falls)) 353 | unchangging = np.all(unchangging, axis=-1) 354 | unloopable = np.logical_xor(rises, falls) 355 | unloopable = np.any(unloopable, axis=-1) 356 | loopable = np.logical_not(np.logical_or(unchangging, unloopable)) 357 | 358 | # loopable = cv2.erode(loopable.astype(np.uint8), np.ones((3, 3))) 359 | # loopable = cv2.dilate(loopable.astype(np.uint8), np.ones((3, 3))) 360 | label = np.stack([loopable, unloopable.astype(np.uint8), unchangging.astype(np.uint8)], axis=-1) * 255 361 | label_smooth = cv2.GaussianBlur(label, (5, 5), 0) 362 | label_smooth = cv2.resize(label_smooth.astype(np.float32), ori_size[::-1], None) 363 | loopable_smooth = label_smooth.argmax(axis=-1) == 0 364 | return loopable_smooth 365 | 366 | 367 | def save_obj_multimaterial(file, vertices, faces_list, uvs, uvfaces_list, mtls_list): 368 | with open(file, 'w') as f: 369 | for vertice in vertices: 370 | f.write(f"v {vertice[0]} {vertice[1]} {vertice[2]}\n") 371 | for uv in uvs: 372 | f.write(f"vt {uv[0]} {uv[1]}\n") 373 | 374 | for mtl, faces, uvfaces in zip(mtls_list, faces_list, uvfaces_list): 375 | faces1 = faces + 1 376 | uvfaces1 = uvfaces + 1 377 | f.write(f"usemtl {mtl}\n") 378 | f.write(f"s off\n") 379 | for face, uvface in zip(faces1, uvfaces1): 380 | f.write(f"f {face[0]}/{uvface[0]} {face[1]}/{uvface[1]} {face[2]}/{uvface[2]}\n") 381 | 382 | f.write("\n") 383 | 384 | 385 | def save_obj_with_vcolor(file, verts_colors, faces, uvs, uvfaces): 386 | with open(file, 'w') as f: 387 | for pos_color in verts_colors: 388 | f.write(f"v {pos_color[0]} {pos_color[1]} {pos_color[2]} {pos_color[3]} {pos_color[4]} {pos_color[5]}\n") 389 | for uv in uvs: 390 | f.write(f"vt {uv[0]} {uv[1]}\n") 391 | 392 | faces1 = faces + 1 393 | uvfaces1 = uvfaces + 1 394 | for face, uvface in zip(faces1, uvfaces1): 395 | f.write(f"f {face[0]}/{uvface[0]} {face[1]}/{uvface[1]} {face[2]}/{uvface[2]}\n") 396 | 397 | f.write("\n") 398 | 399 | 400 | # Mesh utility 401 | 402 | 403 | def normalize_uv(uv, h, w): 404 | uv[:, 1] = -uv[:, 1] 405 | uv = uv * 0.5 + 0.5 406 | uv = uv * np.array([w - 1, h - 1]) / np.array([w, h]) + 0.5 / np.array([w, h]) 407 | return uv 408 | 409 | 410 | def cull_unused(v, f): 411 | id_unique = np.unique(f) 412 | v_unique = v[id_unique] 413 | id_old2new = np.ones(len(v)).astype(id_unique.dtype) * -1 414 | id_old2new[id_unique] = np.arange(len(v_unique)) 415 | newf = id_old2new[f] 416 | return v_unique, newf 417 | 418 | 419 | def save_obj(file, verts, faces, uvs, uvfaces, rm_unused=True): 420 | if rm_unused: 421 | verts, faces = cull_unused(verts, faces) 422 | uvs, uvfaces = cull_unused(uvs, uvfaces) 423 | 424 | with open(file, 'w') as f: 425 | for pos_color in verts: 426 | f.write(f"v {pos_color[0]} {pos_color[1]} {pos_color[2]}\n") 427 | for uv in uvs: 428 | f.write(f"vt {uv[0]} {uv[1]}\n") 429 | 430 | faces1 = faces + 1 431 | uvfaces1 = uvfaces + 1 432 | for face, uvface in zip(faces1, uvfaces1): 433 | f.write(f"f {face[0]}/{uvface[0]} {face[1]}/{uvface[1]} {face[2]}/{uvface[2]}\n") 434 | 435 | f.write("\n") 436 | 437 | -------------------------------------------------------------------------------- /utils_mpi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as torchf 3 | import numpy as np 4 | from typing import Union, Sequence, Tuple 5 | import torch.nn as nn 6 | 7 | 8 | class Feat2RGBMLP_alpha(nn.Module): # alpha is view-independent 9 | def __init__(self, feat_cnl, view_cn): 10 | super().__init__() 11 | self.mlp = nn.Sequential( 12 | nn.Linear(feat_cnl + view_cn - 1, 48), nn.ReLU(), 13 | nn.Linear(48, 3) 14 | ) 15 | 16 | def forward(self, x): 17 | return torch.cat([self.mlp(x[..., 1:]), x[..., :1]], dim=-1) 18 | 19 | 20 | class NeX_RGBA(nn.Module): # alpha is view-independent 21 | def __init__(self, feat_cnl, view_cn): 22 | assert feat_cnl % 4 == 0 23 | super().__init__() 24 | self.feat_cnl, self.view_cnl = feat_cnl, view_cn 25 | self.mlp = nn.Sequential( 26 | nn.Linear(view_cn, 64), nn.ReLU(), 27 | nn.Linear(64, feat_cnl - 4) 28 | ) 29 | 30 | def forward(self, x): 31 | basis = self.mlp(x[:, self.feat_cnl:]).reshape(-1, self.feat_cnl // 4 - 1, 4) 32 | return (basis * x[:, 4:self.feat_cnl].reshape(-1, self)).sum(dim=-2) + x[:, :4] 33 | 34 | 35 | class NeX_RGB(nn.Module): # alpha is view dependent 36 | def __init__(self, feat_cnl, view_cn): 37 | super().__init__() 38 | self.feat_cnl, self.view_cnl = feat_cnl, view_cn 39 | self.mlp = nn.Sequential( 40 | nn.Linear(view_cn, 64), nn.ReLU(), 41 | nn.Linear(64, 3 * (feat_cnl - 1)) 42 | ) 43 | 44 | def forward(self, x): 45 | basis = self.mlp(x[:, self.feat_cnl:]).reshape(-1, self.feat_cnl - 1, 4) 46 | rgb = (basis * x[:, 1:self.feat_cnl, None]).sum(dim=-2) 47 | return torch.cat([rgb, x[..., :1]], dim=-1) 48 | 49 | 50 | class SphericalHarmoic_RGB(nn.Module): # alpha is view-independent 51 | def __init__(self, feat_cnl, view_cn): 52 | super().__init__() 53 | self.sh_dim = feat_cnl // 3 54 | self.feat_cnl = feat_cnl 55 | self.view_cnl = view_cn 56 | 57 | def forward(self, x): 58 | feat, view = torch.split(x, [self.feat_cnl, self.view_cnl], -1) 59 | sh_base = eval_sh_bases(self.sh_dim, view[..., :3]) 60 | rgb = torch.sum(sh_base.reshape(-1, 1, self.sh_dim) * feat[..., :-1].reshape(-1, 3, self.sh_dim), dim=-1) 61 | return torch.cat([rgb, feat[..., -1:]], dim=-1) 62 | 63 | 64 | class SphericalHarmoic_RGBA(nn.Module): # alpha is view-independent 65 | def __init__(self, feat_cnl, view_cn): 66 | super().__init__() 67 | self.sh_dim = 9 68 | self.feat_cnl = feat_cnl 69 | self.view_cnl = view_cn 70 | 71 | def forward(self, x): 72 | feat, view = torch.split(x, [self.feat_cnl, self.view_cnl], -1) 73 | sh_base = eval_sh_bases(self.sh_dim, view[..., :3]) 74 | rgba = torch.sum(sh_base.reshape(1, 1, -1) * feat.reshape(-1, 4, self.sh_dim), dim=-1) 75 | return rgba 76 | 77 | 78 | # geometric utils: generating geometry 79 | # ##################################### 80 | def gen_mpi_vertices(H, W, intrin, num_vert_h, num_vert_w, planedepth): 81 | verts = torch.meshgrid( 82 | [torch.linspace(0, H - 1, num_vert_h), torch.linspace(0, W - 1, num_vert_w)]) 83 | verts = torch.stack(verts[::-1], dim=-1).reshape(1, -1, 2) 84 | # num_plane, H*W, 2 85 | verts = (verts - intrin[None, None, :2, 2]) * planedepth[:, None, None].type_as(verts) 86 | verts /= intrin[None, None, [0, 1], [0, 1]] 87 | zs = planedepth[:, None, None].expand_as(verts[..., :1]) 88 | verts = torch.cat([verts.reshape(-1, 2), zs.reshape(-1, 1)], dim=-1) 89 | return verts 90 | 91 | 92 | def overcompose(alpha, content): 93 | """ 94 | compose mpi back (-1) to front (0) 95 | alpha: [B, H, W, 32] 96 | content: [B, H, W, 32, C] 97 | """ 98 | batchsz, num_plane, height, width, _ = content.shape 99 | 100 | blendweight = torch.cumprod((- alpha + 1)[..., :-1], dim=-1) # B x H x W x LayerNum-1 101 | blendweight = torch.cat([ 102 | alpha[..., :1], 103 | alpha[..., 1:] * blendweight 104 | ], dim=-1) 105 | 106 | rgb = (content * blendweight.unsqueeze(-1)).sum(dim=-2) 107 | return rgb, blendweight 108 | 109 | 110 | def overcomposeNto0(mpi: torch.Tensor, blendweight=None, ret_mask=False, blend_content=None) \ 111 | -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 112 | """ 113 | compose mpi back to front 114 | mpi: [B, 32, 4, H, W] 115 | blendweight: [B, 32, H, W] for reduce reduntant computation 116 | blendContent: [B, layernum, cnl, H, W], if not None, 117 | return: image of shape [B, 4, H, W] 118 | [optional: ] mask of shape [B, H, W], soft mask that 119 | """ 120 | batchsz, num_plane, _, height, width = mpi.shape 121 | alpha = mpi[:, :, -1, ...] # alpha.shape == B x LayerNum x H x W 122 | if blendweight is None: 123 | blendweight = torch.cat([torch.cumprod(- torch.flip(alpha, dims=[1]) + 1, dim=1).flip(dims=[1])[:, 1:], 124 | torch.ones([batchsz, 1, height, width]).type_as(alpha)], dim=1) 125 | renderw = alpha * blendweight 126 | 127 | content = mpi[:, :, :3, ...] if blend_content is None else blend_content 128 | rgb = (content * renderw.unsqueeze(2)).sum(dim=1) 129 | if ret_mask: 130 | return rgb, blendweight 131 | else: 132 | return rgb 133 | 134 | 135 | def estimate_disparity_np(mpi: np.ndarray, min_depth=1, max_depth=100): 136 | """Compute disparity map from a set of MPI layers. 137 | mpi: np.ndarray or torch.Tensor 138 | From reference view. 139 | 140 | Args: 141 | layers: [..., L, H, W, C+1] MPI layers, back to front. 142 | depths: [..., L] depths for each layer. 143 | 144 | Returns: 145 | [..., H, W, 1] Single-channel disparity map from reference viewpoint. 146 | """ 147 | num_plane, height, width, chnl = mpi.shape 148 | disparities = np.linspace(1. / max_depth, 1. / min_depth, num_plane) 149 | disparities = disparities.reshape(-1, 1, 1, 1) 150 | 151 | alpha = mpi[..., -1:] 152 | alpha = alpha * np.concatenate([np.cumprod(1 - alpha[::-1], axis=0)[::-1][1:], 153 | np.ones([1, height, width, 1])], axis=0) 154 | disparity = (alpha * disparities).sum(axis=0) 155 | # Weighted sum of per-layer disparities: 156 | return disparity.squeeze(-1) 157 | 158 | 159 | def warp_homography(h, w, homos: torch.Tensor, images: torch.Tensor) -> torch.Tensor: 160 | """ 161 | apply differentiable homography 162 | h, w: the target size 163 | homos: [B x D x 3 x 3] 164 | images: [B x D x 4 x H x W] 165 | """ 166 | batchsz, planenum, cnl, hei, wid = images.shape 167 | y, x = torch.meshgrid([torch.arange(h), torch.arange(w)]) 168 | x, y = x.type_as(images), y.type_as(images) 169 | one = torch.ones_like(x) 170 | grid = torch.stack([x, y, one], dim=-1) # grid: B x D x H x W x 3 171 | new_grid = homos.unsqueeze(-3).unsqueeze(-3) @ grid.unsqueeze(-1) 172 | new_grid = (new_grid.squeeze(-1) / new_grid[..., 2:3, 0])[..., 0:2] # grid: B x D x H x W x 2 173 | new_grid = new_grid / torch.tensor([wid / 2, hei / 2]).type_as(new_grid) - 1. 174 | warpped = torchf.grid_sample(images.reshape(batchsz * planenum, cnl, hei, wid), 175 | new_grid.reshape(batchsz * planenum, h, w, 2), align_corners=True) 176 | return warpped.reshape(batchsz, planenum, cnl, h, w) 177 | 178 | 179 | def warp_homography_withdepth(homos: torch.Tensor, images: torch.Tensor, depth: torch.Tensor) \ 180 | -> Tuple[torch.Tensor, torch.Tensor]: 181 | """ 182 | Please note that homographies here are not scale invariant. make sure that rotation matrix has 1 det. R.det() == 1. 183 | apply differentiable homography 184 | homos: [B x D x 3 x 3] 185 | images: [B x D x 4 x H x W] 186 | depth: [B x D] or [B x D x 1] (depth in *ref space*) 187 | ret: 188 | the warpped mpi 189 | the warpped depth 190 | """ 191 | batchsz, planenum, cnl, hei, wid = images.shape 192 | y, x = torch.meshgrid([torch.arange(hei), torch.arange(wid)]) 193 | x, y = x.type_as(images), y.type_as(images) 194 | one = torch.ones_like(x) 195 | grid = torch.stack([x, y, one], dim=-1).reshape(1, 1, hei, wid, 3, 1) 196 | if depth.dim() == 4: 197 | depth = depth.reshape(batchsz, planenum, 1, hei, wid) 198 | else: 199 | depth = depth.reshape(batchsz, planenum, 1, 1, 1) 200 | 201 | new_grid = homos.unsqueeze(-3).unsqueeze(-3) @ grid 202 | new_depth = depth.reshape(batchsz, planenum, 1, 1) / new_grid[..., -1, 0] 203 | new_grid = (new_grid.squeeze(-1) / new_grid[..., 2:3, 0])[..., 0:2] # grid: B x D x H x W x 2 204 | new_grid = new_grid / torch.tensor([wid / 2, hei / 2]).type_as(new_grid) - 1. 205 | warpped = torchf.grid_sample(images.reshape(batchsz * planenum, cnl, hei, wid), 206 | new_grid.reshape(batchsz * planenum, hei, wid, 2), align_corners=True) 207 | return warpped.reshape(batchsz, planenum, cnl, hei, wid), new_depth 208 | 209 | 210 | def make_depths(num_plane, min_depth, max_depth): 211 | return torch.reciprocal(torch.linspace(1. / max_depth, 1. / min_depth, num_plane, dtype=torch.float32)) 212 | 213 | 214 | def estimate_disparity_torch(mpi: torch.Tensor, depthes: torch.Tensor, blendweight=None, retbw=False): 215 | """Compute disparity map from a set of MPI layers. 216 | mpi: tensor of shape B x LayerNum x 4 x H x W 217 | depthes: tensor of shape [B x LayerNum] or [B x LayerNum x H x W] (means different depth for each pixel] 218 | blendweight: optional blendweight that to reduce reduntante computation 219 | return: tensor of shape B x H x W 220 | """ 221 | assert (mpi.dim() == 5) 222 | batchsz, num_plane, _, height, width = mpi.shape 223 | disparities = torch.reciprocal(depthes) 224 | if disparities.dim() != 4: 225 | disparities = disparities.reshape(batchsz, num_plane, 1, 1).type_as(mpi) 226 | 227 | alpha = mpi[:, :, -1, ...] # alpha.shape == B x LayerNum x H x W 228 | if blendweight is None: 229 | blendweight = torch.cat([torch.cumprod(- torch.flip(alpha, dims=[1]) + 1, dim=1).flip(dims=[1])[:, 1:], 230 | torch.ones([batchsz, 1, height, width]).type_as(alpha)], dim=1) 231 | renderweight = alpha * blendweight 232 | disparity = (renderweight * disparities).sum(dim=1) 233 | 234 | if retbw: 235 | return disparity, blendweight 236 | else: 237 | return disparity 238 | 239 | 240 | def compute_homography(src_extrin_4x4: torch.Tensor, src_intrin: torch.Tensor, 241 | tar_extrin_4x4: torch.Tensor, tar_intrin: torch.Tensor, 242 | normal: torch.Tensor, distances: torch.Tensor) -> torch.Tensor: 243 | """ 244 | compute homography matrix, detail pls see https://en.wikipedia.org/wiki/Homography_(computer_vision) 245 | explanation: assume P, P1, P2 be coordinate of point in plane in world, ref, tar space 246 | P1 = R1 @ P + t1 P2 = R2 @ P + t2 247 | so P1 = R @ P2 + t where: 248 | R = R1 @ R2^T, t = t1 - R @ t2 249 | n^T @ P1 = d be plane equation in ref space, 250 | so in tar space: n'^T @ P2 = d' where: 251 | n' = R^T @ n, d' = d - n^T @ t 252 | 253 | so P1 = R @ P2 + d'^-1 t @ n'T @ P2 = (R + t @ n'^T @ R / (d - n^T @ t)) @ P2 254 | src_extrin/tar_extrin: [B, 3, 4] = [R | t] 255 | src_intrin/tar_intrin: [B, 3, 3] 256 | normal: [B, D, 3] normal of plane in *reference space* 257 | distances: [B, D] offset of plaen in *ref space* 258 | so the plane equation: n^T @ P1 = d ==> n'^T 259 | return: [B, D, 3, 3] 260 | """ 261 | batchsz, _, _ = src_extrin_4x4.shape 262 | pose = src_extrin_4x4 @ torch.inverse(tar_extrin_4x4) 263 | # rotation = R1 @ R2^T 264 | # translation = (t1 - R1 @ R2^T @ t2) 265 | rotation, translation = pose[..., :3, :3], pose[..., :3, 3:].squeeze(-1) 266 | distances_tar = -(normal @ translation.unsqueeze(-1)).squeeze(-1) + distances 267 | 268 | # [..., 3, 3] -> [..., D, 3, 3] 269 | # multiply extra rotation because normal is in reference space 270 | homo = rotation.unsqueeze(-3) + (translation.unsqueeze(-1) @ normal.unsqueeze(-2) @ rotation.unsqueeze(-3)) \ 271 | / distances_tar.unsqueeze(-1).unsqueeze(-1) 272 | homo = src_intrin.unsqueeze(-3) @ homo @ torch.inverse(tar_intrin.unsqueeze(-3)) 273 | return homo 274 | 275 | 276 | def render_newview(mpi: torch.Tensor, srcextrin: torch.Tensor, tarextrin: torch.Tensor, 277 | srcintrin: torch.Tensor, tarintrin: torch.Tensor, 278 | depths: torch.Tensor, ret_mask=False, ret_dispmap=False) \ 279 | -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor], 280 | Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]]: 281 | """ 282 | mpi: [B, 32, 4, H, W] 283 | srcpose&tarpose: [B, 3, 4] 284 | depthes: tensor of shape [Bx LayerNum] 285 | intrin: [B, 3, 3] 286 | ret: ref_view_images[, mask][, disparitys] 287 | """ 288 | batchsz, planenum, _, hei, wid = mpi.shape 289 | 290 | planenormal = torch.tensor([0, 0, 1]).reshape(1, 3).repeat(batchsz, 1).type_as(mpi) 291 | distance = depths.reshape(batchsz, planenum) 292 | with torch.no_grad(): 293 | # switching the tar/src pose since we have extrinsic but compute_homography uses poses 294 | # srcextrin = torch.tensor([1, 0, 0, 0, # for debug usage 295 | # 0, 1, 0, 0, 296 | # 0, 0, 1, 0]).reshape(1, 3, 4).type_as(intrin) 297 | # tarextrin = torch.tensor([np.cos(0.3), -np.sin(0.3), 0, 0, 298 | # np.sin(0.3), np.cos(0.3), 0, 0, 299 | # 0, 0, 1, 1.5]).reshape(1, 3, 4).type_as(intrin) 300 | homos = compute_homography(srcextrin, srcintrin, tarextrin, tarintrin, 301 | planenormal, distance) 302 | if not ret_dispmap: 303 | mpi_warp = warp_homography(homos, mpi) 304 | return overcomposeNto0(mpi_warp, ret_mask=ret_mask) 305 | else: 306 | mpi_warp, mpi_depth = warp_homography_withdepth(homos, mpi, distance) 307 | disparitys = estimate_disparity_torch(mpi_warp, mpi_depth) 308 | return overcomposeNto0(mpi_warp, ret_mask=ret_mask), disparitys 309 | 310 | 311 | def warp_flow(content: torch.Tensor, flow: torch.Tensor, offset=None, pad_mode="zeros", mode="bilinear"): 312 | """ 313 | content: [..., cnl, H, W] 314 | flow: [..., 2, H, W] 315 | """ 316 | assert content.dim() == flow.dim() 317 | orishape = content.shape 318 | cnl, hei, wid = content.shape[-3:] 319 | mpi = content.reshape(-1, cnl, hei, wid) 320 | flow = flow.reshape(-1, 2, hei, wid).permute(0, 2, 3, 1) 321 | 322 | if offset is None: 323 | y, x = torch.meshgrid([torch.arange(hei), torch.arange(wid)]) 324 | x, y = x.type_as(mpi), y.type_as(mpi) 325 | offset = torch.stack([x, y], dim=-1) 326 | grid = offset.reshape(1, hei, wid, 2) + flow 327 | normanator = torch.tensor([(wid - 1) / 2, (hei - 1) / 2]).reshape(1, 1, 1, 2).type_as(grid) 328 | warpped = torchf.grid_sample(mpi, grid / normanator - 1., padding_mode=pad_mode, mode=mode, align_corners=True) 329 | return warpped.reshape(orishape) 330 | 331 | 332 | # spherical hamoric related, copy from svox2 333 | 334 | SH_C0 = 0.28209479177387814 335 | SH_C1 = 0.4886025119029199 336 | SH_C2 = [ 337 | 1.0925484305920792, 338 | -1.0925484305920792, 339 | 0.31539156525252005, 340 | -1.0925484305920792, 341 | 0.5462742152960396 342 | ] 343 | SH_C3 = [ 344 | -0.5900435899266435, 345 | 2.890611442640554, 346 | -0.4570457994644658, 347 | 0.3731763325901154, 348 | -0.4570457994644658, 349 | 1.445305721320277, 350 | -0.5900435899266435 351 | ] 352 | SH_C4 = [ 353 | 2.5033429417967046, 354 | -1.7701307697799304, 355 | 0.9461746957575601, 356 | -0.6690465435572892, 357 | 0.10578554691520431, 358 | -0.6690465435572892, 359 | 0.47308734787878004, 360 | -1.7701307697799304, 361 | 0.6258357354491761, 362 | ] 363 | 364 | 365 | def eval_sh_bases(basis_dim: int, dirs: torch.Tensor): 366 | """ 367 | Evaluate spherical harmonics bases at unit directions, 368 | without taking linear combination. 369 | At each point, the final result may the be 370 | obtained through simple multiplication. 371 | 372 | :param basis_dim: int SH basis dim. Currently, 1-25 square numbers supported 373 | :param dirs: torch.Tensor (..., 3) unit directions 374 | 375 | :return: torch.Tensor (..., basis_dim) 376 | """ 377 | result = torch.empty((*dirs.shape[:-1], basis_dim), dtype=dirs.dtype, device=dirs.device) 378 | result[..., 0] = SH_C0 379 | if basis_dim > 1: 380 | x, y, z = dirs.unbind(-1) 381 | result[..., 1] = -SH_C1 * y 382 | result[..., 2] = SH_C1 * z 383 | result[..., 3] = -SH_C1 * x 384 | if basis_dim > 4: 385 | xx, yy, zz = x * x, y * y, z * z 386 | xy, yz, xz = x * y, y * z, x * z 387 | result[..., 4] = SH_C2[0] * xy 388 | result[..., 5] = SH_C2[1] * yz 389 | result[..., 6] = SH_C2[2] * (2.0 * zz - xx - yy) 390 | result[..., 7] = SH_C2[3] * xz 391 | result[..., 8] = SH_C2[4] * (xx - yy) 392 | 393 | if basis_dim > 9: 394 | result[..., 9] = SH_C3[0] * y * (3 * xx - yy) 395 | result[..., 10] = SH_C3[1] * xy * z 396 | result[..., 11] = SH_C3[2] * y * (4 * zz - xx - yy) 397 | result[..., 12] = SH_C3[3] * z * (2 * zz - 3 * xx - 3 * yy) 398 | result[..., 13] = SH_C3[4] * x * (4 * zz - xx - yy) 399 | result[..., 14] = SH_C3[5] * z * (xx - yy) 400 | result[..., 15] = SH_C3[6] * x * (xx - 3 * yy) 401 | 402 | if basis_dim > 16: 403 | result[..., 16] = SH_C4[0] * xy * (xx - yy) 404 | result[..., 17] = SH_C4[1] * yz * (3 * xx - yy) 405 | result[..., 18] = SH_C4[2] * xy * (7 * zz - 1) 406 | result[..., 19] = SH_C4[3] * yz * (7 * zz - 3) 407 | result[..., 20] = SH_C4[4] * (zz * (35 * zz - 30) + 3) 408 | result[..., 21] = SH_C4[5] * xz * (7 * zz - 3) 409 | result[..., 22] = SH_C4[6] * (xx - yy) * (7 * zz - 1) 410 | result[..., 23] = SH_C4[7] * xz * (xx - 3 * yy) 411 | result[..., 24] = SH_C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) 412 | return result 413 | --------------------------------------------------------------------------------