├── datasets ├── __init__.py ├── lists │ ├── tnt │ │ ├── advanced.txt │ │ └── intermediate.txt │ ├── blendedmvs │ │ ├── val.txt │ │ └── low_res_all.txt │ └── dtu │ │ ├── val.txt │ │ ├── test.txt │ │ └── train.txt ├── evaluations │ └── dtu_parallel │ │ ├── reducePts_haa.m │ │ ├── MaxDistCP.m │ │ ├── BaseEval2Obj_web.m │ │ ├── BaseEvalMain_web.m │ │ ├── ComputeStat_web.m │ │ ├── PointCompareMain.m │ │ └── plyread.m ├── data_io.py ├── tnt.py ├── blendedmvs.py └── dtu.py ├── .gitignore ├── models ├── utils │ ├── __init__.py │ ├── opts.py │ └── utils.py ├── __init__.py ├── filter.py ├── loss.py ├── geomvsnet.py └── submodules.py ├── .github └── imgs │ ├── geomvsnet-video-cover.png │ └── mvs-demo-video-cover.png ├── requirements.txt ├── scripts ├── data_path.sh ├── dtu │ ├── test_dtu.sh │ ├── train_dtu.sh │ ├── train_dtu_raw.sh │ ├── fusion_dtu.sh │ └── matlab_quan_dtu.sh ├── blend │ └── train_blend.sh └── tnt │ ├── fusion_tnt.sh │ └── test_tnt.sh ├── outputs └── visual.ipynb ├── test.py ├── LICENSE ├── fusions └── dtu │ ├── _open3d.py │ ├── gipuma.py │ └── pcd.py ├── README.md └── train.py /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ -------------------------------------------------------------------------------- /models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from models.utils.utils import * -------------------------------------------------------------------------------- /datasets/lists/tnt/advanced.txt: -------------------------------------------------------------------------------- 1 | Auditorium 2 | Ballroom 3 | Courtroom 4 | Museum 5 | Palace 6 | Temple -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.geomvsnet import GeoMVSNet 2 | from models.loss import geomvsnet_loss -------------------------------------------------------------------------------- /datasets/lists/tnt/intermediate.txt: -------------------------------------------------------------------------------- 1 | Family 2 | Horse 3 | Francis 4 | Lighthouse 5 | M60 6 | Panther 7 | Playground 8 | Train -------------------------------------------------------------------------------- /.github/imgs/geomvsnet-video-cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/doubleZ0108/GeoMVSNet/HEAD/.github/imgs/geomvsnet-video-cover.png -------------------------------------------------------------------------------- /.github/imgs/mvs-demo-video-cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/doubleZ0108/GeoMVSNet/HEAD/.github/imgs/mvs-demo-video-cover.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.0 2 | torchvision 3 | opencv-python 4 | numpy==1.18.1 5 | pillow 6 | scipy 7 | tensorboardX 8 | plyfile 9 | open3d 10 | jupyter 11 | notebook -------------------------------------------------------------------------------- /datasets/lists/blendedmvs/val.txt: -------------------------------------------------------------------------------- 1 | 5b7a3890fc8fcf6781e2593a 2 | 5c189f2326173c3a09ed7ef3 3 | 5b950c71608de421b1e7318f 4 | 5a6400933d809f1d8200af15 5 | 59d2657f82ca7774b1ec081d 6 | 5ba19a8a360c7c30c1c169df 7 | 59817e4a1bd4b175e7038d19 -------------------------------------------------------------------------------- /datasets/lists/dtu/val.txt: -------------------------------------------------------------------------------- 1 | scan3 2 | scan5 3 | scan17 4 | scan21 5 | scan28 6 | scan35 7 | scan37 8 | scan38 9 | scan40 10 | scan43 11 | scan56 12 | scan59 13 | scan66 14 | scan67 15 | scan82 16 | scan86 17 | scan106 18 | scan117 -------------------------------------------------------------------------------- /datasets/lists/dtu/test.txt: -------------------------------------------------------------------------------- 1 | scan1 2 | scan4 3 | scan9 4 | scan10 5 | scan11 6 | scan12 7 | scan13 8 | scan15 9 | scan23 10 | scan24 11 | scan29 12 | scan32 13 | scan33 14 | scan34 15 | scan48 16 | scan49 17 | scan62 18 | scan75 19 | scan77 20 | scan110 21 | scan114 22 | scan118 -------------------------------------------------------------------------------- /scripts/data_path.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # DTU 4 | DTU_TRAIN_ROOT="[/path/to/]dtu" 5 | DTU_TEST_ROOT="[/path/to/]dtu-test" 6 | DTU_QUANTITATIVE_ROOT="[/path/to/]dtu-evaluation" 7 | 8 | # Tanks and Temples 9 | TNT_ROOT="[/path/to/]tnt" 10 | 11 | # BlendedMVS 12 | BLENDEDMVS_ROOT="[/path/to/]blendmvs" -------------------------------------------------------------------------------- /scripts/dtu/test_dtu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source scripts/data_path.sh 3 | 4 | THISNAME="geomvsnet" 5 | BESTEPOCH="geomvsnet_release" 6 | 7 | LOG_DIR="./checkpoints/dtu/"$THISNAME 8 | DTU_CKPT_FILE=$LOG_DIR"/model_"$BESTEPOCH".ckpt" 9 | DTU_OUT_DIR="./outputs/dtu/"$THISNAME 10 | 11 | CUDA_VISIBLE_DEVICES=0 python3 test.py ${@} \ 12 | --which_dataset="dtu" --loadckpt=$DTU_CKPT_FILE --batch_size=1 \ 13 | --outdir=$DTU_OUT_DIR --logdir=$LOG_DIR --nolog \ 14 | --testpath=$DTU_TEST_ROOT --testlist="datasets/lists/dtu/test.txt" \ 15 | \ 16 | --data_scale="raw" --n_views="5" -------------------------------------------------------------------------------- /scripts/dtu/train_dtu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source scripts/data_path.sh 3 | 4 | THISNAME="geomvsnet" 5 | 6 | LOG_DIR="./checkpoints/dtu/"$THISNAME 7 | if [ ! -d $LOG_DIR ]; then 8 | mkdir -p $LOG_DIR 9 | fi 10 | 11 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 train.py ${@} \ 12 | --which_dataset="dtu" --epochs=16 --logdir=$LOG_DIR \ 13 | --trainpath=$DTU_TRAIN_ROOT --testpath=$DTU_TRAIN_ROOT \ 14 | --trainlist="datasets/lists/dtu/train.txt" --testlist="datasets/lists/dtu/test.txt" \ 15 | \ 16 | --data_scale="mid" --n_views="5" --batch_size=4 --lr=0.002 --robust_train \ 17 | --lrepochs="1,3,5,7,9,11,13,15:1.5" -------------------------------------------------------------------------------- /scripts/blend/train_blend.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source scripts/data_path.sh 3 | 4 | THISNAME="geomvsnet" 5 | 6 | LOG_DIR="./checkpoints/blend/"$THISNAME 7 | if [ ! -d $LOG_DIR ]; then 8 | mkdir -p $LOG_DIR 9 | fi 10 | 11 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 train.py ${@} \ 12 | --which_dataset="blendedmvs" --epochs=16 --logdir=$LOG_DIR \ 13 | --trainpath=$BLENDEDMVS_ROOT --testpath=$BLENDEDMVS_ROOT \ 14 | --trainlist="datasets/lists/blendedmvs/low_res_all.txt" --testlist="datasets/lists/blendedmvs/val.txt" \ 15 | \ 16 | --n_views="7" --batch_size=2 --lr=0.001 --robust_train \ 17 | --lr_scheduler="onecycle" -------------------------------------------------------------------------------- /scripts/dtu/train_dtu_raw.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source scripts/data_path.sh 3 | 4 | THISNAME="geomvsnet_raw" 5 | 6 | LOG_DIR="./checkpoints/dtu/"$THISNAME 7 | if [ ! -d $LOG_DIR ]; then 8 | mkdir -p $LOG_DIR 9 | fi 10 | 11 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 train.py ${@} \ 12 | --which_dataset="dtu" --epochs=16 --logdir=$LOG_DIR \ 13 | --trainpath=$DTU_TRAIN_ROOT --testpath=$DTU_TRAIN_ROOT \ 14 | --trainlist="datasets/lists/dtu/train.txt" --testlist="datasets/lists/dtu/test.txt" \ 15 | \ 16 | --data_scale="raw" --n_views="5" --batch_size=1 --lr=0.0005 --robust_train \ 17 | --lrepochs="1,3,5,7,9,11,13,15:1.5" -------------------------------------------------------------------------------- /scripts/tnt/fusion_tnt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source scripts/data_path.sh 3 | 4 | THISNAME="blend/geomvsnet" 5 | 6 | LOG_DIR="./checkpoints/tnt/"$THISNAME 7 | TNT_OUT_DIR="./outputs/tnt/"$THISNAME 8 | 9 | # Intermediate 10 | python3 fusions/tnt/dypcd.py ${@} \ 11 | --root_dir=$TNT_ROOT --list_file="datasets/lists/tnt/intermediate.txt" --split="intermediate" \ 12 | --out_dir=$TNT_OUT_DIR --ply_path=$TNT_OUT_DIR"/dypcd_fusion_plys" \ 13 | --img_mode="resize" --cam_mode="origin" --single_processor 14 | 15 | # Advanced 16 | python3 fusions/tnt/dypcd.py ${@} \ 17 | --root_dir=$TNT_ROOT --list_file="datasets/lists/tnt/advanced.txt" --split="advanced" \ 18 | --out_dir=$TNT_OUT_DIR --ply_path=$TNT_OUT_DIR"/dypcd_fusion_plys" \ 19 | --img_mode="resize" --cam_mode="origin" --single_processor -------------------------------------------------------------------------------- /datasets/lists/dtu/train.txt: -------------------------------------------------------------------------------- 1 | scan2 2 | scan6 3 | scan7 4 | scan8 5 | scan14 6 | scan16 7 | scan18 8 | scan19 9 | scan20 10 | scan22 11 | scan30 12 | scan31 13 | scan36 14 | scan39 15 | scan41 16 | scan42 17 | scan44 18 | scan45 19 | scan46 20 | scan47 21 | scan50 22 | scan51 23 | scan52 24 | scan53 25 | scan55 26 | scan57 27 | scan58 28 | scan60 29 | scan61 30 | scan63 31 | scan64 32 | scan65 33 | scan68 34 | scan69 35 | scan70 36 | scan71 37 | scan72 38 | scan74 39 | scan76 40 | scan83 41 | scan84 42 | scan85 43 | scan87 44 | scan88 45 | scan89 46 | scan90 47 | scan91 48 | scan92 49 | scan93 50 | scan94 51 | scan95 52 | scan96 53 | scan97 54 | scan98 55 | scan99 56 | scan100 57 | scan101 58 | scan102 59 | scan103 60 | scan104 61 | scan105 62 | scan107 63 | scan108 64 | scan109 65 | scan111 66 | scan112 67 | scan113 68 | scan115 69 | scan116 70 | scan119 71 | scan120 72 | scan121 73 | scan122 74 | scan123 75 | scan124 76 | scan125 77 | scan126 78 | scan127 79 | scan128 -------------------------------------------------------------------------------- /scripts/tnt/test_tnt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source scripts/data_path.sh 3 | 4 | THISNAME="blend/geomvsnet" 5 | BESTEPOCH="15" 6 | 7 | LOG_DIR="./checkpoints/"$THISNAME 8 | CKPT_FILE=$LOG_DIR"/model_"$BESTEPOCH".ckpt" 9 | TNT_OUT_DIR="./outputs/tnt/"$THISNAME 10 | 11 | # Intermediate 12 | CUDA_VISIBLE_DEVICES=0 python3 test.py ${@} \ 13 | --which_dataset="tnt" --loadckpt=$CKPT_FILE --batch_size=1 \ 14 | --outdir=$TNT_OUT_DIR --logdir=$LOG_DIR --nolog \ 15 | --testpath=$TNT_ROOT --testlist="datasets/lists/tnt/intermediate.txt" --split="intermediate" \ 16 | \ 17 | --n_views="11" --img_mode="resize" --cam_mode="origin" 18 | 19 | # Advanced 20 | CUDA_VISIBLE_DEVICES=0 python3 test.py ${@} \ 21 | --which_dataset="tnt" --loadckpt=$CKPT_FILE --batch_size=1 \ 22 | --outdir=$TNT_OUT_DIR --logdir=$LOG_DIR --nolog \ 23 | --testpath=$TNT_ROOT --testlist="datasets/lists/tnt/advanced.txt" --split="advanced" \ 24 | \ 25 | --n_views="11" --img_mode="resize" --cam_mode="origin" -------------------------------------------------------------------------------- /datasets/evaluations/dtu_parallel/reducePts_haa.m: -------------------------------------------------------------------------------- 1 | function [ptsOut,indexSet] = reducePts_haa(pts, dst) 2 | 3 | %Reduces a point set, pts, in a stochastic manner, such that the minimum sdistance 4 | % between points is 'dst'. Writen by abd, edited by haa, then by raje 5 | 6 | nPoints=size(pts,2); 7 | 8 | indexSet=true(nPoints,1); 9 | RandOrd=randperm(nPoints); 10 | 11 | %tic 12 | NS = KDTreeSearcher(pts'); 13 | %toc 14 | 15 | % search the KNTree for close neighbours in a chunk-wise fashion to save memory if point cloud is really big 16 | Chunks=1:min(4e6,nPoints-1):nPoints; 17 | Chunks(end)=nPoints; 18 | 19 | for cChunk=1:(length(Chunks)-1) 20 | Range=Chunks(cChunk):Chunks(cChunk+1); 21 | 22 | idx = rangesearch(NS,pts(:,RandOrd(Range))',dst); 23 | 24 | for i = 1:size(idx,1) 25 | id =RandOrd(i-1+Chunks(cChunk)); 26 | if (indexSet(id)) 27 | indexSet(idx{i}) = 0; 28 | indexSet(id) = 1; 29 | end 30 | end 31 | end 32 | 33 | ptsOut = pts(:,indexSet); 34 | 35 | disp(['downsample factor: ' num2str(nPoints/sum(indexSet))]); -------------------------------------------------------------------------------- /models/filter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Description: Basic implementation of Frequency Domain Filtering strategy (Sec 3.2 in the paper). 3 | # @Author: Zhe Zhang (doublez@stu.pku.edu.cn) 4 | # @Affiliation: Peking University (PKU) 5 | # @LastEditDate: 2023-09-07 6 | 7 | import torch 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def frequency_domain_filter(depth, rho_ratio): 13 | """ 14 | large rho_ratio -> more information filtered 15 | """ 16 | f = torch.fft.fft2(depth) 17 | fshift = torch.fft.fftshift(f) 18 | 19 | b, h, w = depth.shape 20 | k_h, k_w = h/rho_ratio, w/rho_ratio 21 | 22 | fshift[:,:int(h/2-k_h/2),:] = 0 23 | fshift[:,int(h/2+k_h/2):,:] = 0 24 | fshift[:,:,:int(w/2-k_w/2)] = 0 25 | fshift[:,:,int(w/2+k_w/2):] = 0 26 | 27 | ishift = torch.fft.ifftshift(fshift) 28 | idepth = torch.fft.ifft2(ishift) 29 | depth_filtered = torch.abs(idepth) 30 | 31 | return depth_filtered 32 | 33 | 34 | def visual_fft_fig(fshift): 35 | fft_fig = torch.abs(20 * torch.log(fshift)) 36 | plt.figure(figsize=(10, 10)) 37 | plt.subplot(121) 38 | plt.imshow(fft_fig[0,:,:], cmap = 'gray') -------------------------------------------------------------------------------- /scripts/dtu/fusion_dtu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source scripts/data_path.sh 3 | 4 | THISNAME="geomvsnet" 5 | FUSION_METHOD="open3d" 6 | 7 | LOG_DIR="./checkpoints/dtu/"$THISNAME 8 | DTU_OUT_DIR="./outputs/dtu/"$THISNAME 9 | 10 | if [ $FUSION_METHOD = "pcd" ] ; then 11 | python3 fusions/dtu/pcd.py ${@} \ 12 | --testpath=$DTU_TEST_ROOT --testlist="datasets/lists/dtu/test.txt" \ 13 | --outdir=$DTU_OUT_DIR --logdir=$LOG_DIR --nolog \ 14 | --num_worker=1 \ 15 | \ 16 | --thres_view=4 --conf=0.5 \ 17 | \ 18 | --plydir=$DTU_OUT_DIR"/pcd_fusion_plys/" 19 | 20 | elif [ $FUSION_METHOD = "gipuma" ] ; then 21 | # source [/path/to/]anaconda3/etc/profile.d/conda.sh 22 | # conda activate fusibile 23 | CUDA_VISIBLE_DEVICES=0 python2 fusions/dtu/gipuma.py \ 24 | --root_dir=$DTU_TEST_ROOT --list_file="datasets/lists/dtu/test.txt" \ 25 | --fusibile_exe_path="fusions/fusibile" --out_folder="fusibile_fused" \ 26 | --depth_folder=$DTU_OUT_DIR \ 27 | --downsample_factor=1 \ 28 | \ 29 | --prob_threshold=0.5 --disp_threshold=0.25 --num_consistent=3 \ 30 | \ 31 | --plydir=$DTU_OUT_DIR"/gipuma_fusion_plys/" 32 | 33 | elif [ $FUSION_METHOD = "open3d" ] ; then 34 | CUDA_VISIBLE_DEVICES=0 python fusions/dtu/_open3d.py --device="cuda" \ 35 | --root_path=$DTU_TEST_ROOT \ 36 | --depth_path=$DTU_OUT_DIR \ 37 | --data_list="datasets/lists/dtu/test.txt" \ 38 | \ 39 | --prob_thresh=0.3 --dist_thresh=0.2 --num_consist=4 \ 40 | \ 41 | --ply_path=$DTU_OUT_DIR"/open3d_fusion_plys/" 42 | 43 | fi 44 | -------------------------------------------------------------------------------- /datasets/evaluations/dtu_parallel/MaxDistCP.m: -------------------------------------------------------------------------------- 1 | function Dist = MaxDistCP(Qto,Qfrom,BB,MaxDist) 2 | 3 | Dist=ones(1,size(Qfrom,2))*MaxDist; 4 | 5 | Range=floor((BB(2,:)-BB(1,:))/MaxDist); 6 | 7 | tic 8 | Done=0; 9 | LookAt=zeros(1,size(Qfrom,2)); 10 | for x=0:Range(1), 11 | for y=0:Range(2), 12 | for z=0:Range(3), 13 | 14 | Low=BB(1,:)+[x y z]*MaxDist; 15 | High=Low+MaxDist; 16 | 17 | idxF=find(Qfrom(1,:)>=Low(1) & Qfrom(2,:)>=Low(2) & Qfrom(3,:)>=Low(3) &... 18 | Qfrom(1,:)=Low(1) & Qto(2,:)>=Low(2) & Qto(3,:)>=Low(3) &... 25 | Qto(1,:)3)] 49 | end -------------------------------------------------------------------------------- /scripts/dtu/matlab_quan_dtu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source scripts/data_path.sh 3 | 4 | OUTNAME="geomvsnet" 5 | 6 | FUSIONMETHOD="open3d" 7 | 8 | # Evaluation 9 | echo "<<<<<<<<<< start parallel evaluation" 10 | METHOD='mvsnet' 11 | PLYPATH='../../../outputs/dtu/'$OUTNAME'/'$FUSIONMETHOD'_fusion_plys/' 12 | RESULTPATH='../../../outputs/dtu/'$OUTNAME'/'$FUSIONMETHOD'_quantitative/' 13 | LOGPATH='outputs/dtu/'$OUTNAME'/'$FUSIONMETHOD'_quantitative/'$OUTNAME'.log' 14 | 15 | mkdir -p 'outputs/dtu/'$OUTNAME'/'$FUSIONMETHOD'_quantitative/' 16 | 17 | set_array=(1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118) 18 | 19 | num_at_once=2 # 1 2 4 5 7 11 22 20 | times=`expr $((${#set_array[*]} / $num_at_once))` 21 | remain=`expr $((${#set_array[*]} - $num_at_once * $times))` 22 | this_group_num=0 23 | pos=0 24 | 25 | for ((t=0; t<$times; t++)) 26 | do 27 | if [ "$t" -ge `expr $(($times-$remain))` ] ; then 28 | this_group_num=`expr $(($num_at_once + 1))` 29 | else 30 | this_group_num=$num_at_once 31 | fi 32 | 33 | for set in "${set_array[@]:pos:this_group_num}" 34 | do 35 | matlab -nodesktop -nosplash -r "cd datasets/evaluations/dtu_parallel; dataPath='$DTU_QUANTITATIVE_ROOT'; plyPath='$PLYPATH'; resultsPath='$RESULTPATH'; method_string='$METHOD'; thisset='$set'; BaseEvalMain_web" & 36 | done 37 | wait 38 | 39 | pos=`expr $(($pos + $this_group_num))` 40 | 41 | done 42 | wait 43 | 44 | 45 | SET=[1,4,9,10,11,12,13,15,23,24,29,32,33,34,48,49,62,75,77,110,114,118] 46 | 47 | matlab -nodesktop -nosplash -r "cd datasets/evaluations/dtu_parallel; resultsPath='$RESULTPATH'; method_string='$METHOD'; set='$SET'; ComputeStat_web" > $LOGPATH -------------------------------------------------------------------------------- /datasets/evaluations/dtu_parallel/BaseEval2Obj_web.m: -------------------------------------------------------------------------------- 1 | function BaseEval2Obj_web(BaseEval,method_string,outputPath) 2 | 3 | if(nargin<3) 4 | outputPath='./'; 5 | end 6 | 7 | % tresshold for coloring alpha channel in the range of 0-10 mm 8 | dist_tresshold=10; 9 | 10 | cSet=BaseEval.cSet; 11 | 12 | Qdata=BaseEval.Qdata; 13 | alpha=min(BaseEval.Ddata,dist_tresshold)/dist_tresshold; 14 | 15 | fid=fopen([outputPath method_string '2Stl_' num2str(cSet) ' .obj'],'w+'); 16 | 17 | for cP=1:size(Qdata,2) 18 | if(BaseEval.DataInMask(cP)) 19 | C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold) 20 | else 21 | C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points outside the mask (which are not included in the analysis) 22 | end 23 | fprintf(fid,'v %f %f %f %f %f %f\n',[Qdata(1,cP) Qdata(2,cP) Qdata(3,cP) C(1) C(2) C(3)]); 24 | end 25 | fclose(fid); 26 | 27 | disp('Data2Stl saved as obj') 28 | 29 | Qstl=BaseEval.Qstl; 30 | fid=fopen([outputPath 'Stl2' method_string '_' num2str(cSet) '.obj'],'w+'); 31 | 32 | alpha=min(BaseEval.Dstl,dist_tresshold)/dist_tresshold; 33 | 34 | for cP=1:size(Qstl,2) 35 | if(BaseEval.StlAbovePlane(cP)) 36 | C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold) 37 | else 38 | C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points below plane (which are not included in the analysis) 39 | end 40 | fprintf(fid,'v %f %f %f %f %f %f\n',[Qstl(1,cP) Qstl(2,cP) Qstl(3,cP) C(1) C(2) C(3)]); 41 | end 42 | fclose(fid); 43 | 44 | disp('Stl2Data saved as obj') -------------------------------------------------------------------------------- /datasets/evaluations/dtu_parallel/BaseEvalMain_web.m: -------------------------------------------------------------------------------- 1 | format compact 2 | 3 | representation_string='Points'; %mvs representation 'Points' or 'Surfaces' 4 | 5 | switch representation_string 6 | case 'Points' 7 | eval_string='_Eval_'; %results naming 8 | settings_string=''; 9 | end 10 | 11 | 12 | dst=0.2; %Min dist between points when reducing 13 | 14 | % start this evaluation 15 | cSet = str2num(thisset) 16 | 17 | %input data name 18 | DataInName = [plyPath sprintf('%s%03d.ply', lower(method_string), cSet)] 19 | 20 | 21 | 22 | %results name 23 | EvalName=[resultsPath method_string eval_string num2str(cSet) '.mat'] 24 | 25 | %check if file is already computed 26 | if(~exist(EvalName,'file')) 27 | disp(DataInName); 28 | 29 | time=clock;time(4:5), drawnow 30 | 31 | tic 32 | Mesh = plyread(DataInName); 33 | Qdata=[Mesh.vertex.x Mesh.vertex.y Mesh.vertex.z]'; 34 | toc 35 | 36 | BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath); 37 | 38 | disp('Saving results'), drawnow 39 | toc 40 | save(EvalName,'BaseEval'); 41 | toc 42 | 43 | % write obj-file of evaluation 44 | % BaseEval2Obj_web(BaseEval,method_string, resultsPath) 45 | % toc 46 | time=clock;time(4:5), drawnow 47 | 48 | BaseEval.MaxDist=20; %outlier threshold of 20 mm 49 | 50 | BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane 51 | BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl0 & Qv(1,:)<=size(ObsMask,1) & Qv(2,:)>0 & Qv(2,:)<=size(ObsMask,2) & Qv(3,:)>0 & Qv(3,:)<=size(ObsMask,3)); 37 | MidxA=sub2ind(size(ObsMask),Qv(1,Midx1),Qv(2,Midx1),Qv(3,Midx1)); 38 | Midx2=find(ObsMask(MidxA)); 39 | 40 | BaseEval.DataInMask(1:size(Qv,2))=false; 41 | BaseEval.DataInMask(Midx1(Midx2))=true; %If Data is within the mask 42 | 43 | BaseEval.cSet=cSet; 44 | BaseEval.Margin=Margin; %Margin of masks 45 | BaseEval.dst=dst; %Min dist between points when reducing 46 | BaseEval.Qdata=Qdata; %Input data points 47 | BaseEval.Ddata=Ddata; %distance from data to stl 48 | BaseEval.Qstl=Qstl; %Input stl points 49 | BaseEval.Dstl=Dstl; %Distance from the stl to data 50 | 51 | load([dataPath '/ObsMask/Plane' num2str(cSet)],'P') 52 | BaseEval.GroundPlane=P; % Plane used to destinguise which Stl points are 'used' 53 | BaseEval.StlAbovePlane=(P'*[Qstl;ones(1,size(Qstl,2))])>0; %Is stl above 'ground plane' 54 | BaseEval.Time=clock; %Time when computation is finished -------------------------------------------------------------------------------- /datasets/data_io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Description: I/O functions for depth maps and camera files. 3 | # @Author: Zhe Zhang (doublez@stu.pku.edu.cn) 4 | # @Affiliation: Peking University (PKU) 5 | # @LastEditDate: 2023-09-07 6 | 7 | import sys, re 8 | import numpy as np 9 | 10 | 11 | def read_pfm(filename): 12 | file = open(filename, 'rb') 13 | color = None 14 | width = None 15 | height = None 16 | scale = None 17 | endian = None 18 | 19 | header = file.readline().decode('utf-8').rstrip() 20 | if header == 'PF': 21 | color = True 22 | elif header == 'Pf': 23 | color = False 24 | else: 25 | raise Exception('Not a PFM file.') 26 | 27 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 28 | if dim_match: 29 | width, height = map(int, dim_match.groups()) 30 | else: 31 | raise Exception('Malformed PFM header.') 32 | 33 | scale = float(file.readline().rstrip()) 34 | if scale < 0: # little-endian 35 | endian = '<' 36 | scale = -scale 37 | else: 38 | endian = '>' # big-endian 39 | 40 | data = np.fromfile(file, endian + 'f') 41 | shape = (height, width, 3) if color else (height, width) 42 | 43 | data = np.reshape(data, shape) 44 | data = np.flipud(data) 45 | file.close() 46 | return data, scale 47 | 48 | 49 | def save_pfm(filename, image, scale=1): 50 | file = open(filename, "wb") 51 | color = None 52 | 53 | image = np.flipud(image) 54 | 55 | if image.dtype.name != 'float32': 56 | raise Exception('Image dtype must be float32.') 57 | 58 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 59 | color = True 60 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 61 | color = False 62 | else: 63 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 64 | 65 | file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8')) 66 | file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8')) 67 | 68 | endian = image.dtype.byteorder 69 | 70 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 71 | scale = -scale 72 | 73 | file.write(('%f\n' % scale).encode('utf-8')) 74 | 75 | image.tofile(file) 76 | file.close() 77 | 78 | 79 | def write_cam(file, cam): 80 | f = open(file, "w") 81 | f.write('extrinsic\n') 82 | for i in range(0, 4): 83 | for j in range(0, 4): 84 | f.write(str(cam[0][i][j]) + ' ') 85 | f.write('\n') 86 | f.write('\n') 87 | 88 | f.write('intrinsic\n') 89 | for i in range(0, 3): 90 | for j in range(0, 3): 91 | f.write(str(cam[1][i][j]) + ' ') 92 | f.write('\n') 93 | 94 | f.write('\n' + str(cam[1][3][0]) + ' ' + str(cam[1][3][1]) + ' ' + str(cam[1][3][2]) + ' ' + str(cam[1][3][3]) + '\n') 95 | 96 | f.close() -------------------------------------------------------------------------------- /datasets/lists/blendedmvs/low_res_all.txt: -------------------------------------------------------------------------------- 1 | 5c1f33f1d33e1f2e4aa6dda4 2 | 5bfe5ae0fe0ea555e6a969ca 3 | 5bff3c5cfe0ea555e6bcbf3a 4 | 58eaf1513353456af3a1682a 5 | 5bfc9d5aec61ca1dd69132a2 6 | 5bf18642c50e6f7f8bdbd492 7 | 5bf26cbbd43923194854b270 8 | 5bf17c0fd439231948355385 9 | 5be3ae47f44e235bdbbc9771 10 | 5be3a5fb8cfdd56947f6b67c 11 | 5bbb6eb2ea1cfa39f1af7e0c 12 | 5ba75d79d76ffa2c86cf2f05 13 | 5bb7a08aea1cfa39f1a947ab 14 | 5b864d850d072a699b32f4ae 15 | 5b6eff8b67b396324c5b2672 16 | 5b6e716d67b396324c2d77cb 17 | 5b69cc0cb44b61786eb959bf 18 | 5b62647143840965efc0dbde 19 | 5b60fa0c764f146feef84df0 20 | 5b558a928bbfb62204e77ba2 21 | 5b271079e0878c3816dacca4 22 | 5b08286b2775267d5b0634ba 23 | 5afacb69ab00705d0cefdd5b 24 | 5af28cea59bc705737003253 25 | 5af02e904c8216544b4ab5a2 26 | 5aa515e613d42d091d29d300 27 | 5c34529873a8df509ae57b58 28 | 5c34300a73a8df509add216d 29 | 5c1af2e2bee9a723c963d019 30 | 5c1892f726173c3a09ea9aeb 31 | 5c0d13b795da9479e12e2ee9 32 | 5c062d84a96e33018ff6f0a6 33 | 5bfd0f32ec61ca1dd69dc77b 34 | 5bf21799d43923194842c001 35 | 5bf3a82cd439231948877aed 36 | 5bf03590d4392319481971dc 37 | 5beb6e66abd34c35e18e66b9 38 | 5be883a4f98cee15019d5b83 39 | 5be47bf9b18881428d8fbc1d 40 | 5bcf979a6d5f586b95c258cd 41 | 5bce7ac9ca24970bce4934b6 42 | 5bb8a49aea1cfa39f1aa7f75 43 | 5b78e57afc8fcf6781d0c3ba 44 | 5b21e18c58e2823a67a10dd8 45 | 5b22269758e2823a67a3bd03 46 | 5b192eb2170cf166458ff886 47 | 5ae2e9c5fe405c5076abc6b2 48 | 5adc6bd52430a05ecb2ffb85 49 | 5ab8b8e029f5351f7f2ccf59 50 | 5abc2506b53b042ead637d86 51 | 5ab85f1dac4291329b17cb50 52 | 5a969eea91dfc339a9a3ad2c 53 | 5a8aa0fab18050187cbe060e 54 | 5a7d3db14989e929563eb153 55 | 5a69c47d0d5d0a7f3b2e9752 56 | 5a618c72784780334bc1972d 57 | 5a6464143d809f1d8208c43c 58 | 5a588a8193ac3d233f77fbca 59 | 5a57542f333d180827dfc132 60 | 5a572fd9fc597b0478a81d14 61 | 5a563183425d0f5186314855 62 | 5a4a38dad38c8a075495b5d2 63 | 5a48d4b2c7dab83a7d7b9851 64 | 5a489fb1c7dab83a7d7b1070 65 | 5a48ba95c7dab83a7d7b44ed 66 | 5a3ca9cb270f0e3f14d0eddb 67 | 5a3cb4e4270f0e3f14d12f43 68 | 5a3f4aba5889373fbbc5d3b5 69 | 5a0271884e62597cdee0d0eb 70 | 59e864b2a9e91f2c5529325f 71 | 599aa591d5b41f366fed0d58 72 | 59350ca084b7f26bf5ce6eb8 73 | 59338e76772c3e6384afbb15 74 | 5c20ca3a0843bc542d94e3e2 75 | 5c1dbf200843bc542d8ef8c4 76 | 5c1b1500bee9a723c96c3e78 77 | 5bea87f4abd34c35e1860ab5 78 | 5c2b3ed5e611832e8aed46bf 79 | 57f8d9bbe73f6760f10e916a 80 | 5bf7d63575c26f32dbf7413b 81 | 5be4ab93870d330ff2dce134 82 | 5bd43b4ba6b28b1ee86b92dd 83 | 5bccd6beca24970bce448134 84 | 5bc5f0e896b66a2cd8f9bd36 85 | 5b908d3dc6ab78485f3d24a9 86 | 5b2c67b5e0878c381608b8d8 87 | 5b4933abf2b5f44e95de482a 88 | 5b3b353d8d46a939f93524b9 89 | 5acf8ca0f3d8a750097e4b15 90 | 5ab8713ba3799a1d138bd69a 91 | 5aa235f64a17b335eeaf9609 92 | 5aa0f9d7a9efce63548c69a1 93 | 5a8315f624b8e938486e0bd8 94 | 5a48c4e9c7dab83a7d7b5cc7 95 | 59ecfd02e225f6492d20fcc9 96 | 59f87d0bfa6280566fb38c9a 97 | 59f363a8b45be22330016cad 98 | 59f70ab1e5c5d366af29bf3e 99 | 59e75a2ca9e91f2c5526005d 100 | 5947719bf1b45630bd096665 101 | 5947b62af1b45630bd0c2a02 102 | 59056e6760bb961de55f3501 103 | 58f7f7299f5b5647873cb110 104 | 58cf4771d0f5fb221defe6da 105 | 58d36897f387231e6c929903 106 | 58c4bb4f4a69c55606122be4 107 | 5b7a3890fc8fcf6781e2593a 108 | 5c189f2326173c3a09ed7ef3 109 | 5b950c71608de421b1e7318f 110 | 5a6400933d809f1d8200af15 111 | 59d2657f82ca7774b1ec081d 112 | 5ba19a8a360c7c30c1c169df 113 | 59817e4a1bd4b175e7038d19 -------------------------------------------------------------------------------- /models/utils/opts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Description: Options settings & configurations for GeoMVSNet. 3 | # @Author: Zhe Zhang (doublez@stu.pku.edu.cn) 4 | # @Affiliation: Peking University (PKU) 5 | # @LastEditDate: 2023-09-07 6 | 7 | import argparse 8 | 9 | def get_opts(): 10 | parser = argparse.ArgumentParser(description="args") 11 | 12 | # global settings 13 | parser.add_argument('--mode', default='train', help='train or test', choices=['train', 'test', 'val']) 14 | parser.add_argument('--which_dataset', default='dtu', choices=['dtu', 'tnt', 'blendedmvs'], help='which dataset for using') 15 | 16 | parser.add_argument('--n_views', type=int, default=5, help='num of view') 17 | parser.add_argument('--levels', type=int, default=4, help='num of stages') 18 | parser.add_argument('--hypo_plane_num_stages', type=str, default="8,8,4,4", help='num of hypothesis planes for each stage') 19 | parser.add_argument('--depth_interal_ratio_stages', type=str, default="0.5,0.5,0.5,1", help='depth interals for each stage') 20 | parser.add_argument("--feat_base_channel", type=int, default=8, help='channel num for base feature') 21 | parser.add_argument("--reg_base_channel", type=int, default=8, help='channel num for regularization') 22 | parser.add_argument('--group_cor_dim_stages', type=str, default="8,8,4,4", help='group correlation dim') 23 | 24 | parser.add_argument('--batch_size', type=int, default=1, help='batch size for training') 25 | parser.add_argument('--data_scale', type=str, choices=['mid', 'raw'], help='use mid or raw resolution') 26 | parser.add_argument('--trainpath', help='data path for training') 27 | parser.add_argument('--testpath', help='data path for testing') 28 | parser.add_argument('--trainlist', help='data list for training') 29 | parser.add_argument('--testlist', help='data list for testing') 30 | 31 | 32 | # training config 33 | parser.add_argument('--stage_lw', type=str, default="1,1,1,1", help='loss weight for different stages') 34 | 35 | parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train') 36 | parser.add_argument('--lr_scheduler', type=str, default='MS', help='scheduler for learning rate') 37 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 38 | parser.add_argument('--lrepochs', type=str, default="1,3,5,7,9,11,13,15:1.5", help='epoch ids to downscale lr and the downscale rate') 39 | parser.add_argument('--wd', type=float, default=0.0, help='weight decay') 40 | 41 | parser.add_argument('--summary_freq', type=int, default=100, help='print and summary frequency') 42 | parser.add_argument('--save_freq', type=int, default=1, help='save checkpoint frequency') 43 | parser.add_argument('--eval_freq', type=int, default=1, help='eval frequency') 44 | 45 | parser.add_argument('--robust_train', action='store_true',help='robust training') 46 | 47 | 48 | # testing config 49 | parser.add_argument('--split', type=str, choices=['intermediate', 'advanced'], help='intermediate|advanced for tanksandtemples') 50 | parser.add_argument('--img_mode', type=str, default='resize', choices=['resize', 'crop'], help='image resolution matching strategy for TNT dataset') 51 | parser.add_argument('--cam_mode', type=str, default='origin', choices=['origin', 'short_range'], help='camera parameter strategy for TNT dataset') 52 | 53 | parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint') 54 | parser.add_argument('--logdir', default='./checkpoints/debug', help='the directory to save checkpoints/logs') 55 | parser.add_argument('--nolog', action='store_true', help='do not log into .log file') 56 | parser.add_argument('--notensorboard', action='store_true', help='do not log into tensorboard') 57 | parser.add_argument('--save_conf_all_stages', action='store_true', help='save confidence maps for all stages') 58 | parser.add_argument('--outdir', default='./outputs', help='output dir') 59 | parser.add_argument('--resume', action='store_true', help='continue to train the model') 60 | 61 | 62 | # pytorch config 63 | parser.add_argument('--device', default='cuda', help='device to use') 64 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed') 65 | parser.add_argument('--pin_m', action='store_true', help='data loader pin memory') 66 | parser.add_argument("--local_rank", type=int, default=0) 67 | 68 | return parser.parse_args() -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Description: Loss Functions (Sec 3.4 in the paper). 3 | # @Author: Zhe Zhang (doublez@stu.pku.edu.cn) 4 | # @Affiliation: Peking University (PKU) 5 | # @LastEditDate: 2023-09-07 6 | 7 | import torch 8 | 9 | 10 | def geomvsnet_loss(inputs, depth_gt_ms, mask_ms, **kwargs): 11 | 12 | stage_lw = kwargs.get("stage_lw", [1, 1, 1, 1]) 13 | depth_values = kwargs.get("depth_values") 14 | depth_min, depth_max = depth_values[:,0], depth_values[:,-1] 15 | 16 | total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False) 17 | pw_loss_stages = [] 18 | dds_loss_stages = [] 19 | for stage_idx, (stage_inputs, stage_key) in enumerate([(inputs[k], k) for k in inputs.keys() if "stage" in k]): 20 | 21 | depth = stage_inputs['depth_filtered'] 22 | prob_volume = stage_inputs['prob_volume'] 23 | depth_value = stage_inputs['depth_hypo'] 24 | 25 | depth_gt = depth_gt_ms[stage_key] 26 | mask = mask_ms[stage_key] > 0.5 27 | 28 | 29 | # pw loss 30 | pw_loss = pixel_wise_loss(prob_volume, depth_gt, mask, depth_value) 31 | pw_loss_stages.append(pw_loss) 32 | 33 | # dds loss 34 | dds_loss = depth_distribution_similarity_loss(depth, depth_gt, mask, depth_min, depth_max) 35 | dds_loss_stages.append(dds_loss) 36 | 37 | # total loss 38 | lam1, lam2 = 0.8, 0.2 39 | total_loss = total_loss + stage_lw[stage_idx] * (lam1 * pw_loss + lam2 * dds_loss) 40 | 41 | 42 | depth_pred = stage_inputs['depth'] 43 | depth_gt = depth_gt_ms[stage_key] 44 | epe = cal_metrics(depth_pred, depth_gt, mask, depth_min, depth_max) 45 | 46 | return total_loss, epe, pw_loss_stages, dds_loss_stages 47 | 48 | 49 | 50 | def pixel_wise_loss(prob_volume, depth_gt, mask, depth_value): 51 | mask_true = mask 52 | valid_pixel_num = torch.sum(mask_true, dim=[1,2])+1e-12 53 | 54 | shape = depth_gt.shape 55 | 56 | depth_num = depth_value.shape[1] 57 | depth_value_mat = depth_value 58 | 59 | gt_index_image = torch.argmin(torch.abs(depth_value_mat-depth_gt.unsqueeze(1)), dim=1) 60 | 61 | gt_index_image = torch.mul(mask_true, gt_index_image.type(torch.float)) 62 | gt_index_image = torch.round(gt_index_image).type(torch.long).unsqueeze(1) 63 | 64 | gt_index_volume = torch.zeros(shape[0], depth_num, shape[1], shape[2]).type(mask_true.type()).scatter_(1, gt_index_image, 1) 65 | cross_entropy_image = -torch.sum(gt_index_volume * torch.log(prob_volume+1e-12), dim=1).squeeze(1) 66 | masked_cross_entropy_image = torch.mul(mask_true, cross_entropy_image) 67 | masked_cross_entropy = torch.sum(masked_cross_entropy_image, dim=[1, 2]) 68 | 69 | masked_cross_entropy = torch.mean(masked_cross_entropy / valid_pixel_num) 70 | 71 | pw_loss = masked_cross_entropy 72 | return pw_loss 73 | 74 | 75 | def depth_distribution_similarity_loss(depth, depth_gt, mask, depth_min, depth_max): 76 | depth_norm = depth * 128 / (depth_max - depth_min)[:,None,None] 77 | depth_gt_norm = depth_gt * 128 / (depth_max - depth_min)[:,None,None] 78 | 79 | M_bins = 48 80 | kl_min = torch.min(torch.min(depth_gt), depth.mean()-3.*depth.std()) 81 | kl_max = torch.max(torch.max(depth_gt), depth.mean()+3.*depth.std()) 82 | bins = torch.linspace(kl_min, kl_max, steps=M_bins) 83 | 84 | kl_divs = [] 85 | for i in range(len(bins) - 1): 86 | bin_mask = (depth_gt >= bins[i]) & (depth_gt < bins[i+1]) 87 | merged_mask = mask & bin_mask 88 | 89 | if merged_mask.sum() > 0: 90 | p = depth_norm[merged_mask] 91 | q = depth_gt_norm[merged_mask] 92 | kl_div = torch.nn.functional.kl_div(torch.log(p)-torch.log(q), p, reduction='batchmean') 93 | kl_div = torch.log(kl_div) 94 | kl_divs.append(kl_div) 95 | 96 | dds_loss = sum(kl_divs) 97 | return dds_loss 98 | 99 | 100 | def cal_metrics(depth_pred, depth_gt, mask, depth_min, depth_max): 101 | depth_pred_norm = depth_pred * 128 / (depth_max - depth_min)[:,None,None] 102 | depth_gt_norm = depth_gt * 128 / (depth_max - depth_min)[:,None,None] 103 | 104 | abs_err = torch.abs(depth_pred_norm[mask] - depth_gt_norm[mask]) 105 | epe = abs_err.mean() 106 | err1= (abs_err<=1).float().mean()*100 107 | err3 = (abs_err<=3).float().mean()*100 108 | 109 | return epe # err1, err3 -------------------------------------------------------------------------------- /outputs/visual.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "- @Description: Juputer notebook for visualizing depth maps.\n", 8 | "- @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n", 9 | "- @Affiliation: Peking University (PKU)\n", 10 | "- @LastEditDate: 2023-09-07" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "ExecutionIndicator": { 18 | "show": true 19 | }, 20 | "tags": [] 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "import sys, os\n", 25 | "sys.path.append('../')\n", 26 | "import numpy as np\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "import re\n", 29 | "\n", 30 | "\n", 31 | "def read_pfm(filename):\n", 32 | " file = open(filename, 'rb')\n", 33 | " color = None\n", 34 | " width = None\n", 35 | " height = None\n", 36 | " scale = None\n", 37 | " endian = None\n", 38 | "\n", 39 | " header = file.readline().decode('utf-8').rstrip()\n", 40 | " if header == 'PF':\n", 41 | " color = True\n", 42 | " elif header == 'Pf':\n", 43 | " color = False\n", 44 | " else:\n", 45 | " raise Exception('Not a PFM file.')\n", 46 | "\n", 47 | " dim_match = re.match(r'^(\\d+)\\s(\\d+)\\s$', file.readline().decode('utf-8'))\n", 48 | " if dim_match:\n", 49 | " width, height = map(int, dim_match.groups())\n", 50 | " else:\n", 51 | " raise Exception('Malformed PFM header.')\n", 52 | "\n", 53 | " scale = float(file.readline().rstrip())\n", 54 | " if scale < 0: # little-endian\n", 55 | " endian = '<'\n", 56 | " scale = -scale\n", 57 | " else:\n", 58 | " endian = '>' # big-endian\n", 59 | "\n", 60 | " data = np.fromfile(file, endian + 'f')\n", 61 | " shape = (height, width, 3) if color else (height, width)\n", 62 | "\n", 63 | " data = np.reshape(data, shape)\n", 64 | " data = np.flipud(data)\n", 65 | " file.close()\n", 66 | " return data, scale\n", 67 | "\n", 68 | "\n", 69 | "def read_depth(filename):\n", 70 | " depth = read_pfm(filename)[0]\n", 71 | " return np.array(depth, dtype=np.float32)\n", 72 | "\n", 73 | "\n", 74 | "assert False" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## DTU" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": { 88 | "ExecutionIndicator": { 89 | "show": true 90 | }, 91 | "tags": [] 92 | }, 93 | "outputs": [], 94 | "source": [ 95 | "exp_name = 'dtu/geomvsnet'\n", 96 | "depth_name = \"00000009.pfm\"\n", 97 | "\n", 98 | "scans = os.listdir(os.path.join(exp_name))\n", 99 | "scans = list(filter(lambda x: x.startswith(\"scan\"), scans))\n", 100 | "scans.sort(key=lambda x: int(x[4:]))\n", 101 | "for scan in scans:\n", 102 | " depth_filename = os.path.join(exp_name, scan, \"depth_est\", depth_name)\n", 103 | " if not os.path.exists(depth_filename): continue\n", 104 | " depth = read_depth(depth_filename)\n", 105 | "\n", 106 | " confidence_filename = os.path.join(exp_name, scan, \"confidence\", depth_name)\n", 107 | " confidence = read_depth(confidence_filename)\n", 108 | "\n", 109 | " print(scan, depth_name)\n", 110 | "\n", 111 | " plt.figure(figsize=(12, 12))\n", 112 | " plt.subplot(1, 2, 1)\n", 113 | " plt.xticks([]), plt.yticks([]), plt.axis('off')\n", 114 | " plt.imshow(depth, 'viridis', vmin=500, vmax=830)\n", 115 | "\n", 116 | " plt.subplot(1, 2, 2)\n", 117 | " plt.xticks([]), plt.yticks([]), plt.axis('off')\n", 118 | " plt.imshow(confidence, 'viridis')\n", 119 | " plt.show()" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "## TNT" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": { 133 | "tags": [] 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "exp_name = './tnt/blend/geomvsnet/'\n", 138 | "depth_name = \"00000009.pfm\"\n", 139 | "\n", 140 | "with open(\"../datasets/lists/tnt/intermediate.txt\") as f:\n", 141 | " scans_i = [line.rstrip() for line in f.readlines()]\n", 142 | "\n", 143 | "with open(\"../datasets/lists/tnt/advanced.txt\") as f:\n", 144 | " scans_a = [line.rstrip() for line in f.readlines()]\n", 145 | "\n", 146 | "scans = scans_i + scans_a\n", 147 | "\n", 148 | "for scan in scans:\n", 149 | "\n", 150 | " depth_filename = os.path.join(exp_name, scan, \"depth_est\", depth_name)\n", 151 | " if not os.path.exists(depth_filename): continue\n", 152 | " depth = read_depth(depth_filename)\n", 153 | "\n", 154 | " print(scan, depth_name, depth.shape)\n", 155 | "\n", 156 | " plt.figure(figsize=(12, 12))\n", 157 | " plt.xticks([]), plt.yticks([]), plt.axis('off')\n", 158 | " plt.imshow(depth, 'viridis', vmin=0, vmax=10)\n", 159 | "\n", 160 | " plt.show()" 161 | ] 162 | } 163 | ], 164 | "metadata": { 165 | "kernelspec": { 166 | "display_name": "Python 3", 167 | "language": "python", 168 | "name": "python3" 169 | }, 170 | "language_info": { 171 | "codemirror_mode": { 172 | "name": "ipython", 173 | "version": 3 174 | }, 175 | "file_extension": ".py", 176 | "mimetype": "text/x-python", 177 | "name": "python", 178 | "nbconvert_exporter": "python", 179 | "pygments_lexer": "ipython3", 180 | "version": "3.6.12" 181 | }, 182 | "vscode": { 183 | "interpreter": { 184 | "hash": "d253918f84404206ad3cf9c22ee3709ef6e34cbea610b0ac9787033d60da5e03" 185 | } 186 | } 187 | }, 188 | "nbformat": 4, 189 | "nbformat_minor": 4 190 | } 191 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Description: Main process of network testing. 3 | # @Author: Zhe Zhang (doublez@stu.pku.edu.cn) 4 | # @Affiliation: Peking University (PKU) 5 | # @LastEditDate: 2023-09-07 6 | 7 | import os, time, sys, gc, cv2, logging 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | from torch.utils.data import DataLoader 14 | 15 | from datasets.data_io import * 16 | from datasets.dtu import DTUDataset 17 | from datasets.tnt import TNTDataset 18 | 19 | from models.geomvsnet import GeoMVSNet 20 | from models.utils import * 21 | from models.utils.opts import get_opts 22 | 23 | 24 | cudnn.benchmark = True 25 | 26 | args = get_opts() 27 | 28 | 29 | def test(): 30 | total_time = 0 31 | with torch.no_grad(): 32 | for batch_idx, sample in enumerate(TestImgLoader): 33 | sample_cuda = tocuda(sample) 34 | start_time = time.time() 35 | 36 | # @Note GeoMVSNet main 37 | outputs = model( 38 | sample_cuda["imgs"], 39 | sample_cuda["proj_matrices"], sample_cuda["intrinsics_matrices"], 40 | sample_cuda["depth_values"], 41 | sample["filename"] 42 | ) 43 | 44 | end_time = time.time() 45 | total_time += end_time - start_time 46 | outputs = tensor2numpy(outputs) 47 | del sample_cuda 48 | 49 | filenames = sample["filename"] 50 | cams = sample["proj_matrices"]["stage{}".format(args.levels)].numpy() 51 | imgs = sample["imgs"] 52 | logger.info('Iter {}/{}, Time:{:.3f} Res:{}'.format(batch_idx, len(TestImgLoader), end_time - start_time, imgs[0].shape)) 53 | 54 | 55 | for filename, cam, img, depth_est, photometric_confidence in zip(filenames, cams, imgs, outputs["depth"], outputs["photometric_confidence"]): 56 | img = img[0].numpy() # ref view 57 | cam = cam[0] # ref cam 58 | 59 | depth_filename = os.path.join(args.outdir, filename.format('depth_est', '.pfm')) 60 | confidence_filename = os.path.join(args.outdir, filename.format('confidence', '.pfm')) 61 | cam_filename = os.path.join(args.outdir, filename.format('cams', '_cam.txt')) 62 | img_filename = os.path.join(args.outdir, filename.format('images', '.jpg')) 63 | os.makedirs(depth_filename.rsplit('/', 1)[0], exist_ok=True) 64 | os.makedirs(confidence_filename.rsplit('/', 1)[0], exist_ok=True) 65 | if args.which_dataset == 'dtu': 66 | os.makedirs(cam_filename.rsplit('/', 1)[0], exist_ok=True) 67 | os.makedirs(img_filename.rsplit('/', 1)[0], exist_ok=True) 68 | 69 | # save depth maps 70 | save_pfm(depth_filename, depth_est) 71 | 72 | # save confidence maps 73 | confidence_list = [outputs['stage{}'.format(i)]['photometric_confidence'].squeeze(0) for i in range(1,5)] 74 | photometric_confidence = confidence_list[-1] 75 | if not args.save_conf_all_stages: 76 | save_pfm(confidence_filename, photometric_confidence) 77 | else: 78 | for stage_idx, photometric_confidence in enumerate(confidence_list): 79 | if stage_idx != args.levels - 1: 80 | confidence_filename = os.path.join(args.outdir, filename.format('confidence', "_stage"+str(stage_idx)+'.pfm')) 81 | else: 82 | confidence_filename = os.path.join(args.outdir, filename.format('confidence', '.pfm')) 83 | save_pfm(confidence_filename, photometric_confidence) 84 | 85 | # save cams, img 86 | if args.which_dataset == 'dtu': 87 | write_cam(cam_filename, cam) 88 | img = np.clip(np.transpose(img, (1, 2, 0)) * 255, 0, 255).astype(np.uint8) 89 | img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 90 | cv2.imwrite(img_filename, img_bgr) 91 | 92 | torch.cuda.empty_cache() 93 | gc.collect() 94 | return total_time, len(TestImgLoader) 95 | 96 | 97 | def initLogger(): 98 | logger = logging.getLogger() 99 | logger.setLevel(logging.INFO) 100 | curTime = time.strftime('%Y%m%d-%H%M', time.localtime(time.time())) 101 | 102 | if args.which_dataset == 'tnt': 103 | logfile = os.path.join(args.logdir, 'TNT-test-' + curTime + '.log') 104 | else: 105 | logfile = os.path.join(args.logdir, 'test-' + curTime + '.log') 106 | 107 | formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s") 108 | if not args.nolog: 109 | fileHandler = logging.FileHandler(logfile, mode='a') 110 | fileHandler.setFormatter(formatter) 111 | logger.addHandler(fileHandler) 112 | consoleHandler = logging.StreamHandler(sys.stdout) 113 | consoleHandler.setFormatter(formatter) 114 | logger.addHandler(consoleHandler) 115 | logger.info("Logger initialized.") 116 | logger.info("Writing logs to file: {}".format(logfile)) 117 | logger.info("Current time: {}".format(curTime)) 118 | 119 | settings_str = "All settings:\n" 120 | for k,v in vars(args).items(): 121 | settings_str += '{0}: {1}\n'.format(k,v) 122 | logger.info(settings_str) 123 | 124 | return logger 125 | 126 | 127 | if __name__ == '__main__': 128 | logger = initLogger() 129 | 130 | # dataset, dataloader 131 | if args.which_dataset == 'dtu': 132 | test_dataset = DTUDataset(args.testpath, args.testlist, "test", args.n_views, max_wh=(1600, 1200)) 133 | elif args.which_dataset == 'tnt': 134 | test_dataset = TNTDataset(args.testpath, args.testlist, split=args.split, n_views=args.n_views, img_wh=(-1, 1024), cam_mode=args.cam_mode, img_mode=args.img_mode) 135 | 136 | TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False) 137 | 138 | # @Note GeoMVSNet model 139 | model = GeoMVSNet( 140 | levels=args.levels, 141 | hypo_plane_num_stages=[int(n) for n in args.hypo_plane_num_stages.split(",")], 142 | depth_interal_ratio_stages=[float(ir) for ir in args.depth_interal_ratio_stages.split(",")], 143 | feat_base_channel=args.feat_base_channel, 144 | reg_base_channel=args.reg_base_channel, 145 | group_cor_dim_stages=[int(n) for n in args.group_cor_dim_stages.split(",")], 146 | ) 147 | 148 | logger.info("loading model {}".format(args.loadckpt)) 149 | state_dict = torch.load(args.loadckpt, map_location=torch.device("cpu")) 150 | model.load_state_dict(state_dict['model'], strict=False) 151 | 152 | model.cuda() 153 | model.eval() 154 | 155 | test() -------------------------------------------------------------------------------- /datasets/tnt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Description: Data preprocessing and organization for Tanks and Temples dataset. 3 | # @Author: Zhe Zhang (doublez@stu.pku.edu.cn) 4 | # @Affiliation: Peking University (PKU) 5 | # @LastEditDate: 2023-09-07 6 | 7 | import os 8 | import cv2 9 | import numpy as np 10 | from PIL import Image 11 | 12 | from torch.utils.data import Dataset 13 | 14 | from datasets.data_io import * 15 | 16 | 17 | class TNTDataset(Dataset): 18 | def __init__(self, root_dir, list_file, split, n_views, **kwargs): 19 | super(TNTDataset, self).__init__() 20 | 21 | self.root_dir = root_dir 22 | self.list_file = list_file 23 | self.split = split 24 | self.n_views = n_views 25 | 26 | self.cam_mode = kwargs.get("cam_mode", "origin") # origin / short_range 27 | if self.cam_mode == 'short_range': assert self.split == "intermediate" 28 | self.img_mode = kwargs.get("img_mode", "resize") # resize / crop 29 | 30 | self.total_depths = 192 31 | self.depth_interval_table = { 32 | # intermediate 33 | 'Family': 2.5e-3, 'Francis': 1e-2, 'Horse': 1.5e-3, 'Lighthouse': 1.5e-2, 'M60': 5e-3, 'Panther': 5e-3, 'Playground': 7e-3, 'Train': 5e-3, 34 | # advanced 35 | 'Auditorium': 3e-2, 'Ballroom': 2e-2, 'Courtroom': 2e-2, 'Museum': 2e-2, 'Palace': 1e-2, 'Temple': 1e-2 36 | } 37 | self.img_wh = kwargs.get("img_wh", (-1, 1024)) 38 | 39 | self.metas = self.build_metas() 40 | 41 | 42 | def build_metas(self): 43 | metas = [] 44 | 45 | with open(os.path.join(self.list_file)) as f: 46 | scans = [line.rstrip() for line in f.readlines()] 47 | 48 | for scan in scans: 49 | with open(os.path.join(self.root_dir, self.split, scan, 'pair.txt')) as f: 50 | num_viewpoint = int(f.readline()) 51 | for view_idx in range(num_viewpoint): 52 | ref_view = int(f.readline().rstrip()) 53 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 54 | if len(src_views) != 0: 55 | metas += [(scan, -1, ref_view, src_views)] 56 | return metas 57 | 58 | 59 | def read_cam_file(self, filename): 60 | with open(filename) as f: 61 | lines = [line.rstrip() for line in f.readlines()] 62 | # extrinsics: line [1,5), 4x4 matrix 63 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ') 64 | extrinsics = extrinsics.reshape((4, 4)) 65 | # intrinsics: line [7-10), 3x3 matrix 66 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ') 67 | intrinsics = intrinsics.reshape((3, 3)) 68 | 69 | depth_min = float(lines[11].split()[0]) 70 | depth_max = float(lines[11].split()[-1]) 71 | 72 | return intrinsics, extrinsics, depth_min, depth_max 73 | 74 | 75 | def read_img(self, filename): 76 | img = Image.open(filename) 77 | np_img = np.array(img, dtype=np.float32) / 255. 78 | return np_img 79 | 80 | 81 | def scale_tnt_input(self, intrinsics, img): 82 | if self.img_mode == "crop": 83 | intrinsics[1,2] = intrinsics[1,2] - 28 # 1080 -> 1024 84 | img = img[28:1080-28, :, :] 85 | elif self.img_mode == "resize": 86 | height, width = img.shape[:2] 87 | 88 | max_w, max_h = self.img_wh[0], self.img_wh[1] 89 | if max_w == -1: 90 | max_w = width 91 | 92 | img = cv2.resize(img, (max_w, max_h)) 93 | 94 | scale_w = 1.0 * max_w / width 95 | intrinsics[0, :] *= scale_w 96 | scale_h = 1.0 * max_h / height 97 | intrinsics[1, :] *= scale_h 98 | 99 | return intrinsics, img 100 | 101 | 102 | def __len__(self): 103 | return len(self.metas) 104 | 105 | 106 | def __getitem__(self, idx): 107 | scan, _, ref_view, src_views = self.metas[idx] 108 | view_ids = [ref_view] + src_views[:self.n_views-1] 109 | 110 | imgs = [] 111 | depth_min = None 112 | depth_max = None 113 | 114 | proj_matrices_0 = [] 115 | proj_matrices_1 = [] 116 | proj_matrices_2 = [] 117 | proj_matrices_3 = [] 118 | 119 | for i, vid in enumerate(view_ids): 120 | img_filename = os.path.join(self.root_dir, self.split, scan, f'images/{vid:08d}.jpg') 121 | if self.cam_mode == 'short_range': 122 | # can only use for Intermediate 123 | proj_mat_filename = os.path.join(self.root_dir, self.split, scan, f'cams_{scan.lower()}/{vid:08d}_cam.txt') 124 | elif self.cam_mode == 'origin': 125 | proj_mat_filename = os.path.join(self.root_dir, self.split, scan, f'cams/{vid:08d}_cam.txt') 126 | 127 | img = self.read_img(img_filename) 128 | 129 | intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename) 130 | intrinsics, img = self.scale_tnt_input(intrinsics, img) 131 | imgs.append(img.transpose(2,0,1)) 132 | 133 | proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32) 134 | proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32) 135 | proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32) 136 | proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32) 137 | 138 | intrinsics[:2,:] *= 0.125 139 | proj_mat_0[0,:4,:4] = extrinsics.copy() 140 | proj_mat_0[1,:3,:3] = intrinsics.copy() 141 | int_mat_0 = intrinsics.copy() 142 | 143 | intrinsics[:2,:] *= 2 144 | proj_mat_1[0,:4,:4] = extrinsics.copy() 145 | proj_mat_1[1,:3,:3] = intrinsics.copy() 146 | int_mat_1 = intrinsics.copy() 147 | 148 | intrinsics[:2,:] *= 2 149 | proj_mat_2[0,:4,:4] = extrinsics.copy() 150 | proj_mat_2[1,:3,:3] = intrinsics.copy() 151 | int_mat_2 = intrinsics.copy() 152 | 153 | intrinsics[:2,:] *= 2 154 | proj_mat_3[0,:4,:4] = extrinsics.copy() 155 | proj_mat_3[1,:3,:3] = intrinsics.copy() 156 | int_mat_3 = intrinsics.copy() 157 | 158 | proj_matrices_0.append(proj_mat_0) 159 | proj_matrices_1.append(proj_mat_1) 160 | proj_matrices_2.append(proj_mat_2) 161 | proj_matrices_3.append(proj_mat_3) 162 | 163 | # reference view 164 | if i == 0: 165 | depth_min = depth_min_ 166 | if self.cam_mode == 'short_range': 167 | depth_max = depth_min + self.total_depths * self.depth_interval_table[scan] 168 | elif self.cam_mode == 'origin': 169 | depth_max = depth_max_ 170 | 171 | proj={} 172 | proj['stage1'] = np.stack(proj_matrices_0) 173 | proj['stage2'] = np.stack(proj_matrices_1) 174 | proj['stage3'] = np.stack(proj_matrices_2) 175 | proj['stage4'] = np.stack(proj_matrices_3) 176 | 177 | intrinsics_matrices = { 178 | "stage1": int_mat_0, 179 | "stage2": int_mat_1, 180 | "stage3": int_mat_2, 181 | "stage4": int_mat_3 182 | } 183 | 184 | sample = { 185 | "imgs": imgs, 186 | "proj_matrices": proj, 187 | "intrinsics_matrices": intrinsics_matrices, 188 | "depth_values": np.array([depth_min, depth_max], dtype=np.float32), 189 | "filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}" 190 | } 191 | 192 | return sample -------------------------------------------------------------------------------- /models/utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Description: Some useful utils. 3 | # @Author: Zhe Zhang (doublez@stu.pku.edu.cn) 4 | # @Affiliation: Peking University (PKU) 5 | # @LastEditDate: 2023-09-07 6 | 7 | import random 8 | import numpy as np 9 | 10 | import torch 11 | import torchvision.utils as vutils 12 | 13 | 14 | # torch.no_grad warpper for functions 15 | def make_nograd_func(func): 16 | def wrapper(*f_args, **f_kwargs): 17 | with torch.no_grad(): 18 | ret = func(*f_args, **f_kwargs) 19 | return ret 20 | 21 | return wrapper 22 | 23 | 24 | # convert a function into recursive style to handle nested dict/list/tuple variables 25 | def make_recursive_func(func): 26 | def wrapper(vars): 27 | if isinstance(vars, list): 28 | return [wrapper(x) for x in vars] 29 | elif isinstance(vars, tuple): 30 | return tuple([wrapper(x) for x in vars]) 31 | elif isinstance(vars, dict): 32 | return {k: wrapper(v) for k, v in vars.items()} 33 | else: 34 | return func(vars) 35 | 36 | return wrapper 37 | 38 | 39 | @make_recursive_func 40 | def tensor2float(vars): 41 | if isinstance(vars, float): 42 | return vars 43 | elif isinstance(vars, torch.Tensor): 44 | return vars.data.item() 45 | else: 46 | raise NotImplementedError("invalid input type {} for tensor2float".format(type(vars))) 47 | 48 | 49 | @make_recursive_func 50 | def tensor2numpy(vars): 51 | if isinstance(vars, np.ndarray): 52 | return vars 53 | elif isinstance(vars, torch.Tensor): 54 | return vars.detach().cpu().numpy().copy() 55 | else: 56 | raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) 57 | 58 | 59 | @make_recursive_func 60 | def tocuda(vars): 61 | if isinstance(vars, torch.Tensor): 62 | return vars.to(torch.device("cuda")) 63 | elif isinstance(vars, str): 64 | return vars 65 | else: 66 | raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) 67 | 68 | 69 | def tb_save_scalars(logger, mode, scalar_dict, global_step): 70 | scalar_dict = tensor2float(scalar_dict) 71 | for key, value in scalar_dict.items(): 72 | if not isinstance(value, (list, tuple)): 73 | name = '{}/{}'.format(mode, key) 74 | logger.add_scalar(name, value, global_step) 75 | else: 76 | for idx in range(len(value)): 77 | name = '{}/{}_{}'.format(mode, key, idx) 78 | logger.add_scalar(name, value[idx], global_step) 79 | 80 | 81 | def tb_save_images(logger, mode, images_dict, global_step): 82 | images_dict = tensor2numpy(images_dict) 83 | 84 | def preprocess(name, img): 85 | if not (len(img.shape) == 3 or len(img.shape) == 4): 86 | raise NotImplementedError("invalid img shape {}:{} in save_images".format(name, img.shape)) 87 | if len(img.shape) == 3: 88 | img = img[:, np.newaxis, :, :] 89 | img = torch.from_numpy(img[:1]) 90 | return vutils.make_grid(img, padding=0, nrow=1, normalize=True, scale_each=True) 91 | 92 | for key, value in images_dict.items(): 93 | if not isinstance(value, (list, tuple)): 94 | name = '{}/{}'.format(mode, key) 95 | logger.add_image(name, preprocess(name, value), global_step) 96 | else: 97 | for idx in range(len(value)): 98 | name = '{}/{}_{}'.format(mode, key, idx) 99 | logger.add_image(name, preprocess(name, value[idx]), global_step) 100 | 101 | 102 | class DictAverageMeter(object): 103 | def __init__(self): 104 | self.data = {} 105 | self.count = 0 106 | 107 | def update(self, new_input): 108 | self.count += 1 109 | if len(self.data) == 0: 110 | for k, v in new_input.items(): 111 | if not isinstance(v, float): 112 | raise NotImplementedError("invalid data {}: {}".format(k, type(v))) 113 | self.data[k] = v 114 | else: 115 | for k, v in new_input.items(): 116 | if not isinstance(v, float): 117 | raise NotImplementedError("invalid data {}: {}".format(k, type(v))) 118 | self.data[k] += v 119 | 120 | def mean(self): 121 | return {k: v / self.count for k, v in self.data.items()} 122 | 123 | 124 | # a wrapper to compute metrics for each image individually 125 | def compute_metrics_for_each_image(metric_func): 126 | def wrapper(depth_est, depth_gt, mask, *args): 127 | batch_size = depth_gt.shape[0] 128 | results = [] 129 | # compute result one by one 130 | for idx in range(batch_size): 131 | ret = metric_func(depth_est[idx], depth_gt[idx], mask[idx], *args) 132 | results.append(ret) 133 | return torch.stack(results).mean() 134 | 135 | return wrapper 136 | 137 | 138 | @make_nograd_func 139 | @compute_metrics_for_each_image 140 | def Thres_metrics(depth_est, depth_gt, mask, thres): 141 | assert isinstance(thres, (int, float)) 142 | depth_est, depth_gt = depth_est[mask], depth_gt[mask] 143 | errors = torch.abs(depth_est - depth_gt) 144 | err_mask = errors > thres 145 | return torch.mean(err_mask.float()) 146 | 147 | 148 | # NOTE: please do not use this to build up training loss 149 | @make_nograd_func 150 | @compute_metrics_for_each_image 151 | def AbsDepthError_metrics(depth_est, depth_gt, mask, thres=None): 152 | depth_est, depth_gt = depth_est[mask], depth_gt[mask] 153 | error = (depth_est - depth_gt).abs() 154 | if thres is not None: 155 | error = error[(error >= float(thres[0])) & (error <= float(thres[1]))] 156 | if error.shape[0] == 0: 157 | return torch.tensor(0, device=error.device, dtype=error.dtype) 158 | return torch.mean(error) 159 | 160 | 161 | import torch.distributed as dist 162 | def synchronize(): 163 | """ 164 | Helper function to synchronize (barrier) among all processes when 165 | using distributed training 166 | """ 167 | if not dist.is_available(): 168 | return 169 | if not dist.is_initialized(): 170 | return 171 | world_size = dist.get_world_size() 172 | if world_size == 1: 173 | return 174 | dist.barrier() 175 | 176 | 177 | def get_world_size(): 178 | if not dist.is_available(): 179 | return 1 180 | if not dist.is_initialized(): 181 | return 1 182 | return dist.get_world_size() 183 | 184 | 185 | def reduce_scalar_outputs(scalar_outputs): 186 | world_size = get_world_size() 187 | if world_size < 2: 188 | return scalar_outputs 189 | with torch.no_grad(): 190 | names = [] 191 | scalars = [] 192 | for k in sorted(scalar_outputs.keys()): 193 | names.append(k) 194 | scalars.append(scalar_outputs[k]) 195 | scalars = torch.stack(scalars, dim=0) 196 | dist.reduce(scalars, dst=0) 197 | if dist.get_rank() == 0: 198 | # only main process gets accumulated, so only divide by 199 | # world_size in this case 200 | scalars /= world_size 201 | reduced_scalars = {k: v for k, v in zip(names, scalars)} 202 | 203 | return reduced_scalars 204 | 205 | 206 | import torch 207 | from bisect import bisect_right 208 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 209 | def __init__( 210 | self, 211 | optimizer, 212 | milestones, 213 | gamma=0.1, 214 | warmup_factor=1.0 / 3, 215 | warmup_iters=500, 216 | warmup_method="linear", 217 | last_epoch=-1, 218 | ): 219 | if not list(milestones) == sorted(milestones): 220 | raise ValueError( 221 | "Milestones should be a list of" " increasing integers. Got {}", 222 | milestones, 223 | ) 224 | 225 | if warmup_method not in ("constant", "linear"): 226 | raise ValueError( 227 | "Only 'constant' or 'linear' warmup_method accepted" 228 | "got {}".format(warmup_method) 229 | ) 230 | self.milestones = milestones 231 | self.gamma = gamma 232 | self.warmup_factor = warmup_factor 233 | self.warmup_iters = warmup_iters 234 | self.warmup_method = warmup_method 235 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 236 | 237 | def get_lr(self): 238 | warmup_factor = 1 239 | if self.last_epoch < self.warmup_iters: 240 | if self.warmup_method == "constant": 241 | warmup_factor = self.warmup_factor 242 | elif self.warmup_method == "linear": 243 | alpha = float(self.last_epoch) / self.warmup_iters 244 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 245 | return [ 246 | base_lr 247 | * warmup_factor 248 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 249 | for base_lr in self.base_lrs 250 | ] 251 | 252 | 253 | def set_random_seed(seed): 254 | random.seed(seed) 255 | np.random.seed(seed) 256 | torch.manual_seed(seed) 257 | torch.cuda.manual_seed_all(seed) -------------------------------------------------------------------------------- /datasets/blendedmvs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Description: Data preprocessing and organization for BlendedMVS dataset. 3 | # @Author: Zhe Zhang (doublez@stu.pku.edu.cn) 4 | # @Affiliation: Peking University (PKU) 5 | # @LastEditDate: 2023-09-07 6 | 7 | import os 8 | import cv2 9 | import random 10 | import numpy as np 11 | from PIL import Image 12 | 13 | from torch.utils.data import Dataset 14 | from torchvision import transforms as T 15 | 16 | from datasets.data_io import * 17 | 18 | 19 | def motion_blur(img: np.ndarray, max_kernel_size=3): 20 | # Either vertial, hozirontal or diagonal blur 21 | mode = np.random.choice(['h', 'v', 'diag_down', 'diag_up']) 22 | ksize = np.random.randint(0, (max_kernel_size + 1) / 2) * 2 + 1 # make sure is odd 23 | center = int((ksize - 1) / 2) 24 | kernel = np.zeros((ksize, ksize)) 25 | if mode == 'h': 26 | kernel[center, :] = 1. 27 | elif mode == 'v': 28 | kernel[:, center] = 1. 29 | elif mode == 'diag_down': 30 | kernel = np.eye(ksize) 31 | elif mode == 'diag_up': 32 | kernel = np.flip(np.eye(ksize), 0) 33 | var = ksize * ksize / 16. 34 | grid = np.repeat(np.arange(ksize)[:, np.newaxis], ksize, axis=-1) 35 | gaussian = np.exp(-(np.square(grid - center) + np.square(grid.T - center)) / (2. * var)) 36 | kernel *= gaussian 37 | kernel /= np.sum(kernel) 38 | img = cv2.filter2D(img, -1, kernel) 39 | return img 40 | 41 | 42 | class BlendedMVSDataset(Dataset): 43 | def __init__(self, root_dir, list_file, split, n_views, **kwargs): 44 | super(BlendedMVSDataset, self).__init__() 45 | 46 | self.levels = 4 47 | self.root_dir = root_dir 48 | self.list_file = list_file 49 | self.split = split 50 | self.n_views = n_views 51 | 52 | assert self.split in ['train', 'val', 'all'] 53 | 54 | self.scale_factors = {} 55 | self.scale_factor = 0 56 | 57 | self.img_wh = kwargs.get("img_wh", (768, 576)) 58 | assert self.img_wh[0]%32==0 and self.img_wh[1]%32==0, \ 59 | 'img_wh must both be multiples of 2^5!' 60 | 61 | self.robust_train = kwargs.get("robust_train", True) 62 | self.augment = kwargs.get("augment", True) 63 | if self.augment: 64 | self.color_augment = T.ColorJitter(brightness=0.25, contrast=(0.3, 1.5)) 65 | 66 | self.metas = self.build_metas() 67 | 68 | 69 | def build_metas(self): 70 | metas = [] 71 | with open(self.list_file) as f: 72 | self.scans = [line.rstrip() for line in f.readlines()] 73 | for scan in self.scans: 74 | with open(os.path.join(self.root_dir, scan, "cams/pair.txt")) as f: 75 | num_viewpoint = int(f.readline()) 76 | for _ in range(num_viewpoint): 77 | ref_view = int(f.readline().rstrip()) 78 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 79 | if len(src_views) >= self.n_views-1: 80 | metas += [(scan, ref_view, src_views)] 81 | return metas 82 | 83 | 84 | def read_cam_file(self, scan, filename): 85 | with open(filename) as f: 86 | lines = f.readlines() 87 | lines = [line.rstrip() for line in lines] 88 | # extrinsics: line [1,5), 4x4 matrix 89 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 90 | # intrinsics: line [7-10), 3x3 matrix 91 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 92 | depth_min = float(lines[11].split()[0]) 93 | depth_max = float(lines[11].split()[-1]) 94 | 95 | if scan not in self.scale_factors: 96 | self.scale_factors[scan] = 100.0 / depth_min 97 | depth_min *= self.scale_factors[scan] 98 | depth_max *= self.scale_factors[scan] 99 | extrinsics[:3, 3] *= self.scale_factors[scan] 100 | 101 | return intrinsics, extrinsics, depth_min, depth_max 102 | 103 | 104 | def read_depth_mask(self, scan, filename, depth_min, depth_max, scale): 105 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) 106 | depth = (depth * self.scale_factors[scan]) * scale 107 | 108 | mask = (depth>=depth_min) & (depth<=depth_max) 109 | assert mask.sum() > 0 110 | mask = mask.astype(np.float32) 111 | if self.img_wh is not None: 112 | depth = cv2.resize(depth, self.img_wh, interpolation=cv2.INTER_NEAREST) 113 | h, w = depth.shape 114 | depth_ms = {} 115 | mask_ms = {} 116 | 117 | for i in range(self.levels): 118 | depth_cur = cv2.resize(depth, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST) 119 | mask_cur = cv2.resize(mask, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST) 120 | 121 | depth_ms[f"stage{self.levels-i}"] = depth_cur 122 | mask_ms[f"stage{self.levels-i}"] = mask_cur 123 | 124 | return depth_ms, mask_ms 125 | 126 | 127 | def read_img(self, filename): 128 | img = Image.open(filename) 129 | 130 | if self.augment: 131 | img = self.color_augment(img) 132 | img = motion_blur(np.array(img, dtype=np.float32)) 133 | 134 | np_img = np.array(img, dtype=np.float32) / 255. 135 | return np_img 136 | 137 | 138 | def __len__(self): 139 | return len(self.metas) 140 | 141 | 142 | def __getitem__(self, idx): 143 | meta = self.metas[idx] 144 | scan, ref_view, src_views = meta 145 | 146 | if self.robust_train: 147 | num_src_views = len(src_views) 148 | index = random.sample(range(num_src_views), self.n_views - 1) 149 | view_ids = [ref_view] + [src_views[i] for i in index] 150 | scale_ratio = random.uniform(0.8, 1.25) 151 | else: 152 | view_ids = [ref_view] + src_views[:self.n_views - 1] 153 | scale_ratio = 1 154 | 155 | imgs = [] 156 | mask = None 157 | depth = None 158 | depth_min = None 159 | depth_max = None 160 | 161 | proj={} 162 | proj_matrices_0 = [] 163 | proj_matrices_1 = [] 164 | proj_matrices_2 = [] 165 | proj_matrices_3 = [] 166 | 167 | for i, vid in enumerate(view_ids): 168 | img_filename = os.path.join(self.root_dir, '{}/blended_images/{:0>8}.jpg'.format(scan, vid)) 169 | depth_filename = os.path.join(self.root_dir, '{}/rendered_depth_maps/{:0>8}.pfm'.format(scan, vid)) 170 | proj_mat_filename = os.path.join(self.root_dir, '{}/cams/{:0>8}_cam.txt'.format(scan, vid)) 171 | 172 | img = self.read_img(img_filename) 173 | imgs.append(img.transpose(2,0,1)) 174 | 175 | intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(scan, proj_mat_filename) 176 | 177 | proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32) 178 | proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32) 179 | proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32) 180 | proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32) 181 | extrinsics[:3, 3] *= scale_ratio 182 | intrinsics[:2,:] *= 0.125 183 | proj_mat_0[0,:4,:4] = extrinsics.copy() 184 | proj_mat_0[1,:3,:3] = intrinsics.copy() 185 | int_mat_0 = intrinsics.copy() 186 | 187 | intrinsics[:2,:] *= 2 188 | proj_mat_1[0,:4,:4] = extrinsics.copy() 189 | proj_mat_1[1,:3,:3] = intrinsics.copy() 190 | int_mat_1 = intrinsics.copy() 191 | 192 | intrinsics[:2,:] *= 2 193 | proj_mat_2[0,:4,:4] = extrinsics.copy() 194 | proj_mat_2[1,:3,:3] = intrinsics.copy() 195 | int_mat_2 = intrinsics.copy() 196 | 197 | intrinsics[:2,:] *= 2 198 | proj_mat_3[0,:4,:4] = extrinsics.copy() 199 | proj_mat_3[1,:3,:3] = intrinsics.copy() 200 | int_mat_3 = intrinsics.copy() 201 | 202 | proj_matrices_0.append(proj_mat_0) 203 | proj_matrices_1.append(proj_mat_1) 204 | proj_matrices_2.append(proj_mat_2) 205 | proj_matrices_3.append(proj_mat_3) 206 | 207 | # reference view 208 | if i == 0: 209 | depth_min = depth_min_ * scale_ratio 210 | depth_max = depth_max_ * scale_ratio 211 | depth, mask = self.read_depth_mask(scan, depth_filename, depth_min, depth_max, scale_ratio) 212 | for l in range(self.levels): 213 | mask[f'stage{l+1}'] = mask[f'stage{l+1}'] 214 | depth[f'stage{l+1}'] = depth[f'stage{l+1}'] 215 | 216 | proj['stage1'] = np.stack(proj_matrices_0) 217 | proj['stage2'] = np.stack(proj_matrices_1) 218 | proj['stage3'] = np.stack(proj_matrices_2) 219 | proj['stage4'] = np.stack(proj_matrices_3) 220 | 221 | intrinsics_matrices = { 222 | "stage1": int_mat_0, 223 | "stage2": int_mat_1, 224 | "stage3": int_mat_2, 225 | "stage4": int_mat_3 226 | } 227 | 228 | sample = { 229 | "imgs": imgs, 230 | "proj_matrices": proj, 231 | "intrinsics_matrices": intrinsics_matrices, 232 | "depth": depth, 233 | "depth_values": np.array([depth_min, depth_max], dtype=np.float32), 234 | "mask": mask 235 | } 236 | 237 | return sample -------------------------------------------------------------------------------- /models/geomvsnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Description: Main network architecture for GeoMVSNet. 3 | # @Author: Zhe Zhang (doublez@stu.pku.edu.cn) 4 | # @Affiliation: Peking University (PKU) 5 | # @LastEditDate: 2023-09-07 6 | 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from models.submodules import homo_warping, init_inverse_range, schedule_inverse_range, FPN, Reg2d 13 | from models.geometry import GeoFeatureFusion, GeoRegNet2d 14 | from models.filter import frequency_domain_filter 15 | 16 | 17 | class GeoMVSNet(nn.Module): 18 | def __init__(self, levels, hypo_plane_num_stages, depth_interal_ratio_stages, 19 | feat_base_channel, reg_base_channel, group_cor_dim_stages): 20 | super(GeoMVSNet, self).__init__() 21 | 22 | self.levels = levels 23 | self.hypo_plane_num_stages = hypo_plane_num_stages 24 | self.depth_interal_ratio_stages = depth_interal_ratio_stages 25 | 26 | self.StageNet = StageNet() 27 | 28 | # feature settings 29 | self.FeatureNet = FPN(base_channels=feat_base_channel) 30 | self.coarest_separate_flag = True 31 | if self.coarest_separate_flag: 32 | self.CoarestFeatureNet = FPN(base_channels=feat_base_channel) 33 | self.GeoFeatureFusionNet = GeoFeatureFusion( 34 | convolutional_layer_encoding="z", mask_type="basic", add_origin_feat_flag=True) 35 | 36 | # cost regularization settings 37 | self.RegNet_stages = nn.ModuleList() 38 | self.group_cor_dim_stages = group_cor_dim_stages 39 | self.geo_reg_flag = True 40 | self.geo_reg_encodings = ['std', 'z', 'z', 'z'] # must use std in idx-0 41 | for stage_idx in range(self.levels): 42 | in_dim = group_cor_dim_stages[stage_idx] 43 | if self.geo_reg_flag: 44 | self.RegNet_stages.append(GeoRegNet2d(input_channel=in_dim, base_channel=reg_base_channel, convolutional_layer_encoding=self.geo_reg_encodings[stage_idx])) 45 | else: 46 | self.RegNet_stages.append(Reg2d(input_channel=in_dim, base_channel=reg_base_channel)) 47 | 48 | # frequency domain filter settings 49 | self.curriculum_learning_rho_ratios = [9, 4, 2, 1] 50 | 51 | 52 | def forward(self, imgs, proj_matrices, intrinsics_matrices, depth_values, filename=None): 53 | 54 | features = [] 55 | if self.coarest_separate_flag: 56 | coarsest_features = [] 57 | for nview_idx in range(len(imgs)): 58 | img = imgs[nview_idx] 59 | features.append(self.FeatureNet(img)) # B C H W 60 | if self.coarest_separate_flag: 61 | coarsest_features.append(self.CoarestFeatureNet(img)) 62 | 63 | # coarse-to-fine 64 | outputs = {} 65 | for stage_idx in range(self.levels): 66 | stage_name = "stage{}".format(stage_idx + 1) 67 | B, C, H, W = features[0][stage_name].shape 68 | proj_matrices_stage = proj_matrices[stage_name] 69 | intrinsics_matrices_stage = intrinsics_matrices[stage_name] 70 | 71 | # @Note features 72 | if stage_idx == 0: 73 | if self.coarest_separate_flag: 74 | features_stage = [feat[stage_name] for feat in coarsest_features] 75 | else: 76 | features_stage = [feat[stage_name] for feat in features] 77 | elif stage_idx >= 1: 78 | features_stage = [feat[stage_name] for feat in features] 79 | 80 | ref_img_stage = F.interpolate(imgs[0], size=None, scale_factor=1./2**(3-stage_idx), mode="bilinear", align_corners=False) 81 | depth_last = F.interpolate(depth_last.unsqueeze(1), size=None, scale_factor=2, mode="bilinear", align_corners=False) 82 | confidence_last = F.interpolate(confidence_last.unsqueeze(1), size=None, scale_factor=2, mode="bilinear", align_corners=False) 83 | 84 | # reference feature 85 | features_stage[0] = self.GeoFeatureFusionNet( 86 | ref_img_stage, depth_last, confidence_last, depth_values, 87 | stage_idx, features_stage[0], intrinsics_matrices_stage 88 | ) 89 | 90 | 91 | # @Note depth hypos 92 | if stage_idx == 0: 93 | depth_hypo = init_inverse_range(depth_values, self.hypo_plane_num_stages[stage_idx], img[0].device, img[0].dtype, H, W) 94 | else: 95 | inverse_min_depth, inverse_max_depth = outputs_stage['inverse_min_depth'].detach(), outputs_stage['inverse_max_depth'].detach() 96 | depth_hypo = schedule_inverse_range(inverse_min_depth, inverse_max_depth, self.hypo_plane_num_stages[stage_idx], H, W) # B D H W 97 | 98 | 99 | # @Note cost regularization 100 | geo_reg_data = {} 101 | if self.geo_reg_flag: 102 | geo_reg_data['depth_values'] = depth_values 103 | if stage_idx >= 1 and self.geo_reg_encodings[stage_idx] == 'z': 104 | prob_volume_last = F.interpolate(prob_volume_last, size=None, scale_factor=2, mode="bilinear", align_corners=False) 105 | geo_reg_data["prob_volume_last"] = prob_volume_last 106 | 107 | outputs_stage = self.StageNet( 108 | stage_idx, features_stage, proj_matrices_stage, depth_hypo=depth_hypo, 109 | regnet=self.RegNet_stages[stage_idx], group_cor_dim=self.group_cor_dim_stages[stage_idx], 110 | depth_interal_ratio=self.depth_interal_ratio_stages[stage_idx], 111 | geo_reg_data=geo_reg_data 112 | ) 113 | 114 | 115 | # @Note frequency domain filter 116 | depth_est = outputs_stage['depth'] 117 | depth_est_filtered = frequency_domain_filter(depth_est, rho_ratio=self.curriculum_learning_rho_ratios[stage_idx]) 118 | outputs_stage['depth_filtered'] = depth_est_filtered 119 | depth_last = depth_est_filtered 120 | 121 | 122 | confidence_last = outputs_stage['photometric_confidence'] 123 | prob_volume_last = outputs_stage['prob_volume'] 124 | 125 | outputs[stage_name] = outputs_stage 126 | outputs.update(outputs_stage) 127 | 128 | return outputs 129 | 130 | 131 | class StageNet(nn.Module): 132 | def __init__(self, attn_temp=2): 133 | super(StageNet, self).__init__() 134 | self.attn_temp = attn_temp 135 | 136 | def forward(self, stage_idx, features, proj_matrices, depth_hypo, regnet, 137 | group_cor_dim, depth_interal_ratio, geo_reg_data=None): 138 | 139 | # @Note step1: feature extraction 140 | proj_matrices = torch.unbind(proj_matrices, 1) 141 | ref_feature, src_features = features[0], features[1:] 142 | ref_proj, src_projs = proj_matrices[0], proj_matrices[1:] 143 | B, D, H, W = depth_hypo.shape 144 | C = ref_feature.shape[1] 145 | 146 | 147 | # @Note step2: cost aggregation 148 | ref_volume = ref_feature.unsqueeze(2).repeat(1, 1, D, 1, 1) 149 | cor_weight_sum = 1e-8 150 | cor_feats = 0 151 | for src_idx, (src_fea, src_proj) in enumerate(zip(src_features, src_projs)): 152 | save_fn = None 153 | src_proj_new = src_proj[:, 0].clone() 154 | src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3], src_proj[:, 0, :3, :4]) 155 | ref_proj_new = ref_proj[:, 0].clone() 156 | ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3], ref_proj[:, 0, :3, :4]) 157 | warped_src = homo_warping(src_fea, src_proj_new, ref_proj_new, depth_hypo) # B C D H W 158 | 159 | warped_src = warped_src.reshape(B, group_cor_dim, C//group_cor_dim, D, H, W) 160 | ref_volume = ref_volume.reshape(B, group_cor_dim, C//group_cor_dim, D, H, W) 161 | cor_feat = (warped_src * ref_volume).mean(2) # B G D H W 162 | del warped_src, src_proj, src_fea 163 | 164 | cor_weight = torch.softmax(cor_feat.sum(1) / self.attn_temp, 1) / math.sqrt(C) # B D H W 165 | cor_weight_sum += cor_weight # B D H W 166 | cor_feats += cor_weight.unsqueeze(1) * cor_feat # B C D H W 167 | del cor_weight, cor_feat 168 | 169 | cost_volume = cor_feats / cor_weight_sum.unsqueeze(1) # B C D H W 170 | del cor_weight_sum, src_features 171 | 172 | 173 | # @Note step3: cost regularization 174 | if geo_reg_data == {}: 175 | # basic 176 | cost_reg = regnet(cost_volume) 177 | else: 178 | # probability volume geometry embedding 179 | cost_reg = regnet(cost_volume, stage_idx, geo_reg_data) 180 | del cost_volume 181 | prob_volume = F.softmax(cost_reg, dim=1) # B D H W 182 | 183 | 184 | # @Note step4: depth regression 185 | prob_max_indices = prob_volume.max(1, keepdim=True)[1] # B 1 H W 186 | depth = torch.gather(depth_hypo, 1, prob_max_indices).squeeze(1) # B H W 187 | 188 | with torch.no_grad(): 189 | photometric_confidence = prob_volume.max(1)[0] # B H W 190 | photometric_confidence = F.interpolate(photometric_confidence.unsqueeze(1), scale_factor=1, mode='bilinear', align_corners=True).squeeze(1) 191 | 192 | last_depth_itv = 1./depth_hypo[:,2,:,:] - 1./depth_hypo[:,1,:,:] 193 | inverse_min_depth = 1/depth + depth_interal_ratio * last_depth_itv # B H W 194 | inverse_max_depth = 1/depth - depth_interal_ratio * last_depth_itv # B H W 195 | 196 | 197 | output_stage = { 198 | "depth": depth, 199 | "photometric_confidence": photometric_confidence, 200 | "depth_hypo": depth_hypo, 201 | "prob_volume": prob_volume, 202 | "inverse_min_depth": inverse_min_depth, 203 | "inverse_max_depth": inverse_max_depth, 204 | } 205 | return output_stage -------------------------------------------------------------------------------- /models/submodules.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Description: Some sub-modules for the network. 3 | # @Author: Zhe Zhang (doublez@stu.pku.edu.cn) 4 | # @Affiliation: Peking University (PKU) 5 | # @LastEditDate: 2023-09-07 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class FPN(nn.Module): 13 | """FPN aligncorners downsample 4x""" 14 | def __init__(self, base_channels, gn=False): 15 | super(FPN, self).__init__() 16 | self.base_channels = base_channels 17 | 18 | self.conv0 = nn.Sequential( 19 | Conv2d(3, base_channels, 3, 1, padding=1, gn=gn), 20 | Conv2d(base_channels, base_channels, 3, 1, padding=1, gn=gn), 21 | ) 22 | 23 | self.conv1 = nn.Sequential( 24 | Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2, gn=gn), 25 | Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1, gn=gn), 26 | Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1, gn=gn), 27 | ) 28 | 29 | self.conv2 = nn.Sequential( 30 | Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2, gn=gn), 31 | Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1, gn=gn), 32 | Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1, gn=gn), 33 | ) 34 | 35 | self.conv3 = nn.Sequential( 36 | Conv2d(base_channels * 4, base_channels * 8, 5, stride=2, padding=2, gn=gn), 37 | Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1, gn=gn), 38 | Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1, gn=gn), 39 | ) 40 | 41 | self.out_channels = [8 * base_channels] 42 | final_chs = base_channels * 8 43 | 44 | self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True) 45 | self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True) 46 | self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True) 47 | 48 | self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False) 49 | self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False) 50 | self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False) 51 | self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False) 52 | 53 | self.out_channels.append(base_channels * 4) 54 | self.out_channels.append(base_channels * 2) 55 | self.out_channels.append(base_channels) 56 | 57 | def forward(self, x): 58 | conv0 = self.conv0(x) 59 | conv1 = self.conv1(conv0) 60 | conv2 = self.conv2(conv1) 61 | conv3 = self.conv3(conv2) 62 | 63 | intra_feat = conv3 64 | outputs = {} 65 | out1 = self.out1(intra_feat) 66 | 67 | intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv2) 68 | out2 = self.out2(intra_feat) 69 | 70 | intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv1) 71 | out3 = self.out3(intra_feat) 72 | 73 | intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner3(conv0) 74 | out4 = self.out4(intra_feat) 75 | 76 | outputs["stage1"] = out1 77 | outputs["stage2"] = out2 78 | outputs["stage3"] = out3 79 | outputs["stage4"] = out4 80 | 81 | return outputs 82 | 83 | 84 | class Reg2d(nn.Module): 85 | def __init__(self, input_channel=128, base_channel=32): 86 | super(Reg2d, self).__init__() 87 | 88 | self.conv0 = ConvBnReLU3D(input_channel, base_channel, kernel_size=(1,3,3), pad=(0,1,1)) 89 | self.conv1 = ConvBnReLU3D(base_channel, base_channel*2, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1)) 90 | self.conv2 = ConvBnReLU3D(base_channel*2, base_channel*2) 91 | 92 | self.conv3 = ConvBnReLU3D(base_channel*2, base_channel*4, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1)) 93 | self.conv4 = ConvBnReLU3D(base_channel*4, base_channel*4) 94 | 95 | self.conv5 = ConvBnReLU3D(base_channel*4, base_channel*8, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1)) 96 | self.conv6 = ConvBnReLU3D(base_channel*8, base_channel*8) 97 | 98 | self.conv7 = nn.Sequential( 99 | nn.ConvTranspose3d(base_channel*8, base_channel*4, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False), 100 | nn.BatchNorm3d(base_channel*4), 101 | nn.ReLU(inplace=True)) 102 | 103 | self.conv9 = nn.Sequential( 104 | nn.ConvTranspose3d(base_channel*4, base_channel*2, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False), 105 | nn.BatchNorm3d(base_channel*2), 106 | nn.ReLU(inplace=True)) 107 | 108 | self.conv11 = nn.Sequential( 109 | nn.ConvTranspose3d(base_channel*2, base_channel, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False), 110 | nn.BatchNorm3d(base_channel), 111 | nn.ReLU(inplace=True)) 112 | 113 | self.prob = nn.Conv3d(8, 1, 1, stride=1, padding=0) 114 | 115 | def forward(self, x): 116 | conv0 = self.conv0(x) 117 | conv2 = self.conv2(self.conv1(conv0)) 118 | conv4 = self.conv4(self.conv3(conv2)) 119 | x = self.conv6(self.conv5(conv4)) 120 | x = conv4 + self.conv7(x) 121 | x = conv2 + self.conv9(x) 122 | x = conv0 + self.conv11(x) 123 | x = self.prob(x) 124 | 125 | return x.squeeze(1) 126 | 127 | 128 | def homo_warping(src_fea, src_proj, ref_proj, depth_values): 129 | # src_fea: [B, C, H, W] 130 | # src_proj: [B, 4, 4] 131 | # ref_proj: [B, 4, 4] 132 | # depth_values: [B, Ndepth] o [B, Ndepth, H, W] 133 | # out: [B, C, Ndepth, H, W] 134 | C = src_fea.shape[1] 135 | Hs,Ws = src_fea.shape[-2:] 136 | B,num_depth,Hr,Wr = depth_values.shape 137 | 138 | with torch.no_grad(): 139 | proj = torch.matmul(src_proj, torch.inverse(ref_proj)) 140 | rot = proj[:, :3, :3] # [B,3,3] 141 | trans = proj[:, :3, 3:4] # [B,3,1] 142 | 143 | y, x = torch.meshgrid([torch.arange(0, Hr, dtype=torch.float32, device=src_fea.device), 144 | torch.arange(0, Wr, dtype=torch.float32, device=src_fea.device)]) 145 | y = y.reshape(Hr*Wr) 146 | x = x.reshape(Hr*Wr) 147 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W] 148 | xyz = torch.unsqueeze(xyz, 0).repeat(B, 1, 1) # [B, 3, H*W] 149 | rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W] 150 | rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_values.reshape(B, 1, num_depth, -1) # [B, 3, Ndepth, H*W] 151 | proj_xyz = rot_depth_xyz + trans.reshape(B, 3, 1, 1) # [B, 3, Ndepth, H*W] 152 | # FIXME divide 0 153 | temp = proj_xyz[:, 2:3, :, :] 154 | temp[temp==0] = 1e-9 155 | proj_xy = proj_xyz[:, :2, :, :] / temp # [B, 2, Ndepth, H*W] 156 | # proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W] 157 | 158 | proj_x_normalized = proj_xy[:, 0, :, :] / ((Ws - 1) / 2) - 1 159 | proj_y_normalized = proj_xy[:, 1, :, :] / ((Hs - 1) / 2) - 1 160 | proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2] 161 | grid = proj_xy 162 | if len(src_fea.shape)==4: 163 | warped_src_fea = F.grid_sample(src_fea, grid.reshape(B, num_depth * Hr, Wr, 2), mode='bilinear', padding_mode='zeros', align_corners=True) 164 | warped_src_fea = warped_src_fea.reshape(B, C, num_depth, Hr, Wr) 165 | elif len(src_fea.shape)==5: 166 | warped_src_fea = [] 167 | for d in range(src_fea.shape[2]): 168 | warped_src_fea.append(F.grid_sample(src_fea[:,:,d], grid.reshape(B, num_depth, Hr, Wr, 2)[:,d], mode='bilinear', padding_mode='zeros', align_corners=True)) 169 | warped_src_fea = torch.stack(warped_src_fea, dim=2) 170 | 171 | return warped_src_fea 172 | 173 | 174 | def init_inverse_range(cur_depth, ndepths, device, dtype, H, W): 175 | inverse_depth_min = 1. / cur_depth[:, 0] # (B,) 176 | inverse_depth_max = 1. / cur_depth[:, -1] 177 | itv = torch.arange(0, ndepths, device=device, dtype=dtype, requires_grad=False).reshape(1, -1,1,1).repeat(1, 1, H, W) / (ndepths - 1) # 1 D H W 178 | inverse_depth_hypo = inverse_depth_max[:,None, None, None] + (inverse_depth_min - inverse_depth_max)[:,None, None, None] * itv 179 | 180 | return 1./inverse_depth_hypo 181 | 182 | 183 | def schedule_inverse_range(inverse_min_depth, inverse_max_depth, ndepths, H, W): 184 | # cur_depth_min, (B, H, W) 185 | # cur_depth_max: (B, H, W) 186 | itv = torch.arange(0, ndepths, device=inverse_min_depth.device, dtype=inverse_min_depth.dtype, requires_grad=False).reshape(1, -1,1,1).repeat(1, 1, H//2, W//2) / (ndepths - 1) # 1 D H W 187 | 188 | inverse_depth_hypo = inverse_max_depth[:,None, :, :] + (inverse_min_depth - inverse_max_depth)[:,None, :, :] * itv # B D H W 189 | inverse_depth_hypo = F.interpolate(inverse_depth_hypo.unsqueeze(1), [ndepths, H, W], mode='trilinear', align_corners=True).squeeze(1) 190 | return 1./inverse_depth_hypo 191 | 192 | 193 | # -------------------------------------------------------------- 194 | 195 | 196 | def init_bn(module): 197 | if module.weight is not None: 198 | nn.init.ones_(module.weight) 199 | if module.bias is not None: 200 | nn.init.zeros_(module.bias) 201 | return 202 | 203 | 204 | def init_uniform(module, init_method): 205 | if module.weight is not None: 206 | if init_method == "kaiming": 207 | nn.init.kaiming_uniform_(module.weight) 208 | elif init_method == "xavier": 209 | nn.init.xavier_uniform_(module.weight) 210 | return 211 | 212 | 213 | class ConvBnReLU3D(nn.Module): 214 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): 215 | super(ConvBnReLU3D, self).__init__() 216 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) 217 | self.bn = nn.BatchNorm3d(out_channels) 218 | 219 | def forward(self, x): 220 | return F.relu(self.bn(self.conv(x)), inplace=True) 221 | 222 | 223 | class Conv2d(nn.Module): 224 | 225 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 226 | relu=True, bn_momentum=0.1, init_method="xavier", gn=False, group_channel=8, **kwargs): 227 | super(Conv2d, self).__init__() 228 | bn = not gn 229 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, 230 | bias=(not bn), **kwargs) 231 | self.kernel_size = kernel_size 232 | self.stride = stride 233 | self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None 234 | self.gn = nn.GroupNorm(int(max(1, out_channels / group_channel)), out_channels) if gn else None 235 | self.relu = relu 236 | 237 | def forward(self, x): 238 | x = self.conv(x) 239 | if self.bn is not None: 240 | x = self.bn(x) 241 | else: 242 | x = self.gn(x) 243 | if self.relu: 244 | x = F.relu(x, inplace=True) 245 | return x 246 | 247 | def init_weights(self, init_method): 248 | init_uniform(self.conv, init_method) 249 | if self.bn is not None: 250 | init_bn(self.bn) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /datasets/dtu.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Description: Data preprocessing and organization for DTU dataset. 3 | # @Author: Zhe Zhang (doublez@stu.pku.edu.cn) 4 | # @Affiliation: Peking University (PKU) 5 | # @LastEditDate: 2023-09-07 6 | 7 | import os 8 | import cv2 9 | import random 10 | import numpy as np 11 | from PIL import Image 12 | 13 | from torchvision import transforms 14 | from torch.utils.data import Dataset 15 | 16 | from datasets.data_io import * 17 | 18 | 19 | class DTUDataset(Dataset): 20 | def __init__(self, root_dir, list_file, mode, n_views, **kwargs): 21 | super(DTUDataset, self).__init__() 22 | 23 | self.root_dir = root_dir 24 | self.list_file = list_file 25 | self.mode = mode 26 | self.n_views = n_views 27 | 28 | assert self.mode in ["train", "val", "test"] 29 | 30 | self.total_depths = 192 31 | self.interval_scale = 1.06 32 | 33 | self.data_scale = kwargs.get("data_scale", "mid") # mid / raw 34 | self.robust_train = kwargs.get("robust_train", False) # True / False 35 | self.color_augment = transforms.ColorJitter(brightness=0.5, contrast=0.5) 36 | 37 | if self.mode == "test": 38 | self.max_wh = kwargs.get("max_wh", (1600, 1200)) 39 | 40 | self.metas = self.build_metas() 41 | 42 | 43 | def build_metas(self): 44 | metas = [] 45 | 46 | with open(os.path.join(self.list_file)) as f: 47 | scans = [line.rstrip() for line in f.readlines()] 48 | 49 | pair_file = "Cameras/pair.txt" 50 | for scan in scans: 51 | with open(os.path.join(self.root_dir, pair_file)) as f: 52 | num_viewpoint = int(f.readline()) 53 | 54 | # viewpoints (49) 55 | for _ in range(num_viewpoint): 56 | ref_view = int(f.readline().rstrip()) 57 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 58 | 59 | if self.mode == "train": 60 | # light conditions 0-6 61 | for light_idx in range(7): 62 | metas.append((scan, light_idx, ref_view, src_views)) 63 | elif self.mode in ["test", "val"]: 64 | if len(src_views) < self.n_views: 65 | print("{} < num_views:{}".format(len(src_views), self.n_views)) 66 | src_views += [src_views[0]] * (self.n_views - len(src_views)) 67 | metas.append((scan, 3, ref_view, src_views)) 68 | 69 | print("DTU Dataset in", self.mode, "mode metas:", len(metas)) 70 | return metas 71 | 72 | 73 | def __len__(self): 74 | return len(self.metas) 75 | 76 | 77 | def read_cam_file(self, filename): 78 | with open(filename) as f: 79 | lines = f.readlines() 80 | lines = [line.rstrip() for line in lines] 81 | # extrinsics: line [1,5), 4x4 matrix 82 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 83 | # intrinsics: line [7-10), 3x3 matrix 84 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 85 | 86 | if self.mode == "test": 87 | intrinsics[:2, :] /= 4.0 88 | 89 | # depth_min & depth_interval: line 11 90 | depth_min = float(lines[11].split()[0]) 91 | depth_interval = float(lines[11].split()[1]) 92 | 93 | if len(lines[11].split()) >= 3: 94 | num_depth = lines[11].split()[2] 95 | depth_max = depth_min + int(float(num_depth)) * depth_interval 96 | depth_interval = (depth_max - depth_min) / self.total_depths 97 | 98 | depth_interval *= self.interval_scale 99 | 100 | return intrinsics, extrinsics, depth_min, depth_interval 101 | 102 | 103 | def read_img(self, filename): 104 | img = Image.open(filename) 105 | if self.mode == "train" and self.robust_train: 106 | img = self.color_augment(img) 107 | # scale 0~255 to 0~1 108 | np_img = np.array(img, dtype=np.float32) / 255. 109 | return np_img 110 | 111 | 112 | def crop_img(self, img): 113 | raw_h, raw_w = img.shape[:2] 114 | start_h = (raw_h-1024)//2 115 | start_w = (raw_w-1280)//2 116 | return img[start_h:start_h+1024, start_w:start_w+1280, :] # (1024, 1280) 117 | 118 | 119 | def prepare_img(self, hr_img): 120 | h, w = hr_img.shape 121 | if self.data_scale == "mid": 122 | hr_img_ds = cv2.resize(hr_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST) 123 | h, w = hr_img_ds.shape 124 | target_h, target_w = 512, 640 125 | start_h, start_w = (h - target_h)//2, (w - target_w)//2 126 | hr_img_crop = hr_img_ds[start_h: start_h + target_h, start_w: start_w + target_w] 127 | elif self.data_scale == "raw": 128 | hr_img_crop = hr_img[h//2-1024//2:h//2+1024//2, w//2-1280//2:w//2+1280//2] # (1024, 1280) 129 | return hr_img_crop 130 | 131 | 132 | def scale_mvs_input(self, img, intrinsics, max_w, max_h, base=64): 133 | h, w = img.shape[:2] 134 | if h > max_h or w > max_w: 135 | scale = 1.0 * max_h / h 136 | if scale * w > max_w: 137 | scale = 1.0 * max_w / w 138 | new_w, new_h = scale * w // base * base, scale * h // base * base 139 | else: 140 | new_w, new_h = 1.0 * w // base * base, 1.0 * h // base * base 141 | 142 | scale_w = 1.0 * new_w / w 143 | scale_h = 1.0 * new_h / h 144 | intrinsics[0, :] *= scale_w 145 | intrinsics[1, :] *= scale_h 146 | 147 | img = cv2.resize(img, (int(new_w), int(new_h))) 148 | 149 | return img, intrinsics 150 | 151 | 152 | def read_mask_hr(self, filename): 153 | img = Image.open(filename) 154 | np_img = np.array(img, dtype=np.float32) 155 | np_img = (np_img > 10).astype(np.float32) 156 | np_img = self.prepare_img(np_img) 157 | 158 | h, w = np_img.shape 159 | np_img_ms = { 160 | "stage1": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_NEAREST), 161 | "stage2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_NEAREST), 162 | "stage3": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST), 163 | "stage4": np_img, 164 | } 165 | return np_img_ms 166 | 167 | 168 | def read_depth_hr(self, filename, scale): 169 | depth_hr = np.array(read_pfm(filename)[0], dtype=np.float32) * scale 170 | depth_lr = self.prepare_img(depth_hr) 171 | 172 | h, w = depth_lr.shape 173 | depth_lr_ms = { 174 | "stage1": cv2.resize(depth_lr, (w//8, h//8), interpolation=cv2.INTER_NEAREST), 175 | "stage2": cv2.resize(depth_lr, (w//4, h//4), interpolation=cv2.INTER_NEAREST), 176 | "stage3": cv2.resize(depth_lr, (w//2, h//2), interpolation=cv2.INTER_NEAREST), 177 | "stage4": depth_lr, 178 | } 179 | return depth_lr_ms 180 | 181 | 182 | def __getitem__(self, idx): 183 | scan, light_idx, ref_view, src_views = self.metas[idx] 184 | 185 | if self.mode == "train" and self.robust_train: 186 | num_src_views = len(src_views) 187 | index = random.sample(range(num_src_views), self.n_views-1) 188 | view_ids = [ref_view] + [src_views[i] for i in index] 189 | scale_ratio = random.uniform(0.8, 1.25) 190 | else: 191 | view_ids = [ref_view] + src_views[:self.n_views-1] 192 | scale_ratio = 1 193 | 194 | imgs = [] 195 | mask = None 196 | depth_values = None 197 | proj_matrices = [] 198 | 199 | for i, vid in enumerate(view_ids): 200 | # @Note image & cam 201 | if self.mode in ["train", "val"]: 202 | if self.data_scale == "mid": 203 | img_filename = os.path.join(self.root_dir, 'Rectified/{}_train/rect_{:0>3}_{}_r5000.png'.format(scan, vid+1, light_idx)) 204 | elif self.data_scale == "raw": 205 | img_filename = os.path.join(self.root_dir, 'Rectified_raw/{}/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx)) 206 | proj_mat_filename = os.path.join(self.root_dir, 'Cameras/train/{:0>8}_cam.txt').format(vid) 207 | elif self.mode == "test": 208 | img_filename = os.path.join(self.root_dir, 'Rectified/{}/rect_{:0>3}_3_r5000.png'.format(scan, vid+1)) 209 | proj_mat_filename = os.path.join(self.root_dir, 'Cameras/{:0>8}_cam.txt'.format(vid)) 210 | 211 | img = self.read_img(img_filename) 212 | intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename) 213 | 214 | if self.mode in ["train", "val"]: 215 | if self.data_scale == "raw": 216 | img = self.crop_img(img) 217 | intrinsics[:2, :] *= 2.0 218 | if self.mode == "train" and self.robust_train: 219 | extrinsics[:3,3] *= scale_ratio 220 | elif self.mode == "test": 221 | img, intrinsics = self.scale_mvs_input(img, intrinsics, self.max_wh[0], self.max_wh[1]) 222 | 223 | imgs.append(img.transpose(2,0,1)) 224 | 225 | # reference view 226 | if i == 0: 227 | # @Note depth values 228 | diff = 0.5 if self.mode in ["test", "val"] else 0 229 | depth_max = depth_interval * (self.total_depths - diff) + depth_min 230 | depth_values = np.array([depth_min * scale_ratio, depth_max * scale_ratio], dtype=np.float32) 231 | 232 | # @Note depth & mask 233 | if self.mode in ["train", "val"]: 234 | depth_filename_hr = os.path.join(self.root_dir, 'Depths_raw/{}/depth_map_{:0>4}.pfm'.format(scan, vid)) 235 | depth = self.read_depth_hr(depth_filename_hr, scale_ratio) 236 | 237 | mask_filename_hr = os.path.join(self.root_dir, 'Depths_raw/{}/depth_visual_{:0>4}.png'.format(scan, vid)) 238 | mask = self.read_mask_hr(mask_filename_hr) 239 | 240 | proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32) 241 | proj_mat[0, :4, :4] = extrinsics 242 | proj_mat[1, :3, :3] = intrinsics 243 | proj_matrices.append(proj_mat) 244 | 245 | proj_matrices = np.stack(proj_matrices) 246 | intrinsics = np.stack(intrinsics) 247 | stage1_pjmats = proj_matrices.copy() 248 | stage1_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] / 2.0 249 | stage1_ins = intrinsics.copy() 250 | stage1_ins[:2, :] = intrinsics[:2, :] / 2.0 251 | stage3_pjmats = proj_matrices.copy() 252 | stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2 253 | stage3_ins = intrinsics.copy() 254 | stage3_ins[:2, :] = intrinsics[:2, :] * 2.0 255 | stage4_pjmats = proj_matrices.copy() 256 | stage4_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4 257 | stage4_ins = intrinsics.copy() 258 | stage4_ins[:2, :] = intrinsics[:2, :] * 4.0 259 | proj_matrices = { 260 | "stage1": stage1_pjmats, 261 | "stage2": proj_matrices, 262 | "stage3": stage3_pjmats, 263 | "stage4": stage4_pjmats 264 | } 265 | intrinsics_matrices = { 266 | "stage1": stage1_ins, 267 | "stage2": intrinsics, 268 | "stage3": stage3_ins, 269 | "stage4": stage4_ins 270 | } 271 | 272 | sample = { 273 | "imgs": imgs, 274 | "proj_matrices": proj_matrices, 275 | "intrinsics_matrices": intrinsics_matrices, 276 | "depth_values": depth_values 277 | } 278 | if self.mode in ["train", "val"]: 279 | sample["depth"] = depth 280 | sample["mask"] = mask 281 | elif self.mode == "test": 282 | sample["filename"] = scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}" 283 | 284 | return sample -------------------------------------------------------------------------------- /fusions/dtu/_open3d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Description: Point cloud fusion strategy for DTU dataset based on Open3D Library. 3 | # @Author: Zhe Zhang (doublez@stu.pku.edu.cn) 4 | # @Affiliation: Peking University (PKU) 5 | # @LastEditDate: 2023-09-07 6 | 7 | import torch 8 | import numpy as np 9 | import sys 10 | import argparse 11 | import errno, os 12 | import glob 13 | import os.path as osp 14 | import re 15 | import cv2 16 | from PIL import Image 17 | import gc 18 | import open3d as o3d 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | import numpy as np 23 | 24 | 25 | parser = argparse.ArgumentParser(description='Depth fusion with consistency check.') 26 | parser.add_argument('--root_path', type=str, default='[/path/to/]dtu-test-1200') 27 | parser.add_argument('--depth_path', type=str, default='') 28 | parser.add_argument('--data_list', type=str, default='') 29 | parser.add_argument('--ply_path', type=str, default='') 30 | parser.add_argument('--dist_thresh', type=float, default=0.001) 31 | parser.add_argument('--prob_thresh', type=float, default=0.6) 32 | parser.add_argument('--num_consist', type=int, default=10) 33 | parser.add_argument('--device', type=str, default='cpu') 34 | 35 | args = parser.parse_args() 36 | 37 | 38 | def homo_warping(src_fea, src_proj, ref_proj, depth_values): 39 | # src_fea: [B, C, H, W] 40 | # src_proj: [B, 4, 4] 41 | # ref_proj: [B, 4, 4] 42 | # depth_values: [B, Ndepth] o [B, Ndepth, H, W] 43 | # out: [B, C, Ndepth, H, W] 44 | batch, channels = src_fea.shape[0], src_fea.shape[1] 45 | height, width = src_fea.shape[2], src_fea.shape[3] 46 | 47 | with torch.no_grad(): 48 | proj = torch.matmul(src_proj, torch.inverse(ref_proj)) 49 | rot = proj[:, :3, :3] # [B,3,3] 50 | trans = proj[:, :3, 3:4] # [B,3,1] 51 | 52 | y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=src_fea.device), 53 | torch.arange(0, width, dtype=torch.float32, device=src_fea.device)]) 54 | y, x = y.contiguous(), x.contiguous() 55 | y, x = y.view(height * width), x.view(height * width) 56 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W] 57 | xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W] 58 | rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W] 59 | 60 | rot_depth_xyz = rot_xyz.unsqueeze(2) * depth_values.view(-1, 1, 1, height*width) # [B, 3, 1, H*W] 61 | 62 | proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W] 63 | proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W] 64 | proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1 65 | proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1 66 | proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2] 67 | grid = proj_xy 68 | 69 | warped_src_fea = F.grid_sample(src_fea, grid.view(batch, height, width, 2), mode='bilinear', 70 | padding_mode='zeros') 71 | warped_src_fea = warped_src_fea.view(batch, channels, height, width) 72 | 73 | return warped_src_fea 74 | 75 | 76 | def generate_points_from_depth(depth, proj): 77 | ''' 78 | :param depth: (B, 1, H, W) 79 | :param proj: (B, 4, 4) 80 | :return: point_cloud (B, 3, H, W) 81 | ''' 82 | batch, height, width = depth.shape[0], depth.shape[2], depth.shape[3] 83 | inv_proj = torch.inverse(proj) 84 | 85 | rot = inv_proj[:, :3, :3] # [B,3,3] 86 | trans = inv_proj[:, :3, 3:4] # [B,3,1] 87 | 88 | y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=depth.device), 89 | torch.arange(0, width, dtype=torch.float32, device=depth.device)]) 90 | y, x = y.contiguous(), x.contiguous() 91 | y, x = y.view(height * width), x.view(height * width) 92 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W] 93 | xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W] 94 | rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W] 95 | rot_depth_xyz = rot_xyz * depth.view(batch, 1, -1) 96 | proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1) # [B, 3, H*W] 97 | proj_xyz = proj_xyz.view(batch, 3, height, width) 98 | 99 | return proj_xyz 100 | 101 | 102 | def mkdir_p(path): 103 | try: 104 | os.makedirs(path) 105 | except OSError as exc: 106 | if exc.errno == errno.EEXIST and os.path.isdir(path): 107 | pass 108 | else: 109 | raise 110 | 111 | 112 | def read_pfm(filename): 113 | file = open(filename, 'rb') 114 | color = None 115 | width = None 116 | height = None 117 | scale = None 118 | endian = None 119 | 120 | header = file.readline().decode('utf-8').rstrip() 121 | if header == 'PF': 122 | color = True 123 | elif header == 'Pf': 124 | color = False 125 | else: 126 | raise Exception('Not a PFM file.') 127 | 128 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 129 | if dim_match: 130 | width, height = map(int, dim_match.groups()) 131 | else: 132 | raise Exception('Malformed PFM header.') 133 | 134 | scale = float(file.readline().rstrip()) 135 | if scale < 0: # little-endian 136 | endian = '<' 137 | scale = -scale 138 | else: 139 | endian = '>' # big-endian 140 | 141 | data = np.fromfile(file, endian + 'f') 142 | shape = (height, width, 3) if color else (height, width) 143 | 144 | data = np.reshape(data, shape) 145 | data = np.flipud(data) 146 | file.close() 147 | return data, scale 148 | 149 | 150 | def write_pfm(file, image, scale=1): 151 | file = open(file, 'wb') 152 | color = None 153 | if image.dtype.name != 'float32': 154 | raise Exception('Image dtype must be float32.') 155 | 156 | image = np.flipud(image) 157 | 158 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 159 | color = True 160 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 161 | color = False 162 | else: 163 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 164 | 165 | file.write('PF\n'.encode() if color else 'Pf\n'.encode()) 166 | file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) 167 | 168 | endian = image.dtype.byteorder 169 | 170 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 171 | scale = -scale 172 | 173 | file.write('%f\n'.encode() % scale) 174 | 175 | image_string = image.tostring() 176 | file.write(image_string) 177 | file.close() 178 | 179 | 180 | def write_ply(file, points): 181 | pcd = o3d.geometry.PointCloud() 182 | pcd.points = o3d.utility.Vector3dVector(points[:, :3]) 183 | pcd.colors = o3d.utility.Vector3dVector(points[:, 3:] / 255.) 184 | o3d.io.write_point_cloud(file, pcd, write_ascii=False) 185 | 186 | 187 | def filter_depth(ref_depth, src_depths, ref_proj, src_projs): 188 | ''' 189 | :param ref_depth: (1, 1, H, W) 190 | :param src_depths: (B, 1, H, W) 191 | :param ref_proj: (1, 4, 4) 192 | :param src_proj: (B, 4, 4) 193 | :return: ref_pc: (1, 3, H, W), aligned_pcs: (B, 3, H, W), dist: (B, 1, H, W) 194 | ''' 195 | 196 | ref_pc = generate_points_from_depth(ref_depth, ref_proj) 197 | src_pcs = generate_points_from_depth(src_depths, src_projs) 198 | 199 | aligned_pcs = homo_warping(src_pcs, src_projs, ref_proj, ref_depth) 200 | 201 | x_2 = (ref_pc[:, 0] - aligned_pcs[:, 0])**2 202 | y_2 = (ref_pc[:, 1] - aligned_pcs[:, 1])**2 203 | z_2 = (ref_pc[:, 2] - aligned_pcs[:, 2])**2 204 | dist = torch.sqrt(x_2 + y_2 + z_2).unsqueeze(1) 205 | 206 | return ref_pc, aligned_pcs, dist 207 | 208 | 209 | def parse_cameras(path): 210 | cam_txt = open(path).readlines() 211 | f = lambda xs: list(map(lambda x: list(map(float, x.strip().split())), xs)) 212 | 213 | extr_mat = f(cam_txt[1:5]) 214 | intr_mat = f(cam_txt[7:10]) 215 | 216 | extr_mat = np.array(extr_mat, np.float32) 217 | intr_mat = np.array(intr_mat, np.float32) 218 | 219 | return extr_mat, intr_mat 220 | 221 | 222 | def load_data(root_path, depth_path, scene_name, thresh): 223 | 224 | depths = [] 225 | projs = [] 226 | rgbs = [] 227 | 228 | for view in range(49): 229 | img_filename = "{}/{}/images/{:08d}.jpg".format(depth_path, scene_name, view) 230 | cam_filename = "{}/{}/cams/{:08d}_cam.txt".format(depth_path, scene_name, view) 231 | depth_filename = "{}/{}/depth_est/{:08d}.pfm".format(depth_path, scene_name, view) 232 | confidence_filename = "{}/{}/confidence/{:08d}.pfm".format(depth_path, scene_name, view) 233 | 234 | 235 | extr_mat, intr_mat = parse_cameras(cam_filename) 236 | proj_mat = np.eye(4) 237 | proj_mat[:3, :4] = np.dot(intr_mat[:3, :3], extr_mat[:3, :4]) 238 | projs.append(torch.from_numpy(proj_mat)) 239 | 240 | dep_map, _ = read_pfm(depth_filename) 241 | h, w = dep_map.shape 242 | conf_map, _ = read_pfm(confidence_filename) 243 | conf_map = cv2.resize(conf_map, (w, h), interpolation=cv2.INTER_LINEAR) 244 | 245 | dep_map = dep_map * (conf_map>thresh).astype(np.float32) 246 | depths.append(torch.from_numpy(dep_map).unsqueeze(0)) 247 | 248 | rgb = np.array(Image.open(img_filename)) 249 | rgbs.append(rgb) 250 | 251 | depths = torch.stack(depths).float() 252 | projs = torch.stack(projs).float() 253 | if args.device == 'cuda' and torch.cuda.is_available(): 254 | depths = depths.cuda() 255 | projs = projs.cuda() 256 | 257 | return depths, projs, rgbs 258 | 259 | 260 | def extract_points(pc, mask, rgb): 261 | pc = pc.cpu().numpy() 262 | mask = mask.cpu().numpy() 263 | 264 | mask = np.reshape(mask, (-1,)) 265 | pc = np.reshape(pc, (-1, 3)) 266 | rgb = np.reshape(rgb, (-1, 3)) 267 | 268 | points = pc[np.where(mask)] 269 | colors = rgb[np.where(mask)] 270 | 271 | points_with_color = np.concatenate([points, colors], axis=1) 272 | 273 | return points_with_color 274 | 275 | 276 | def open3d_filter(): 277 | with torch.no_grad(): 278 | mkdir_p(args.ply_path) 279 | all_scenes = open(args.data_list, 'r').readlines() 280 | all_scenes = list(map(str.strip, all_scenes)) 281 | 282 | for i, scene in enumerate(all_scenes): 283 | 284 | print("{}/{} {}:".format(i, len(all_scenes), scene), '------------------------') 285 | 286 | depths, projs, rgbs = load_data(args.root_path, args.depth_path, scene, args.prob_thresh) 287 | tot_frame = depths.shape[0] 288 | height, width = depths.shape[2], depths.shape[3] 289 | points = [] 290 | 291 | print('Scene: {} total: {} frames'.format(scene, tot_frame)) 292 | for i in range(tot_frame): 293 | pc_buff = torch.zeros((3, height, width), device=depths.device, dtype=depths.dtype) 294 | val_cnt = torch.zeros((1, height, width), device=depths.device, dtype=depths.dtype) 295 | j = 0 296 | batch_size = 20 297 | 298 | while True: 299 | ref_pc, pcs, dist = filter_depth(ref_depth=depths[i:i+1], src_depths=depths[j:min(j+batch_size, tot_frame)], 300 | ref_proj=projs[i:i+1], src_projs=projs[j:min(j+batch_size, tot_frame)]) 301 | masks = (dist < args.dist_thresh).float() 302 | masked_pc = pcs * masks 303 | pc_buff += masked_pc.sum(dim=0, keepdim=False) 304 | val_cnt += masks.sum(dim=0, keepdim=False) 305 | 306 | j += batch_size 307 | if j >= tot_frame: 308 | break 309 | 310 | final_mask = (val_cnt >= args.num_consist).squeeze(0) 311 | avg_points = torch.div(pc_buff, val_cnt).permute(1, 2, 0) 312 | 313 | final_pc = extract_points(avg_points, final_mask, rgbs[i]) 314 | points.append(final_pc) 315 | if i==0 or i==tot_frame-1: 316 | print('Processing {} {}/{} ...'.format(scene, i+1, tot_frame)) 317 | 318 | ply_id = int(scene[4:]) 319 | write_ply('{}/mvsnet{:03d}.ply'.format(args.ply_path, ply_id), np.concatenate(points, axis=0)) 320 | del points, depths, rgbs, projs 321 | 322 | gc.collect() 323 | 324 | print('Save {}/mvsnet{:03d}.ply successful.'.format(args.ply_path, ply_id)) 325 | 326 | 327 | if __name__ == '__main__': 328 | open3d_filter() 329 | -------------------------------------------------------------------------------- /fusions/dtu/gipuma.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Description: Point cloud fusion strategy for DTU dataset: Gipuma (fusibile). 3 | # Refer to: https://github.com/YoYo000/MVSNet/blob/master/mvsnet/depthfusion.py 4 | # @Author: Zhe Zhang (doublez@stu.pku.edu.cn) 5 | # @Affiliation: Peking University (PKU) 6 | # @LastEditDate: 2023-09-07 7 | 8 | from __future__ import print_function 9 | 10 | import os, re, sys, shutil 11 | from struct import * 12 | import numpy as np 13 | import argparse 14 | import cv2 15 | from tensorflow.python.lib.io import file_io 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--root_dir', type=str, default='[/path/to]/dtu-test-1200', help='root directory of dtu dataset') 20 | parser.add_argument('--list_file', type=str, default='datasets/lists/dtu/train.txt', help='file contains the scans list') 21 | 22 | parser.add_argument('--depth_folder', type=str, default = './outputs/') 23 | parser.add_argument('--out_folder', type=str, default = 'fusibile_fused') 24 | parser.add_argument('--plydir', type=str, default='') 25 | parser.add_argument('--quandir', type=str, default='') 26 | parser.add_argument('--fusibile_exe_path', type=str, default = 'fusion/fusibile') 27 | parser.add_argument('--prob_threshold', type=float, default = '0.8') 28 | parser.add_argument('--disp_threshold', type=float, default = '0.13') 29 | parser.add_argument('--num_consistent', type=float, default = '3') 30 | parser.add_argument('--downsample_factor', type=int, default='1') 31 | 32 | args = parser.parse_args() 33 | 34 | 35 | # preprocess ==================================== 36 | 37 | def load_cam(file, interval_scale=1): 38 | """ read camera txt file """ 39 | cam = np.zeros((2, 4, 4)) 40 | words = file.read().split() 41 | # read extrinsic 42 | for i in range(0, 4): 43 | for j in range(0, 4): 44 | extrinsic_index = 4 * i + j + 1 45 | cam[0][i][j] = words[extrinsic_index] 46 | 47 | # read intrinsic 48 | for i in range(0, 3): 49 | for j in range(0, 3): 50 | intrinsic_index = 3 * i + j + 18 51 | cam[1][i][j] = words[intrinsic_index] 52 | 53 | if len(words) == 29: 54 | cam[1][3][0] = words[27] 55 | cam[1][3][1] = float(words[28]) * interval_scale 56 | cam[1][3][2] = 1100 57 | cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2] 58 | elif len(words) == 30: 59 | cam[1][3][0] = words[27] 60 | cam[1][3][1] = float(words[28]) * interval_scale 61 | cam[1][3][2] = words[29] 62 | cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2] 63 | elif len(words) == 31: 64 | cam[1][3][0] = words[27] 65 | cam[1][3][1] = float(words[28]) * interval_scale 66 | cam[1][3][2] = words[29] 67 | cam[1][3][3] = words[30] 68 | else: 69 | cam[1][3][0] = 0 70 | cam[1][3][1] = 0 71 | cam[1][3][2] = 0 72 | cam[1][3][3] = 0 73 | 74 | return cam 75 | 76 | 77 | def load_pfm(file): 78 | color = None 79 | width = None 80 | height = None 81 | scale = None 82 | data_type = None 83 | header = file.readline().decode('UTF-8').rstrip() 84 | 85 | if header == 'PF': 86 | color = True 87 | elif header == 'Pf': 88 | color = False 89 | else: 90 | raise Exception('Not a PFM file.') 91 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('UTF-8')) 92 | if dim_match: 93 | width, height = map(int, dim_match.groups()) 94 | else: 95 | raise Exception('Malformed PFM header.') 96 | # scale = float(file.readline().rstrip()) 97 | scale = float((file.readline()).decode('UTF-8').rstrip()) 98 | if scale < 0: # little-endian 99 | data_type = ' 0, 1, 0)) 231 | mask_image = np.reshape(mask_image, (image_shape[0], image_shape[1], 1)) 232 | mask_image = np.tile(mask_image, [1, 1, 3]) 233 | mask_image = np.float32(mask_image) 234 | 235 | normal_image = np.multiply(normal_image, mask_image) 236 | normal_image = np.float32(normal_image) 237 | 238 | write_gipuma_dmb(out_normal_path, normal_image) 239 | return 240 | 241 | 242 | def mvsnet_to_gipuma(scan_folder, scan, root_dir, gipuma_point_folder): 243 | 244 | image_folder = os.path.join(root_dir, 'Rectified', scan) 245 | cam_folder = os.path.join(root_dir, 'Cameras') 246 | depth_folder = os.path.join(scan_folder, 'depth_est') 247 | 248 | gipuma_cam_folder = os.path.join(gipuma_point_folder, 'cams') 249 | gipuma_image_folder = os.path.join(gipuma_point_folder, 'images') 250 | if not os.path.isdir(gipuma_point_folder): 251 | os.mkdir(gipuma_point_folder) 252 | if not os.path.isdir(gipuma_cam_folder): 253 | os.mkdir(gipuma_cam_folder) 254 | if not os.path.isdir(gipuma_image_folder): 255 | os.mkdir(gipuma_image_folder) 256 | 257 | # convert cameras 258 | for view in range(0,49): 259 | in_cam_file = os.path.join(cam_folder, "{:08d}_cam.txt".format(view)) 260 | out_cam_file = os.path.join(gipuma_cam_folder, "{:08d}.png.P".format(view)) 261 | mvsnet_to_gipuma_cam(in_cam_file, out_cam_file) 262 | 263 | # copy images to gipuma image folder 264 | for view in range(0,49): 265 | in_image_file = os.path.join(image_folder, "rect_{:03d}_3_r5000.png".format(view+1))# Our image start from 1 266 | out_image_file = os.path.join(gipuma_image_folder, "{:08d}.png".format(view)) 267 | # shutil.copy(in_image_file, out_image_file) 268 | 269 | in_image = cv2.imread(in_image_file) 270 | out_image = cv2.resize(in_image, None, fx=1.0/args.downsample_factor, fy=1.0/args.downsample_factor, interpolation=cv2.INTER_LINEAR) 271 | cv2.imwrite(out_image_file, out_image) 272 | 273 | # convert depth maps and fake normal maps 274 | gipuma_prefix = '2333__' 275 | for view in range(0,49): 276 | 277 | sub_depth_folder = os.path.join(gipuma_point_folder, gipuma_prefix+"{:08d}".format(view)) 278 | if not os.path.isdir(sub_depth_folder): 279 | os.mkdir(sub_depth_folder) 280 | in_depth_pfm = os.path.join(depth_folder, "{:08d}_prob_filtered.pfm".format(view)) 281 | out_depth_dmb = os.path.join(sub_depth_folder, 'disp.dmb') 282 | fake_normal_dmb = os.path.join(sub_depth_folder, 'normals.dmb') 283 | mvsnet_to_gipuma_dmb(in_depth_pfm, out_depth_dmb) 284 | fake_gipuma_normal(out_depth_dmb, fake_normal_dmb) 285 | 286 | 287 | def probability_filter(scan_folder, prob_threshold): 288 | depth_folder = os.path.join(scan_folder, 'depth_est') 289 | prob_folder = os.path.join(scan_folder, 'confidence') 290 | 291 | # convert cameras 292 | for view in range(0,49): 293 | init_depth_map_path = os.path.join(depth_folder, "{:08d}.pfm".format(view)) # New dataset outputs depth start from 0. 294 | prob_map_path = os.path.join(prob_folder, "{:08d}.pfm".format(view)) # Same as above 295 | out_depth_map_path = os.path.join(depth_folder, "{:08d}_prob_filtered.pfm".format(view)) # Gipuma start from 0 296 | 297 | depth_map = load_pfm(open(init_depth_map_path)) 298 | prob_map = load_pfm(open(prob_map_path)) 299 | depth_map[prob_map < prob_threshold] = 0 300 | write_pfm(out_depth_map_path, depth_map) 301 | 302 | 303 | def depth_map_fusion(point_folder, fusibile_exe_path, disp_thresh, num_consistent): 304 | 305 | cam_folder = os.path.join(point_folder, 'cams') 306 | image_folder = os.path.join(point_folder, 'images') 307 | depth_min = 0.001 308 | depth_max = 100000 309 | normal_thresh = 360 310 | 311 | cmd = fusibile_exe_path 312 | cmd = cmd + ' -input_folder ' + point_folder + '/' 313 | cmd = cmd + ' -p_folder ' + cam_folder + '/' 314 | cmd = cmd + ' -images_folder ' + image_folder + '/' 315 | cmd = cmd + ' --depth_min=' + str(depth_min) 316 | cmd = cmd + ' --depth_max=' + str(depth_max) 317 | cmd = cmd + ' --normal_thresh=' + str(normal_thresh) 318 | cmd = cmd + ' --disp_thresh=' + str(disp_thresh) 319 | cmd = cmd + ' --num_consistent=' + str(num_consistent) 320 | print (cmd) 321 | os.system(cmd) 322 | 323 | return 324 | 325 | 326 | def collectPly(point_folder, scan_id): 327 | model_name = 'final3d_model.ply' 328 | model_dir = [item for item in os.listdir(point_folder) if item.startswith("consistencyCheck")][-1] 329 | 330 | old = os.path.join(point_folder, model_dir, model_name) 331 | fresh = os.path.join(args.plydir, "mvsnet") + scan_id.zfill(3) + ".ply" 332 | shutil.move(old, fresh) 333 | 334 | 335 | if __name__ == '__main__': 336 | 337 | root_dir = args.root_dir 338 | depth_folder = args.depth_folder 339 | out_folder = args.out_folder 340 | fusibile_exe_path = args.fusibile_exe_path 341 | prob_threshold = args.prob_threshold 342 | disp_threshold = args.disp_threshold 343 | num_consistent = args.num_consistent 344 | 345 | # Read test list 346 | testlist = args.list_file 347 | with open(testlist) as f: 348 | scans = f.readlines() 349 | scans = [line.rstrip() for line in scans] 350 | 351 | print("Start Gipuma(GPU) fusion!") 352 | 353 | if not os.path.isdir(args.plydir): 354 | os.mkdir(args.plydir) 355 | 356 | # Fusion 357 | for i, scan in enumerate(scans): 358 | print("{}/{} {}:".format(i, len(scans), scan), '------------------------') 359 | 360 | scan_folder = os.path.join(depth_folder, scan) 361 | fusibile_workspace = os.path.join(depth_folder, out_folder, scan) 362 | 363 | if not os.path.isdir(os.path.join(depth_folder, out_folder)): 364 | os.mkdir(os.path.join(depth_folder, out_folder)) 365 | 366 | if not os.path.isdir(fusibile_workspace): 367 | os.mkdir(fusibile_workspace) 368 | 369 | # probability filtering 370 | print ('filter depth map with probability map') 371 | probability_filter(scan_folder, prob_threshold) 372 | 373 | # convert to gipuma format 374 | print ('Convert mvsnet output to gipuma input') 375 | mvsnet_to_gipuma(scan_folder, scan, root_dir, fusibile_workspace) 376 | 377 | # depth map fusion with gipuma 378 | print ('Run depth map fusion & filter') 379 | depth_map_fusion(fusibile_workspace, fusibile_exe_path, disp_threshold, num_consistent) 380 | 381 | # collect .ply results to summary folder 382 | print('Collect {} ply'.format(scan)) 383 | collectPly(fusibile_workspace, scan[4:]) 384 | 385 | print("Gipuma(GPU) fusion done!") 386 | shutil.rmtree(os.path.join(depth_folder, out_folder)) 387 | print("fusibile_fused remove done!") -------------------------------------------------------------------------------- /fusions/dtu/pcd.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Description: Point cloud fusion strategy for DTU dataset: Basic PCD. 3 | # Refer to: https://github.com/xy-guo/MVSNet_pytorch/blob/master/eval.py 4 | # @Author: Zhe Zhang (doublez@stu.pku.edu.cn) 5 | # @Affiliation: Peking University (PKU) 6 | # @LastEditDate: 2023-09-07 7 | 8 | import argparse, os, sys, cv2, re, logging, time 9 | import numpy as np 10 | from plyfile import PlyData, PlyElement 11 | from PIL import Image 12 | 13 | from multiprocessing import Pool 14 | from functools import partial 15 | import signal 16 | 17 | 18 | parser = argparse.ArgumentParser(description='filter, and fuse') 19 | 20 | parser.add_argument('--testpath', default='[/path/to]/dtu-test-1200', help='testing data dir for some scenes') 21 | parser.add_argument('--testlist', default="datasets/lists/dtu/test.txt", help='testing scene list') 22 | 23 | parser.add_argument('--outdir', default='./outputs/[exp_name]', help='output dir') 24 | parser.add_argument('--logdir', default='./checkpoints/debug', help='the directory to save checkpoints/logs') 25 | parser.add_argument('--nolog', action='store_true', help='do not logging into .log file') 26 | parser.add_argument('--plydir', default='./outputs/[exp_name]/pcd_fusion_plys/', help='output dir') 27 | 28 | parser.add_argument('--num_worker', type=int, default=4, help='depth_filer worker') 29 | 30 | parser.add_argument('--conf', type=float, default=0.9, help='prob confidence') 31 | parser.add_argument('--thres_view', type=int, default=5, help='threshold of num view') 32 | 33 | args = parser.parse_args() 34 | 35 | 36 | def read_pfm(filename): 37 | file = open(filename, 'rb') 38 | color = None 39 | width = None 40 | height = None 41 | scale = None 42 | endian = None 43 | 44 | header = file.readline().decode('utf-8').rstrip() 45 | if header == 'PF': 46 | color = True 47 | elif header == 'Pf': 48 | color = False 49 | else: 50 | raise Exception('Not a PFM file.') 51 | 52 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 53 | if dim_match: 54 | width, height = map(int, dim_match.groups()) 55 | else: 56 | raise Exception('Malformed PFM header.') 57 | 58 | scale = float(file.readline().rstrip()) 59 | if scale < 0: # little-endian 60 | endian = '<' 61 | scale = -scale 62 | else: 63 | endian = '>' # big-endian 64 | 65 | data = np.fromfile(file, endian + 'f') 66 | shape = (height, width, 3) if color else (height, width) 67 | 68 | data = np.reshape(data, shape) 69 | data = np.flipud(data) 70 | file.close() 71 | return data, scale 72 | 73 | 74 | def read_camera_parameters(filename): 75 | with open(filename) as f: 76 | lines = f.readlines() 77 | lines = [line.rstrip() for line in lines] 78 | # extrinsics: line [1,5), 4x4 matrix 79 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 80 | # intrinsics: line [7-10), 3x3 matrix 81 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 82 | return intrinsics, extrinsics 83 | 84 | 85 | def read_img(filename): 86 | img = Image.open(filename) 87 | # scale 0~255 to 0~1 88 | np_img = np.array(img, dtype=np.float32) / 255. 89 | return np_img 90 | 91 | 92 | def read_mask(filename): 93 | return read_img(filename) > 0.5 94 | 95 | 96 | def save_mask(filename, mask): 97 | assert mask.dtype == np.bool 98 | mask = mask.astype(np.uint8) * 255 99 | Image.fromarray(mask).save(filename) 100 | 101 | 102 | def read_pair_file(filename): 103 | data = [] 104 | with open(filename) as f: 105 | num_viewpoint = int(f.readline()) 106 | # 49 viewpoints 107 | for view_idx in range(num_viewpoint): 108 | ref_view = int(f.readline().rstrip()) 109 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 110 | if len(src_views) > 0: 111 | data.append((ref_view, src_views)) 112 | return data 113 | 114 | 115 | # project the reference point cloud into the source view, then project back 116 | def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): 117 | width, height = depth_ref.shape[1], depth_ref.shape[0] 118 | ## step1. project reference pixels to the source view 119 | # reference view x, y 120 | x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) 121 | x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1]) 122 | # reference 3D space 123 | xyz_ref = np.matmul(np.linalg.inv(intrinsics_ref), 124 | np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1])) 125 | # source 3D space 126 | xyz_src = np.matmul(np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)), 127 | np.vstack((xyz_ref, np.ones_like(x_ref))))[:3] 128 | # source view x, y 129 | K_xyz_src = np.matmul(intrinsics_src, xyz_src) 130 | xy_src = K_xyz_src[:2] / K_xyz_src[2:3] 131 | 132 | ## step2. reproject the source view points with source view depth estimation 133 | # find the depth estimation of the source view 134 | x_src = xy_src[0].reshape([height, width]).astype(np.float32) 135 | y_src = xy_src[1].reshape([height, width]).astype(np.float32) 136 | sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR) 137 | # mask = sampled_depth_src > 0 138 | 139 | # source 3D space 140 | # NOTE that we should use sampled source-view depth_here to project back 141 | xyz_src = np.matmul(np.linalg.inv(intrinsics_src), 142 | np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1])) 143 | # reference 3D space 144 | xyz_reprojected = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)), 145 | np.vstack((xyz_src, np.ones_like(x_ref))))[:3] 146 | # source view x, y, depth 147 | depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32) 148 | K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected) 149 | xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3] 150 | x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32) 151 | y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32) 152 | 153 | return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src 154 | 155 | 156 | def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): 157 | width, height = depth_ref.shape[1], depth_ref.shape[0] 158 | x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) 159 | depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, 160 | depth_src, intrinsics_src, extrinsics_src) 161 | # check |p_reproj-p_1| < 1 162 | dist = np.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2) 163 | 164 | # check |d_reproj-d_1| / d_1 < 0.01 165 | depth_diff = np.abs(depth_reprojected - depth_ref) 166 | relative_depth_diff = depth_diff / depth_ref 167 | 168 | mask = np.logical_and(dist < 1, relative_depth_diff < 0.01) 169 | depth_reprojected[~mask] = 0 170 | 171 | return mask, depth_reprojected, x2d_src, y2d_src 172 | 173 | 174 | def filter_depth(pair_folder, scan_folder, out_folder, plyfilename): 175 | # the pair file 176 | pair_file = os.path.join(pair_folder, "pair.txt") 177 | # for the final point cloud 178 | vertexs = [] 179 | vertex_colors = [] 180 | 181 | pair_data = read_pair_file(pair_file) 182 | 183 | # for each reference view and the corresponding source views 184 | for ref_view, src_views in pair_data: 185 | # src_views = src_views[:args.num_view] 186 | # load the camera parameters 187 | ref_intrinsics, ref_extrinsics = read_camera_parameters( 188 | os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(ref_view))) 189 | # load the reference image 190 | ref_img = read_img(os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(ref_view))) 191 | # load the estimated depth of the reference view 192 | ref_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(ref_view)))[0] 193 | # load the photometric mask of the reference view 194 | confidence = read_pfm(os.path.join(out_folder, 'confidence/{:0>8}.pfm'.format(ref_view)))[0] 195 | photo_mask = confidence > args.conf 196 | 197 | all_srcview_depth_ests = [] 198 | all_srcview_x = [] 199 | all_srcview_y = [] 200 | all_srcview_geomask = [] 201 | 202 | # compute the geometric mask 203 | geo_mask_sum = 0 204 | for src_view in src_views: 205 | # camera parameters of the source view 206 | src_intrinsics, src_extrinsics = read_camera_parameters( 207 | os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(src_view))) 208 | # the estimated depth of the source view 209 | src_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(src_view)))[0] 210 | 211 | geo_mask, depth_reprojected, x2d_src, y2d_src = check_geometric_consistency(ref_depth_est, ref_intrinsics, ref_extrinsics, 212 | src_depth_est, 213 | src_intrinsics, src_extrinsics) 214 | geo_mask_sum += geo_mask.astype(np.int32) 215 | all_srcview_depth_ests.append(depth_reprojected) 216 | all_srcview_x.append(x2d_src) 217 | all_srcview_y.append(y2d_src) 218 | all_srcview_geomask.append(geo_mask) 219 | 220 | depth_est_averaged = (sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1) 221 | # at least 3 source views matched 222 | geo_mask = geo_mask_sum >= args.thres_view 223 | final_mask = np.logical_and(photo_mask, geo_mask) 224 | 225 | os.makedirs(os.path.join(out_folder, "mask"), exist_ok=True) 226 | save_mask(os.path.join(out_folder, "mask/{:0>8}_photo.png".format(ref_view)), photo_mask) 227 | save_mask(os.path.join(out_folder, "mask/{:0>8}_geo.png".format(ref_view)), geo_mask) 228 | save_mask(os.path.join(out_folder, "mask/{:0>8}_final.png".format(ref_view)), final_mask) 229 | 230 | logger.info("processing {}, ref-view{:0>2}, photo/geo/final-mask:{:.3f}/{:.3f}/{:.3f}".format(scan_folder, ref_view, 231 | photo_mask.mean(), 232 | geo_mask.mean(), final_mask.mean())) 233 | 234 | height, width = depth_est_averaged.shape[:2] 235 | x, y = np.meshgrid(np.arange(0, width), np.arange(0, height)) 236 | # valid_points = np.logical_and(final_mask, ~used_mask[ref_view]) 237 | valid_points = final_mask 238 | logger.info("valid_points: {}".format(valid_points.mean())) 239 | x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[valid_points] 240 | #color = ref_img[1:-16:4, 1::4, :][valid_points] # hardcoded for DTU dataset 241 | color = ref_img[valid_points] 242 | 243 | xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics), 244 | np.vstack((x, y, np.ones_like(x))) * depth) 245 | xyz_world = np.matmul(np.linalg.inv(ref_extrinsics), 246 | np.vstack((xyz_ref, np.ones_like(x))))[:3] 247 | vertexs.append(xyz_world.transpose((1, 0))) 248 | vertex_colors.append((color * 255).astype(np.uint8)) 249 | 250 | vertexs = np.concatenate(vertexs, axis=0) 251 | vertex_colors = np.concatenate(vertex_colors, axis=0) 252 | vertexs = np.array([tuple(v) for v in vertexs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) 253 | vertex_colors = np.array([tuple(v) for v in vertex_colors], dtype=[('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 254 | 255 | vertex_all = np.empty(len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr) 256 | for prop in vertexs.dtype.names: 257 | vertex_all[prop] = vertexs[prop] 258 | for prop in vertex_colors.dtype.names: 259 | vertex_all[prop] = vertex_colors[prop] 260 | 261 | el = PlyElement.describe(vertex_all, 'vertex') 262 | PlyData([el]).write(plyfilename) 263 | logger.info("saving the final model to " + plyfilename) 264 | 265 | 266 | def init_worker(): 267 | ''' 268 | Catch Ctrl+C signal to termiante workers 269 | ''' 270 | signal.signal(signal.SIGINT, signal.SIG_IGN) 271 | 272 | 273 | def pcd_filter_worker(scan): 274 | scan_id = int(scan[4:]) 275 | save_name = 'mvsnet{:0>3}.ply'.format(scan_id) 276 | 277 | pair_folder = os.path.join(args.testpath, "Cameras") 278 | scan_folder = os.path.join(args.outdir, scan) 279 | out_folder = os.path.join(args.outdir, scan) 280 | filter_depth(pair_folder, scan_folder, out_folder, os.path.join(args.plydir, save_name)) 281 | 282 | 283 | def pcd_filter(testlist, number_worker): 284 | 285 | partial_func = partial(pcd_filter_worker) 286 | 287 | p = Pool(number_worker, init_worker) 288 | try: 289 | p.map(partial_func, testlist) 290 | except KeyboardInterrupt: 291 | logger.info("....\nCaught KeyboardInterrupt, terminating workers") 292 | p.terminate() 293 | else: 294 | p.close() 295 | p.join() 296 | 297 | 298 | def initLogger(): 299 | logger = logging.getLogger() 300 | logger.setLevel(logging.INFO) 301 | curTime = time.strftime('%Y%m%d-%H%M', time.localtime(time.time())) 302 | if not os.path.isdir(args.logdir): 303 | os.mkdir(args.logdir) 304 | logfile = os.path.join(args.logdir, 'fusion-' + curTime + '.log') 305 | formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s") 306 | if not args.nolog: 307 | fileHandler = logging.FileHandler(logfile, mode='a') 308 | fileHandler.setFormatter(formatter) 309 | logger.addHandler(fileHandler) 310 | consoleHandler = logging.StreamHandler(sys.stdout) 311 | consoleHandler.setFormatter(formatter) 312 | logger.addHandler(consoleHandler) 313 | logger.info("Logger initialized.") 314 | logger.info("Writing logs to file: {}".format(logfile)) 315 | logger.info("Current time: {}".format(curTime)) 316 | 317 | return logger 318 | 319 | 320 | if __name__ == '__main__': 321 | 322 | logger = initLogger() 323 | 324 | if not os.path.isdir(args.plydir): 325 | os.mkdir(args.plydir) 326 | 327 | with open(args.testlist) as f: 328 | content = f.readlines() 329 | testlist = [line.rstrip() for line in content] 330 | 331 | pcd_filter(testlist, args.num_worker) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

GeoMVSNet: Learning Multi-View Stereo With Geometry Perception (CVPR 2023)

2 | 3 |
4 | Zhe Zhang, 5 | Rui Peng, 6 | Yuxi Hu, 7 | Ronggang Wang* 8 |
9 | 10 |
11 | 12 |
13 |   14 |   15 |   16 | 17 |
18 | 19 |
20 | 21 |
22 | 23 | 24 | 25 |
26 | 27 | 28 | ## 🔨 Setup 29 | 30 | ### 1.1 Requirements 31 | 32 | Use the following commands to build the `conda` environment. 33 | 34 | ```bash 35 | conda create -n geomvsnet python=3.8 36 | conda activate geomvsnet 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | ### 1.2 Datasets 41 | 42 | Download the following datasets and modify the corresponding local path in `scripts/data_path.sh`. 43 | 44 | #### DTU Dataset 45 | 46 | **Training data**. We use the same DTU training data as mentioned in MVSNet and CasMVSNet, please refer to [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view) and [Depth raw](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/cascade-stereo/CasMVSNet/dtu_data/dtu_train_hr/Depths_raw.zip) for data download. Optional, you should download the [Recitfied raw](http://roboimagedata2.compute.dtu.dk/data/MVS/Rectified.zip) if you want to train the model in raw image resolution. Unzip and organize them as: 47 | 48 | ``` 49 | dtu/ 50 | ├── Cameras 51 | ├── Depths 52 | ├── Depths_raw 53 | ├── Rectified 54 | └── Rectified_raw (optional) 55 | ``` 56 | 57 | **Testing data**. For convenience, we use the [DTU testing data](https://drive.google.com/file/d/1rX0EXlUL4prRxrRu2DgLJv2j7-tpUD4D/view?usp=sharing) processed by CVP-MVSNet. Also unzip and organize it as: 58 | 59 | ``` 60 | dtu-test/ 61 | ├── Cameras 62 | ├── Depths 63 | └── Rectified 64 | ``` 65 | 66 | > Please note that the images and lighting here are consistent with the original dataset. 67 | 68 | #### BlendedMVS Dataset 69 | 70 | Download the low image resolution version of [BlendedMVS dataset](https://drive.google.com/file/d/1ilxls-VJNvJnB7IaFj7P0ehMPr7ikRCb/view) and unzip it as: 71 | 72 | ``` 73 | blendedmvs/ 74 | └── dataset_low_res 75 | ├── ... 76 | └── 5c34529873a8df509ae57b58 77 | ``` 78 | 79 | #### Tanks and Temples Dataset 80 | 81 | Download the intermediate and advanced subsets of [Tanks and Temples dataset](https://drive.google.com/file/d/1YArOJaX9WVLJh4757uE8AEREYkgszrCo/view) and unzip them. If you want to use the short range version of camera parameters for `Intermediate` subset, unzip `short_range_caemeras_for_mvsnet.zip` and move `cam_[]` to the corresponding scenarios. 82 | 83 | ``` 84 | tnt/ 85 | ├── advanced 86 | │ ├── ... 87 | │ └── Temple 88 | │ ├── cams 89 | │ ├── images 90 | │ ├── pair.txt 91 | │ └── Temple.log 92 | └── intermediate 93 | ├── ... 94 | └── Train 95 | ├── cams 96 | ├── cams_train 97 | ├── images 98 | ├── pair.txt 99 | └── Train.log 100 | ``` 101 | 102 | 103 | ## 🚂 Training 104 | 105 | You can train GeoMVSNet from scratch on DTU dataset and BlendedMVS dataset. After suitable setting and training, you can get the training checkpoints model in `checkpoints/[Dataset]/[THISNAME]`, and the following outputs lied in the folder: 106 | - `events.out.tfevents*`: you can use `tensorboard` to monitor the training process. 107 | - `model_[epoch].ckpt`: we save a checkpoint every `--save_freq`. 108 | - `train-[TIME].log`: logged the detailed training message, you can refer to appropiate indicators to judge the quality of training. 109 | 110 | ### 2.1 DTU 111 | 112 | To train GeoMVSNet on DTU dataset, you can refer to `scripts/dtu/train_dtu.sh`, specify `THISNAME`, `CUDA_VISIBLE_DEVICES`, `batch_size`, etc. to meet your demand. And run: 113 | 114 | ```bash 115 | bash scripts/dtu/train_dtu.sh 116 | ``` 117 | 118 | The default training strategy we provide is the *distributed* training mode. If you want to use the *general* training mode, you can refer to the following code. 119 | 120 |
121 | general training script 122 | 123 | ```bash 124 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 train.py ${@} \ 125 | --which_dataset="dtu" --epochs=16 --logdir=$LOG_DIR \ 126 | --trainpath=$DTU_TRAIN_ROOT --testpath=$DTU_TRAIN_ROOT \ 127 | --trainlist="datasets/lists/dtu/train.txt" --testlist="datasets/lists/dtu/test.txt" \ 128 | \ 129 | --data_scale="mid" --n_views="5" --batch_size=16 --lr=0.025 --robust_train \ 130 | --lrepochs="1,3,5,7,9,11,13,15:1.5" 131 | ``` 132 | 133 |
134 | 135 | > It should be noted that two different training strategies need to adjust the `batch_size` and `lr` parameters to achieve the best training results. 136 | 137 | 138 | ### 2.2 BlendedMVS 139 | 140 | To train GeoMVSNet on BlendedMVS dataset, you can refer to `scripts/bled/train_blend.sh`, and also specify `THISNAME`, `CUDA_VISIBLE_DEVICES`, `batch_size`, etc. to meet your demand. And run: 141 | 142 | ```bash 143 | bash scripts/blend/train_blend.sh 144 | ``` 145 | 146 | By default, we use `7` viewpoints as input for the BlendedMVS training. Similarly, you can choose to use the *distributed* training mode or the *general* one as mentioned in 2.1. 147 | 148 | ## ⚗️ Testing 149 | 150 | ### 3.1 DTU 151 | 152 | For DTU testing, we use model trained on DTU training dataset. You can basically download our [DTU pretrained model](https://drive.google.com/file/d/147_UbjE87E-HB9sZ5yLDbckynH825nJd/view?usp=sharing) and put it into `checkpoints/dtu/geomvsnet/`. And perform *depth map estimation, point cloud fusion, and result evaluation* according to the following steps. 153 | 1. Run `bash scripts/dtu/test_dtu.sh` for depth map estimation. The results will be stored in `outputs/dtu/[THISNAME]/`, each scan folder holding `depth_est` and `confidence`, etc. 154 | - Use `outputs/visual.ipynb` for depth map visualization. 155 | 2. Run `bash scripts/dtu/fusion_dtu.sh` for point cloud fusion. We provide 3 different fusion methods, and we recommend the `open3d` option by default. After fusion, you can get `[FUSION_METHOD]_fusion_plys` under the experiment output folder, point clouds of each testing scan are there. 156 | 157 |
158 | (Optional) If you want to use the "Gipuma" fusion method. 159 | 160 | 1. Clone the [edited fusibile repo](https://github.com/YoYo000/fusibile). 161 | 2. Refer to [fusibile configuration blog (Chinese)](https://zhuanlan.zhihu.com/p/460212787) for building details. 162 | 3. Create a new python2.7 conda env. 163 | ```bash 164 | conda create -n fusibile python=2.7 165 | conda install scipy matplotlib 166 | conda install tensorflow==1.14.0 167 | conda install -c https://conda.anaconda.org/menpo opencv 168 | ``` 169 | 4. Use the `fusibile` conda environment for `gipuma` fusion method. 170 | 171 |
172 | 173 | 3. Download the [ObsMask](http://roboimagedata2.compute.dtu.dk/data/MVS/SampleSet.zip) and [Points](http://roboimagedata2.compute.dtu.dk/data/MVS/Points.zip) of DTU GT point clouds from the official website and organize them as: 174 | 175 | ``` 176 | dtu-evaluation/ 177 | ├── ObsMask 178 | └── Points 179 | ``` 180 | 181 | 4. Setup `Matlab` in command line mode, and run `bash scripts/dtu/matlab_quan_dtu.sh`. You can adjust the `num_at_once` config according to your machine's CPU and memory ceiling. After quantitative evaluation, you will get `[FUSION_METHOD]_quantitative/` and `[THISNAME].log` just store the quantitative results. 182 | 183 | ### 3.2 Tanks and Temples 184 | 185 | For testing on [Tanks and Temples benchmark](https://www.tanksandtemples.org/leaderboard/), you can use any of the following configurations: 186 | - Only train on DTU training dataset. 187 | - Only train on BlendedMVS dataset. 188 | - Pretrained on DTU training dataset and finetune on BlendedMVS dataset. (Recommend) 189 | 190 | After your personal training, also follow these steps: 191 | 1. Run `bash scripts/tnt/test_tnt.sh` for depth map estimation. The results will be stored in `outputs/[TRAINING_DATASET]/[THISNAME]/`. 192 | - Use `outputs/visual.ipynb` for depth map visualization. 193 | 2. Run `bash scripts/tnt/fusion_tnt.sh` for point cloud fusion. We provide the popular dynamic fusion strategy, and you can tune the fusion threshold in `fusions/tnt/dypcd.py`. 194 | 3. Follow the *Upload Instructions* on the [T&T official website](https://www.tanksandtemples.org/submit/) to make online submissions. 195 | 196 | ### 3.3 Custom Data (TODO) 197 | 198 | GeoMVSNet can reconstruct on custom data. At present, you can refer to [MVSNet](https://github.com/YoYo000/MVSNet#file-formats) to organize your data, and refer to the same steps as above for *depth estimation* and *point cloud fusion*. 199 | 200 | ## 💡 Results 201 | 202 | Our results on DTU and Tanks and Temples Dataset are listed in the tables. 203 | 204 | | DTU Dataset | Acc. ↓ | Comp. ↓ | Overall ↓ | 205 | | ----------- | ------ | ------- | --------- | 206 | | GeoMVSNet | 0.3309 | 0.2593 | 0.2951 | 207 | 208 | | T&T (Intermediate) | Mean ↑ | Family | Francis | Horse | Lighthouse | M60 | Panther | Playground | Train | 209 | | ------------------ | ------ | ------ | ------- | ----- | ---------- | ----- | ------- | ---------- | ----- | 210 | | GeoMVSNet | 65.89 | 81.64 | 67.53 | 55.78 | 68.02 | 65.49 | 67.19 | 63.27 | 58.22 | 211 | 212 | | T&T (Advanced) | Mean ↑ | Auditorium | Ballroom | Courtroom | Museum | Palace | Temple | 213 | | -------------- | ------ | ---------- | -------- | --------- | ------ | ------ | ------ | 214 | | GeoMVSNet | 41.52 | 30.23 | 46.53 | 39.98 | 53.05 | 35.98 | 43.34 | 215 | 216 | And you can download our [Point Cloud](https://disk.pku.edu.cn:443/link/69D473126C509C8DCBCC7E233FAAEEAA) and [Estimated Depth](https://disk.pku.edu.cn:443/link/4217EB2F063D2B10EDC711F54A12B5F7) for academic usage. 217 | 218 |
219 | 🌟 About Reproduce Paper Results 220 | 221 | 222 | In our experiment, we found that the reproduction of MVS network is relatively difficult. Therefore, we summarize some of the problems encountered in our experiment as follows, hoping to be helpful to you. 223 | 224 | **Q1. GPU Architecture Matters.** 225 | 226 | There are two commonly used NVIDIA GPU series: GeForce RTX (e.g. 4090Ti, 3090Ti, 2090Ti) and Tesla (e.g. V100, T4). We find that there is generally no performance degradation in training and testing on the same series of GPUs. But on the contrary, for example, if you train on V100 and test on 3090Ti, the visual effect of the depth map looks exactly the same, but each pixel value is not exactly the same. We conjecture that the two series or architectures differ in numerical computation and processing precision. 227 | 228 | > Our pretrained model is trained on NVIDIA V100 GPUs. 229 | 230 | **Q2. Pytorch Version Matters.** 231 | 232 | Different Cuda versions will result in different optional Pytorch versions. Different torch versions will affect the accuracy of network training and testing. One of the reasons we found is that the implementation and parameter control of the `F.grid_sample()` are various in different versions of Pytorch. 233 | 234 | **Q3. Training Hyperparameters Matters.** 235 | 236 | In the era of neural network, hyperparameters really matter. We made some network hyperparameters tuning, but it may not be the same as your configuration. Most fundamentally, due to differences in GPU graphics memory, you need to synchronize `batch_size` and `lr`. And the schedule of learning rate also matters. 237 | 238 | **Q4. Testing Epoch Matters.** 239 | 240 | By default, our model will train 16 epochs. But how to select the best training model for testing to achieve the best performance? One solution is to use [PyTorch-lightning](https://lightning.ai/docs/pytorch/latest/starter/introduction.html). For simplicity, you can decide which checkpoint to use based on the `.log` file we provide. 241 | 242 | **Q5. Fusion Hyperparameters Matters.** 243 | 244 | For both DTU and T&T datasets, the hyperparameters of point cloud fusion greatly affect the final performance. We have provided different fusion strategies and easy access to adjust parameters. Maybe you need to know the temperament of your model. 245 | 246 | Qx. Others, you can [raise an issue](https://github.com/doubleZ0108/GeoMVSNet/issues/new/choose) if you meet other problems. 247 | 248 |
249 | 250 |
251 | 252 | ## ⚖️ Citation 253 | ``` 254 | @InProceedings{zhe2023geomvsnet, 255 | title={GeoMVSNet: Learning Multi-View Stereo With Geometry Perception}, 256 | author={Zhang, Zhe and Peng, Rui and Hu, Yuxi and Wang, Ronggang}, 257 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 258 | pages={21508--21518}, 259 | year={2023} 260 | } 261 | ``` 262 | 263 | ## 💌 Acknowledgements 264 | 265 | This repository is partly based on [MVSNet](https://github.com/YoYo000/MVSNet), [MVSNet-pytorch](https://github.com/xy-guo/MVSNet_pytorch), [CVP-MVSNet](https://github.com/JiayuYANG/CVP-MVSNet), [cascade-stereo](https://github.com/alibaba/cascade-stereo), [MVSTER](https://github.com/JeffWang987/MVSTER). 266 | 267 | We appreciate their contributions to the MVS community. -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Description: Main process of network training & evaluation. 3 | # @Author: Zhe Zhang (doublez@stu.pku.edu.cn) 4 | # @Affiliation: Peking University (PKU) 5 | # @LastEditDate: 2023-09-07 6 | 7 | import os, sys, time, gc, datetime, logging, json 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim as optim 13 | import torch.distributed as dist 14 | from torch.utils.data import DataLoader 15 | from tensorboardX import SummaryWriter 16 | 17 | from datasets.dtu import DTUDataset 18 | from datasets.blendedmvs import BlendedMVSDataset 19 | 20 | from models.geomvsnet import GeoMVSNet 21 | from models.loss import geomvsnet_loss 22 | from models.utils import * 23 | from models.utils.opts import get_opts 24 | 25 | 26 | cudnn.benchmark = True 27 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 28 | is_distributed = num_gpus > 1 29 | 30 | args = get_opts() 31 | 32 | 33 | def train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, start_epoch, args): 34 | if args.lr_scheduler == 'MS': 35 | milestones = [len(TrainImgLoader) * int(epoch_idx) for epoch_idx in args.lrepochs.split(':')[0].split(',')] 36 | lr_gamma = 1 / float(args.lrepochs.split(':')[1]) 37 | lr_scheduler = WarmupMultiStepLR(optimizer, milestones, gamma=lr_gamma, warmup_factor=1.0/3, warmup_iters=500, last_epoch=len(TrainImgLoader) * start_epoch - 1) 38 | elif args.lr_scheduler == 'cos': 39 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(args.epochs*len(TrainImgLoader)), eta_min=0) 40 | elif args.lr_scheduler == 'onecycle': 41 | lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, total_steps=int(args.epochs*len(TrainImgLoader))) 42 | elif args.lr_scheduler == 'lambda': 43 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 0.9 ** ((epoch-1) / len(TrainImgLoader)), last_epoch=len(TrainImgLoader)*start_epoch-1) 44 | 45 | 46 | for epoch_idx in range(start_epoch, args.epochs): 47 | logger.info('Epoch {}:'.format(epoch_idx)) 48 | global_step = len(TrainImgLoader) * epoch_idx 49 | 50 | # training 51 | for batch_idx, sample in enumerate(TrainImgLoader): 52 | start_time = time.time() 53 | global_step = len(TrainImgLoader) * epoch_idx + batch_idx 54 | do_summary = global_step % args.summary_freq == 0 55 | loss, scalar_outputs, image_outputs = train_sample(model, model_loss, optimizer, sample, args) 56 | lr_scheduler.step() 57 | if (not is_distributed) or (dist.get_rank() == 0): 58 | if do_summary: 59 | if not args.notensorboard: 60 | tb_save_scalars(tb_writer, 'train', scalar_outputs, global_step) 61 | tb_save_images(tb_writer, 'train', image_outputs, global_step) 62 | logger.info("Epoch {}/{}, Iter {}/{}, 2mm_err={:.3f} | lr={:.6f}, train_loss={:.3f}, abs_err={:.3f}, pw_loss={:.3f}, dds_loss={:.3f}, time={:.3f}".format( 63 | epoch_idx, args.epochs, batch_idx, len(TrainImgLoader), 64 | scalar_outputs["thres2mm_error"], 65 | optimizer.param_groups[0]["lr"], 66 | loss, 67 | scalar_outputs["abs_depth_error"], 68 | scalar_outputs["s3_pw_loss"], 69 | scalar_outputs["s3_dds_loss"], 70 | time.time() - start_time)) 71 | del scalar_outputs, image_outputs 72 | 73 | # save checkpoint 74 | if (not is_distributed) or (dist.get_rank() == 0): 75 | if ((epoch_idx + 1) % args.save_freq == 0) or (epoch_idx == args.epochs-1): 76 | torch.save({ 77 | 'epoch': epoch_idx, 78 | 'model': model.module.state_dict(), 79 | 'optimizer': optimizer.state_dict()}, 80 | "{}/model_{:0>2}.ckpt".format(args.logdir, epoch_idx)) 81 | gc.collect() 82 | 83 | # testing 84 | if (epoch_idx % args.eval_freq == 0) or (epoch_idx == args.epochs - 1): 85 | avg_test_scalars = DictAverageMeter() 86 | for batch_idx, sample in enumerate(TestImgLoader): 87 | start_time = time.time() 88 | global_step = len(TrainImgLoader) * epoch_idx + batch_idx 89 | do_summary = global_step % args.summary_freq == 0 90 | loss, scalar_outputs, image_outputs = test_sample_depth(model, model_loss, sample, args) 91 | if (not is_distributed) or (dist.get_rank() == 0): 92 | if do_summary: 93 | if not args.notensorboard: 94 | tb_save_scalars(tb_writer, 'test', scalar_outputs, global_step) 95 | tb_save_images(tb_writer, 'test', image_outputs, global_step) 96 | logger.info( 97 | "Epoch {}/{}, Iter {}/{}, 2mm_err={:.3f} | lr={:.6f}, test_loss={:.3f}, abs_err={:.3f}, pw_loss={:.3f}, dds_loss={:.3f}, time={:.3f}".format( 98 | epoch_idx, args.epochs, batch_idx, len(TestImgLoader), 99 | scalar_outputs["thres2mm_error"], 100 | optimizer.param_groups[0]["lr"], 101 | loss, 102 | scalar_outputs["abs_depth_error"], 103 | scalar_outputs["s3_pw_loss"], 104 | scalar_outputs["s3_dds_loss"], 105 | time.time() - start_time)) 106 | avg_test_scalars.update(scalar_outputs) 107 | del scalar_outputs, image_outputs 108 | 109 | if (not is_distributed) or (dist.get_rank() == 0): 110 | if not args.notensorboard: 111 | tb_save_scalars(tb_writer, 'fulltest', avg_test_scalars.mean(), global_step) 112 | logger.info("avg_test_scalars: " + json.dumps(avg_test_scalars.mean())) 113 | gc.collect() 114 | 115 | 116 | def train_sample(model, model_loss, optimizer, sample, args): 117 | model.train() 118 | optimizer.zero_grad() 119 | 120 | sample_cuda = tocuda(sample) 121 | depth_gt_ms, mask_ms = sample_cuda["depth"], sample_cuda["mask"] 122 | depth_gt, mask = depth_gt_ms["stage{}".format(args.levels)], mask_ms["stage{}".format(args.levels)] 123 | 124 | # @Note GeoMVSNet main 125 | outputs = model( 126 | sample_cuda["imgs"], 127 | sample_cuda["proj_matrices"], sample_cuda["intrinsics_matrices"], 128 | sample_cuda["depth_values"] 129 | ) 130 | 131 | depth_est = outputs["depth"] 132 | 133 | loss, epe, pw_loss_stages, dds_loss_stages = model_loss( 134 | outputs, depth_gt_ms, mask_ms, 135 | stage_lw=[float(e) for e in args.stage_lw.split(",") if e], depth_values=sample_cuda["depth_values"] 136 | ) 137 | 138 | loss.backward() 139 | optimizer.step() 140 | 141 | scalar_outputs = { 142 | "loss": loss, 143 | "epe": epe, 144 | "s0_pw_loss": pw_loss_stages[0], 145 | "s1_pw_loss": pw_loss_stages[1], 146 | "s2_pw_loss": pw_loss_stages[2], 147 | "s3_pw_loss": pw_loss_stages[3], 148 | "s0_dds_loss": dds_loss_stages[0], 149 | "s1_dds_loss": dds_loss_stages[1], 150 | "s2_dds_loss": dds_loss_stages[2], 151 | "s3_dds_loss": dds_loss_stages[3], 152 | "abs_depth_error": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5), 153 | "thres2mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 2), 154 | "thres4mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 4), 155 | "thres8mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 8), 156 | } 157 | 158 | image_outputs = { 159 | "depth_est": depth_est * mask, 160 | "depth_est_nomask": depth_est, 161 | "depth_gt": sample["depth"]["stage1"], 162 | "ref_img": sample["imgs"][0], 163 | "mask": sample["mask"]["stage1"], 164 | "errormap": (depth_est - depth_gt).abs() * mask, 165 | } 166 | 167 | if is_distributed: 168 | scalar_outputs = reduce_scalar_outputs(scalar_outputs) 169 | 170 | return tensor2float(scalar_outputs["loss"]), tensor2float(scalar_outputs), tensor2numpy(image_outputs) 171 | 172 | 173 | @make_nograd_func 174 | def test_sample_depth(model, model_loss, sample, args): 175 | if is_distributed: 176 | model_eval = model.module 177 | else: 178 | model_eval = model 179 | model_eval.eval() 180 | 181 | sample_cuda = tocuda(sample) 182 | depth_gt_ms, mask_ms = sample_cuda["depth"], sample_cuda["mask"] 183 | depth_gt, mask = depth_gt_ms["stage{}".format(args.levels)], mask_ms["stage{}".format(args.levels)] 184 | 185 | outputs = model_eval( 186 | sample_cuda["imgs"], 187 | sample_cuda["proj_matrices"], sample_cuda["intrinsics_matrices"], 188 | sample_cuda["depth_values"] 189 | ) 190 | 191 | depth_est = outputs["depth"] 192 | 193 | loss, epe, pw_loss_stages, dds_loss_stages = model_loss( 194 | outputs, depth_gt_ms, mask_ms, 195 | stage_lw=[float(e) for e in args.stage_lw.split(",") if e], depth_values=sample_cuda["depth_values"] 196 | ) 197 | 198 | scalar_outputs = { 199 | "loss": loss, 200 | "epe": epe, 201 | "s0_pw_loss": pw_loss_stages[0], 202 | "s1_pw_loss": pw_loss_stages[1], 203 | "s2_pw_loss": pw_loss_stages[2], 204 | "s3_pw_loss": pw_loss_stages[3], 205 | "s0_dds_loss": dds_loss_stages[0], 206 | "s1_dds_loss": dds_loss_stages[1], 207 | "s2_dds_loss": dds_loss_stages[2], 208 | "s3_dds_loss": dds_loss_stages[3], 209 | "abs_depth_error": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5), 210 | "thres2mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 2), 211 | "thres4mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 4), 212 | "thres8mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 8), 213 | } 214 | 215 | image_outputs = { 216 | "depth_est": depth_est * mask, 217 | "depth_est_nomask": depth_est, 218 | "depth_gt": sample["depth"]["stage1"], 219 | "ref_img": sample["imgs"][0], 220 | "mask": sample["mask"]["stage1"], 221 | "errormap": (depth_est - depth_gt).abs() * mask 222 | } 223 | 224 | if is_distributed: 225 | scalar_outputs = reduce_scalar_outputs(scalar_outputs) 226 | 227 | return tensor2float(scalar_outputs["loss"]), tensor2float(scalar_outputs), tensor2numpy(image_outputs) 228 | 229 | 230 | def initLogger(): 231 | logger = logging.getLogger() 232 | logger.setLevel(logging.INFO) 233 | curTime = time.strftime('%Y%m%d-%H%M', time.localtime(time.time())) 234 | logfile = os.path.join(args.logdir, 'train-' + curTime + '.log') 235 | formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s") 236 | fileHandler = logging.FileHandler(logfile, mode='a') 237 | fileHandler.setFormatter(formatter) 238 | logger.addHandler(fileHandler) 239 | consoleHandler = logging.StreamHandler(sys.stdout) 240 | consoleHandler.setFormatter(formatter) 241 | logger.addHandler(consoleHandler) 242 | logger.info("Logger initialized.") 243 | logger.info("Writing logs to file: {}".format(logfile)) 244 | logger.info("Current time: {}".format(curTime)) 245 | 246 | settings_str = "All settings:\n" 247 | for k,v in vars(args).items(): 248 | settings_str += '{0}: {1}\n'.format(k,v) 249 | logger.info(settings_str) 250 | 251 | return logger 252 | 253 | 254 | if __name__ == '__main__': 255 | logger = initLogger() 256 | 257 | if args.resume: 258 | assert args.mode == "train" 259 | assert args.loadckpt is None 260 | 261 | if is_distributed: 262 | torch.cuda.set_device(args.local_rank) 263 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 264 | synchronize() 265 | 266 | set_random_seed(args.seed) 267 | device = torch.device(args.device) 268 | 269 | 270 | # tensorboard 271 | if (not is_distributed) or (dist.get_rank() == 0): 272 | if not os.path.isdir(args.logdir): 273 | os.makedirs(args.logdir) 274 | current_time_str = str(datetime.datetime.now().strftime('%Y%m%d_%H%M%S')) 275 | logger.info("current time " + current_time_str) 276 | logger.info("creating new summary file") 277 | if not args.notensorboard: 278 | tb_writer = SummaryWriter(args.logdir) 279 | 280 | 281 | # @Note GeoMVSNet model 282 | model = GeoMVSNet( 283 | levels=args.levels, 284 | hypo_plane_num_stages=[int(n) for n in args.hypo_plane_num_stages.split(",")], 285 | depth_interal_ratio_stages=[float(ir) for ir in args.depth_interal_ratio_stages.split(",")], 286 | feat_base_channel=args.feat_base_channel, 287 | reg_base_channel=args.reg_base_channel, 288 | group_cor_dim_stages=[int(n) for n in args.group_cor_dim_stages.split(",")], 289 | ) 290 | model.to(device) 291 | 292 | model_loss = geomvsnet_loss 293 | 294 | # optimizer 295 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.wd) 296 | 297 | 298 | # load parameters 299 | start_epoch = 0 300 | if args.resume: 301 | saved_models = [fn for fn in os.listdir(args.logdir) if fn.endswith(".ckpt")] 302 | saved_models = sorted(saved_models, key=lambda x: int(x.split('_')[-1].split('.')[0])) 303 | loadckpt = os.path.join(args.logdir, saved_models[-1]) 304 | logger.info("resuming: " + loadckpt) 305 | state_dict = torch.load(loadckpt, map_location=torch.device("cpu")) 306 | model.load_state_dict(state_dict['model']) 307 | optimizer.load_state_dict(state_dict['optimizer']) 308 | start_epoch = state_dict['epoch'] + 1 309 | 310 | 311 | # distributed 312 | if (not is_distributed) or (dist.get_rank() == 0): 313 | logger.info("start at epoch {}".format(start_epoch)) 314 | logger.info('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 315 | 316 | if is_distributed: 317 | if dist.get_rank() == 0: 318 | logger.info("Let's use {} GPUs in distributed mode!".format(torch.cuda.device_count())) 319 | model = torch.nn.parallel.DistributedDataParallel( 320 | model, device_ids=[args.local_rank], output_device=args.local_rank, 321 | find_unused_parameters=True, 322 | ) 323 | else: 324 | if torch.cuda.is_available(): 325 | logger.info("Let's use {} GPUs in parallel mode.".format(torch.cuda.device_count())) 326 | model = nn.DataParallel(model) 327 | 328 | 329 | # dataset, dataloader 330 | if args.which_dataset == "dtu": 331 | train_dataset = DTUDataset(args.trainpath, args.trainlist, "train", args.n_views, data_scale=args.data_scale, robust_train=args.robust_train) 332 | test_dataset = DTUDataset(args.testpath, args.testlist, "val", args.n_views, data_scale=args.data_scale) 333 | elif args.which_dataset == "blendedmvs": 334 | train_dataset = BlendedMVSDataset(args.trainpath, args.trainlist, "train", args.n_views, img_wh=(768, 576), robust_train=args.robust_train, augment=False) 335 | test_dataset = BlendedMVSDataset(args.testpath, args.testlist, "val", args.n_views, img_wh=(768, 576)) 336 | 337 | if is_distributed: 338 | train_sampler = torch.utils.data.DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) 339 | test_sampler = torch.utils.data.DistributedSampler(test_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) 340 | 341 | TrainImgLoader = DataLoader(train_dataset, args.batch_size, sampler=train_sampler, num_workers=8, drop_last=True, pin_memory=args.pin_m) 342 | TestImgLoader = DataLoader(test_dataset, args.batch_size, sampler=test_sampler, num_workers=8, drop_last=False, pin_memory=args.pin_m) 343 | else: 344 | TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=8, drop_last=True, pin_memory=args.pin_m) 345 | TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=8, drop_last=False, pin_memory=args.pin_m) 346 | 347 | train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, start_epoch, args) -------------------------------------------------------------------------------- /datasets/evaluations/dtu_parallel/plyread.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | function [Elements,varargout] = plyread(Path,Str) 3 | %PLYREAD Read a PLY 3D data file. 4 | % [DATA,COMMENTS] = PLYREAD(FILENAME) reads a version 1.0 PLY file 5 | % FILENAME and returns a structure DATA. The fields in this structure 6 | % are defined by the PLY header; each element type is a field and each 7 | % element property is a subfield. If the file contains any comments, 8 | % they are returned in a cell string array COMMENTS. 9 | % 10 | % [TRI,PTS] = PLYREAD(FILENAME,'tri') or 11 | % [TRI,PTS,DATA,COMMENTS] = PLYREAD(FILENAME,'tri') converts vertex 12 | % and face data into triangular connectivity and vertex arrays. The 13 | % mesh can then be displayed using the TRISURF command. 14 | % 15 | % Note: This function is slow for large mesh files (+50K faces), 16 | % especially when reading data with list type properties. 17 | % 18 | % Example: 19 | % [Tri,Pts] = PLYREAD('cow.ply','tri'); 20 | % trisurf(Tri,Pts(:,1),Pts(:,2),Pts(:,3)); 21 | % colormap(gray); axis equal; 22 | % 23 | % See also: PLYWRITE 24 | 25 | % Pascal Getreuer 2004 26 | 27 | [fid,Msg] = fopen(Path,'rt'); % open file in read text mode 28 | 29 | if fid == -1, error(Msg); end 30 | 31 | Buf = fscanf(fid,'%s',1); 32 | if ~strcmp(Buf,'ply') 33 | fclose(fid); 34 | error('Not a PLY file.'); 35 | end 36 | 37 | 38 | %%% read header %%% 39 | 40 | Position = ftell(fid); 41 | Format = ''; 42 | NumComments = 0; 43 | Comments = {}; % for storing any file comments 44 | NumElements = 0; 45 | NumProperties = 0; 46 | Elements = []; % structure for holding the element data 47 | ElementCount = []; % number of each type of element in file 48 | PropertyTypes = []; % corresponding structure recording property types 49 | ElementNames = {}; % list of element names in the order they are stored in the file 50 | PropertyNames = []; % structure of lists of property names 51 | 52 | while 1 53 | Buf = fgetl(fid); % read one line from file 54 | BufRem = Buf; 55 | Token = {}; 56 | Count = 0; 57 | 58 | while ~isempty(BufRem) % split line into tokens 59 | [tmp,BufRem] = strtok(BufRem); 60 | 61 | if ~isempty(tmp) 62 | Count = Count + 1; % count tokens 63 | Token{Count} = tmp; 64 | end 65 | end 66 | 67 | if Count % parse line 68 | switch lower(Token{1}) 69 | case 'format' % read data format 70 | if Count >= 2 71 | Format = lower(Token{2}); 72 | 73 | if Count == 3 & ~strcmp(Token{3},'1.0') 74 | fclose(fid); 75 | error('Only PLY format version 1.0 supported.'); 76 | end 77 | end 78 | case 'comment' % read file comment 79 | NumComments = NumComments + 1; 80 | Comments{NumComments} = ''; 81 | for i = 2:Count 82 | Comments{NumComments} = [Comments{NumComments},Token{i},' ']; 83 | end 84 | case 'element' % element name 85 | if Count >= 3 86 | if isfield(Elements,Token{2}) 87 | fclose(fid); 88 | error(['Duplicate element name, ''',Token{2},'''.']); 89 | end 90 | 91 | NumElements = NumElements + 1; 92 | NumProperties = 0; 93 | Elements = setfield(Elements,Token{2},[]); 94 | PropertyTypes = setfield(PropertyTypes,Token{2},[]); 95 | ElementNames{NumElements} = Token{2}; 96 | PropertyNames = setfield(PropertyNames,Token{2},{}); 97 | CurElement = Token{2}; 98 | ElementCount(NumElements) = str2double(Token{3}); 99 | 100 | if isnan(ElementCount(NumElements)) 101 | fclose(fid); 102 | error(['Bad element definition: ',Buf]); 103 | end 104 | else 105 | error(['Bad element definition: ',Buf]); 106 | end 107 | case 'property' % element property 108 | if ~isempty(CurElement) & Count >= 3 109 | NumProperties = NumProperties + 1; 110 | eval(['tmp=isfield(Elements.',CurElement,',Token{Count});'],... 111 | 'fclose(fid);error([''Error reading property: '',Buf])'); 112 | 113 | if tmp 114 | error(['Duplicate property name, ''',CurElement,'.',Token{2},'''.']); 115 | end 116 | 117 | % add property subfield to Elements 118 | eval(['Elements.',CurElement,'.',Token{Count},'=[];'], ... 119 | 'fclose(fid);error([''Error reading property: '',Buf])'); 120 | % add property subfield to PropertyTypes and save type 121 | eval(['PropertyTypes.',CurElement,'.',Token{Count},'={Token{2:Count-1}};'], ... 122 | 'fclose(fid);error([''Error reading property: '',Buf])'); 123 | % record property name order 124 | eval(['PropertyNames.',CurElement,'{NumProperties}=Token{Count};'], ... 125 | 'fclose(fid);error([''Error reading property: '',Buf])'); 126 | else 127 | fclose(fid); 128 | 129 | if isempty(CurElement) 130 | error(['Property definition without element definition: ',Buf]); 131 | else 132 | error(['Bad property definition: ',Buf]); 133 | end 134 | end 135 | case 'end_header' % end of header, break from while loop 136 | break; 137 | end 138 | end 139 | end 140 | 141 | %%% set reading for specified data format %%% 142 | 143 | if isempty(Format) 144 | warning('Data format unspecified, assuming ASCII.'); 145 | Format = 'ascii'; 146 | end 147 | 148 | switch Format 149 | case 'ascii' 150 | Format = 0; 151 | case 'binary_little_endian' 152 | Format = 1; 153 | case 'binary_big_endian' 154 | Format = 2; 155 | otherwise 156 | fclose(fid); 157 | error(['Data format ''',Format,''' not supported.']); 158 | end 159 | 160 | if ~Format 161 | Buf = fscanf(fid,'%f'); % read the rest of the file as ASCII data 162 | BufOff = 1; 163 | else 164 | % reopen the file in read binary mode 165 | fclose(fid); 166 | 167 | if Format == 1 168 | fid = fopen(Path,'r','ieee-le.l64'); % little endian 169 | else 170 | fid = fopen(Path,'r','ieee-be.l64'); % big endian 171 | end 172 | 173 | % find the end of the header again (using ftell on the old handle doesn't give the correct position) 174 | BufSize = 8192; 175 | Buf = [blanks(10),char(fread(fid,BufSize,'uchar')')]; 176 | i = []; 177 | tmp = -11; 178 | 179 | while isempty(i) 180 | i = findstr(Buf,['end_header',13,10]); % look for end_header + CR/LF 181 | i = [i,findstr(Buf,['end_header',10])]; % look for end_header + LF 182 | 183 | if isempty(i) 184 | tmp = tmp + BufSize; 185 | Buf = [Buf(BufSize+1:BufSize+10),char(fread(fid,BufSize,'uchar')')]; 186 | end 187 | end 188 | 189 | % seek to just after the line feed 190 | fseek(fid,i + tmp + 11 + (Buf(i + 10) == 13),-1); 191 | end 192 | 193 | 194 | %%% read element data %%% 195 | 196 | % PLY and MATLAB data types (for fread) 197 | PlyTypeNames = {'char','uchar','short','ushort','int','uint','float','double', ... 198 | 'char8','uchar8','short16','ushort16','int32','uint32','float32','double64'}; 199 | MatlabTypeNames = {'schar','uchar','int16','uint16','int32','uint32','single','double'}; 200 | SizeOf = [1,1,2,2,4,4,4,8]; % size in bytes of each type 201 | 202 | for i = 1:NumElements 203 | % get current element property information 204 | eval(['CurPropertyNames=PropertyNames.',ElementNames{i},';']); 205 | eval(['CurPropertyTypes=PropertyTypes.',ElementNames{i},';']); 206 | NumProperties = size(CurPropertyNames,2); 207 | 208 | % fprintf('Reading %s...\n',ElementNames{i}); 209 | 210 | if ~Format %%% read ASCII data %%% 211 | for j = 1:NumProperties 212 | Token = getfield(CurPropertyTypes,CurPropertyNames{j}); 213 | 214 | if strcmpi(Token{1},'list') 215 | Type(j) = 1; 216 | else 217 | Type(j) = 0; 218 | end 219 | end 220 | 221 | % parse buffer 222 | if ~any(Type) 223 | % no list types 224 | Data = reshape(Buf(BufOff:BufOff+ElementCount(i)*NumProperties-1),NumProperties,ElementCount(i))'; 225 | BufOff = BufOff + ElementCount(i)*NumProperties; 226 | else 227 | ListData = cell(NumProperties,1); 228 | 229 | for k = 1:NumProperties 230 | ListData{k} = cell(ElementCount(i),1); 231 | end 232 | 233 | % list type 234 | for j = 1:ElementCount(i) 235 | for k = 1:NumProperties 236 | if ~Type(k) 237 | Data(j,k) = Buf(BufOff); 238 | BufOff = BufOff + 1; 239 | else 240 | tmp = Buf(BufOff); 241 | ListData{k}{j} = Buf(BufOff+(1:tmp))'; 242 | BufOff = BufOff + tmp + 1; 243 | end 244 | end 245 | end 246 | end 247 | else %%% read binary data %%% 248 | % translate PLY data type names to MATLAB data type names 249 | ListFlag = 0; % = 1 if there is a list type 250 | SameFlag = 1; % = 1 if all types are the same 251 | 252 | for j = 1:NumProperties 253 | Token = getfield(CurPropertyTypes,CurPropertyNames{j}); 254 | 255 | if ~strcmp(Token{1},'list') % non-list type 256 | tmp = rem(strmatch(Token{1},PlyTypeNames,'exact')-1,8)+1; 257 | 258 | if ~isempty(tmp) 259 | TypeSize(j) = SizeOf(tmp); 260 | Type{j} = MatlabTypeNames{tmp}; 261 | TypeSize2(j) = 0; 262 | Type2{j} = ''; 263 | 264 | SameFlag = SameFlag & strcmp(Type{1},Type{j}); 265 | else 266 | fclose(fid); 267 | error(['Unknown property data type, ''',Token{1},''', in ', ... 268 | ElementNames{i},'.',CurPropertyNames{j},'.']); 269 | end 270 | else % list type 271 | if length(Token) == 3 272 | ListFlag = 1; 273 | SameFlag = 0; 274 | tmp = rem(strmatch(Token{2},PlyTypeNames,'exact')-1,8)+1; 275 | tmp2 = rem(strmatch(Token{3},PlyTypeNames,'exact')-1,8)+1; 276 | 277 | if ~isempty(tmp) & ~isempty(tmp2) 278 | TypeSize(j) = SizeOf(tmp); 279 | Type{j} = MatlabTypeNames{tmp}; 280 | TypeSize2(j) = SizeOf(tmp2); 281 | Type2{j} = MatlabTypeNames{tmp2}; 282 | else 283 | fclose(fid); 284 | error(['Unknown property data type, ''list ',Token{2},' ',Token{3},''', in ', ... 285 | ElementNames{i},'.',CurPropertyNames{j},'.']); 286 | end 287 | else 288 | fclose(fid); 289 | error(['Invalid list syntax in ',ElementNames{i},'.',CurPropertyNames{j},'.']); 290 | end 291 | end 292 | end 293 | 294 | % read file 295 | if ~ListFlag 296 | if SameFlag 297 | % no list types, all the same type (fast) 298 | Data = fread(fid,[NumProperties,ElementCount(i)],Type{1})'; 299 | else 300 | % no list types, mixed type 301 | Data = zeros(ElementCount(i),NumProperties); 302 | 303 | for j = 1:ElementCount(i) 304 | for k = 1:NumProperties 305 | Data(j,k) = fread(fid,1,Type{k}); 306 | end 307 | end 308 | end 309 | else 310 | ListData = cell(NumProperties,1); 311 | 312 | for k = 1:NumProperties 313 | ListData{k} = cell(ElementCount(i),1); 314 | end 315 | 316 | if NumProperties == 1 317 | BufSize = 512; 318 | SkipNum = 4; 319 | j = 0; 320 | 321 | % list type, one property (fast if lists are usually the same length) 322 | while j < ElementCount(i) 323 | Position = ftell(fid); 324 | % read in BufSize count values, assuming all counts = SkipNum 325 | [Buf,BufSize] = fread(fid,BufSize,Type{1},SkipNum*TypeSize2(1)); 326 | Miss = find(Buf ~= SkipNum); % find first count that is not SkipNum 327 | fseek(fid,Position + TypeSize(1),-1); % seek back to after first count 328 | 329 | if isempty(Miss) % all counts are SkipNum 330 | Buf = fread(fid,[SkipNum,BufSize],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))'; 331 | fseek(fid,-TypeSize(1),0); % undo last skip 332 | 333 | for k = 1:BufSize 334 | ListData{1}{j+k} = Buf(k,:); 335 | end 336 | 337 | j = j + BufSize; 338 | BufSize = floor(1.5*BufSize); 339 | else 340 | if Miss(1) > 1 % some counts are SkipNum 341 | Buf2 = fread(fid,[SkipNum,Miss(1)-1],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))'; 342 | 343 | for k = 1:Miss(1)-1 344 | ListData{1}{j+k} = Buf2(k,:); 345 | end 346 | 347 | j = j + k; 348 | end 349 | 350 | % read in the list with the missed count 351 | SkipNum = Buf(Miss(1)); 352 | j = j + 1; 353 | ListData{1}{j} = fread(fid,[1,SkipNum],Type2{1}); 354 | BufSize = ceil(0.6*BufSize); 355 | end 356 | end 357 | else 358 | % list type(s), multiple properties (slow) 359 | Data = zeros(ElementCount(i),NumProperties); 360 | 361 | for j = 1:ElementCount(i) 362 | for k = 1:NumProperties 363 | if isempty(Type2{k}) 364 | Data(j,k) = fread(fid,1,Type{k}); 365 | else 366 | tmp = fread(fid,1,Type{k}); 367 | ListData{k}{j} = fread(fid,[1,tmp],Type2{k}); 368 | end 369 | end 370 | end 371 | end 372 | end 373 | end 374 | 375 | % put data into Elements structure 376 | for k = 1:NumProperties 377 | if (~Format & ~Type(k)) | (Format & isempty(Type2{k})) 378 | eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=Data(:,k);']); 379 | else 380 | eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=ListData{k};']); 381 | end 382 | end 383 | end 384 | 385 | clear Data ListData; 386 | fclose(fid); 387 | 388 | if (nargin > 1 & strcmpi(Str,'Tri')) | nargout > 2 389 | % find vertex element field 390 | Name = {'vertex','Vertex','point','Point','pts','Pts'}; 391 | Names = []; 392 | 393 | for i = 1:length(Name) 394 | if any(strcmp(ElementNames,Name{i})) 395 | Names = getfield(PropertyNames,Name{i}); 396 | Name = Name{i}; 397 | break; 398 | end 399 | end 400 | 401 | if any(strcmp(Names,'x')) & any(strcmp(Names,'y')) & any(strcmp(Names,'z')) 402 | eval(['varargout{1}=[Elements.',Name,'.x,Elements.',Name,'.y,Elements.',Name,'.z];']); 403 | else 404 | varargout{1} = zeros(1,3); 405 | end 406 | 407 | varargout{2} = Elements; 408 | varargout{3} = Comments; 409 | Elements = []; 410 | 411 | % find face element field 412 | Name = {'face','Face','poly','Poly','tri','Tri'}; 413 | Names = []; 414 | 415 | for i = 1:length(Name) 416 | if any(strcmp(ElementNames,Name{i})) 417 | Names = getfield(PropertyNames,Name{i}); 418 | Name = Name{i}; 419 | break; 420 | end 421 | end 422 | 423 | if ~isempty(Names) 424 | % find vertex indices property subfield 425 | PropertyName = {'vertex_indices','vertex_indexes','vertex_index','indices','indexes'}; 426 | 427 | for i = 1:length(PropertyName) 428 | if any(strcmp(Names,PropertyName{i})) 429 | PropertyName = PropertyName{i}; 430 | break; 431 | end 432 | end 433 | 434 | if ~iscell(PropertyName) 435 | % convert face index lists to triangular connectivity 436 | eval(['FaceIndices=varargout{2}.',Name,'.',PropertyName,';']); 437 | N = length(FaceIndices); 438 | Elements = zeros(N*2,3); 439 | Extra = 0; 440 | 441 | for k = 1:N 442 | Elements(k,:) = FaceIndices{k}(1:3); 443 | 444 | for j = 4:length(FaceIndices{k}) 445 | Extra = Extra + 1; 446 | Elements(N + Extra,:) = [Elements(k,[1,j-1]),FaceIndices{k}(j)]; 447 | end 448 | end 449 | Elements = Elements(1:N+Extra,:) + 1; 450 | end 451 | end 452 | else 453 | varargout{1} = Comments; 454 | end --------------------------------------------------------------------------------